On the architectural divide between parsing and tracing kernel DSLs, and what tends to go wrong in each.
The language for writing NVIDIA GPU kernels was always exclusively CUDA, but since Triton appeared, a wave of Pythonic DSLs has followed: CuTe-DSL, cuTile, Pallas, Gluon, Warp, and the more recent TileLang used in DeepSeek’s DeepGEMM. Most of these systems share the same goal of lowering a tile-oriented program into PTX or LLVM-IR, and are embedded in Python.
The question is how to embed the DSL into Python. Triton and CuTe-DSL parse the source AST. Pallas runs the function under abstract values and traces the resulting operations. (PyTorch’s torch.compile intercepts CPython bytecode rather than source, but that is still parsing, just against a smaller, post-desugared grammar; the same trade-offs apply.)
Most DSLs follow Triton’s lead and use parsing. This essay takes the alternative and argues that a tracing-based approach is often preferable.
CUDA and Templates
A CUDA kernel directly specifies the execution code for each thread. A textbook fused-softmax kernel in CUDA looks roughly like this:
template <typename T, int BLOCK_SIZE>
__global__ void softmax_kernel(const T* __restrict__ x,
T* __restrict__ y,
int n_cols) {
int row = blockIdx.x;
int tid = threadIdx.x;
__shared__ float sdata[BLOCK_SIZE];
const T* row_ptr = x + row * n_cols;
float local_max = -INFINITY;
for (int i = tid; i < n_cols; i += BLOCK_SIZE)
local_max = fmaxf(local_max, float(row_ptr[i]));
sdata[tid] = local_max;
__syncthreads();
// ... tree reduction, exp, normalize, store ...
}The element type T and the block size BLOCK_SIZE must be known at compile time, as __shared__ memory is statically sized, and the compiler must specialise loop bounds to enable vectorisation of the body. Hence any expansion of the supported configuration space multiplies the number of instantiations. Three element types and four block sizes already imply twelve instantiations, and the responsibility for dispatching among them rests with the caller.
Adding more templates and more generalisations to CUDA, one eventually reaches a heavily templated CUTLASS-like state.
CUTLASS: Building Blocks for CUDA Kernels
CUTLASS is what C++ template metaprogramming looks like when taken as a way to write GPU kernels. Consider the declaration of its principal Gemm class, the entry point most users first encounter, from include/cutlass/gemm/device/gemm.h:
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
// ... ten more parameters elided ...
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute>
class Gemm { /* ... */ };cutlass/gemm/device/gemm.h, lines 169–233. Around twenty template parameters, several with defaults that recursively look up DefaultGemmConfiguration<...>.
A fragment of the canonical Hopper warp-specialized GEMM example shows how a user composes a kernel from nested CollectiveBuilders, each a template that pulls in dozens of further instantiations:
using namespace cute;
using TileShape = Shape<_128,_128,_32>; // CTA tile
using ClusterShape = Shape<_4,_2,_1>; // cluster of CTAs
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, CollectiveMainloop, CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu, lines 83–142.
Shape<_128,_128,_32> denotes a type rather than a value, and the compiler must instantiate every dependent template once per distinct shape. This results in large compile times.
Compile Time: The Cost of Templates
We compiled
48_hopper_warp_specialized_gemm, a single CuTe-based GEMM file of roughly 500 lines, with-cand no benchmark harness, inside annvidia/cuda:12.5.0-devel-ubuntu22.04container, invokingnvcc -std=c++17 -arch=sm_90a. The steady-state nvcc time, averaged over two warm runs on a consumer laptop, was:~20.5 s compiling single kernel for single architecture
A full CUTLASS build targets several architectures, and the cost multiplies accordingly. NVIDIA’s own bug tracker records 17m22s for two Ampere
i16832gemm_s8kernels (issue #1042) and approximately two minutes for a 30-line CuTe-DSL kernel (issue #2677). The NVIDIA developer blog post introducing the CuTe Python DSL in November 2025 frames its principal contribution as “up to two orders of magnitude reduced” compile times relative to C++ CUTLASS, with a quoted “~100x compilation speedup” for Blackwell GEMM and “30-50x” for flash attention.
C++ templates compile too slowly for the iteration speed kernel authors need.
Triton: DSL Embedded Into Python
Using Triton is straightforward: decorate a Python function with @triton.jit, mark the compile-time parameters with tl.constexpr, and write the kernel body in something close to NumPy. Triton also simplifies the programming model, focusing on the tile a single thread block operates on, rather than on the code for an individual thread.
@triton.jit
def softmax_kernel(output_ptr, input_ptr,
input_row_stride, output_row_stride,
n_rows, n_cols,
BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_out = numerator / denominator
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
tl.store(output_ptrs, softmax_out, mask=mask)triton/python/tutorials/02-fused-softmax.py, lines 84–109.
Triton is a pleasure to use when it works: the program looks straightforward and does what one would expect. Integration into PyTorch is first-class, there is no build system to set up, and it is relatively hard to construct a malformed program that triggers a crash. However, when one wants to write a reusable generic set of libraries in Triton, things get tricky.
Parsing Limitations
Suppose we want to expose a fused matmul whose epilogue (activation, scaling, or fusion applied to the accumulator after the inner product) can be supplied by the caller as a Python callable:
matmul(A, B, activation=lambda x: tl.where(x > 0, x, 0))In a tracing-based framework this is one line: the user hands in a Python callable, and the trace records whatever operations the callable performs. In Triton this is tricky to achieve, and almost every other limitation in this section is a variation of the same underlying constraint.
What the Ecosystem Actually Does
Before we look at why, it helps to see how production Triton libraries handle this in practice. Across Liger-Kernel, FlagGems, Quack, and DeepGEMM, the answer is consistent: enumerate variants statically, dispatch by enum or string tag, never accept a runtime callable. Quack’s GEMM signature is representative:
def gemm_act_tuned(
A, B, preact_out, postact_out,
C=None, bias=None,
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
...
):quack/quack/gemm_interface.py, line 158. Gated variants are similarly enumerated: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"].
Why Lambdas Fail
If we try to pass a callable in Triton (say, matmul(A, B, activation=lambda x: tl.where(x > 0, x, 0))), it fails at compile time:
CompilationError: at 4:25: ... ACTIVATION(x)
ValueError("Did you forget to add @triton.jit ? (`_semantic`
argument must be provided outside of JIT functions.)")Following the error’s advice and wrapping the lambda in triton.jit fails earlier, during construction of the JITFunction object:
File "triton/runtime/jit.py", line 1110, in get_def_col_number
raise ValueError("No function definition found for kernel")The constructor calls inspect.getsourcelines(fn) and expects the returned source to contain a def name( line, matched via a module-level regex at jit.py:27. It uses the def to compute indentation and dedent the body before handing it to the AST walker; a lambda’s source line contains no def, and the construction step fails. The supported workaround is to lift the activation into a named @triton.jit def and pass that through the constexpr argument, which does work. But that is exactly what an enum entry already refers to, and the ecosystem’s convergence on enums reflects this constraint: once the supported shape is “any caller-defined @triton.jit-decorated function passed by name,” the library author may as well enumerate the small set they intend to support, autotune over it, and present a string-typed API.
Closures
A factory-style helper that captures configuration in an enclosing scope is a common pattern in Python:
def make_kernel(scale):
@triton.jit
def k(...):
... # uses `scale` from the enclosing scope
return kIn Triton this fails unconditionally:
NameError: Cannot access global variable scale from within
@jit'ed function. Triton kernels can only access global variables
that are instanstiated as constexpr
(`x = triton.language.constexpr(42)`).Triton does not implement closure capture at all; every free name in a kernel body is resolved against module globals and must already be a tl.constexpr.
Higher-Order Primitives
The same constraint shows up most sharply in the region-builder APIs of tl.reduce and tl.associative_scan: even a properly @triton.jit-decorated combine function cannot be passed through a kernel argument at all. The combine must be resolved lexically in the kernel’s enclosing scope, and the standard library copes by shipping one named combine per parameter combination: _argmax_combine_tie_break_left and _argmax_combine_tie_break_fast exist as two file-scoped functions because the boolean argument they differ in cannot be threaded in at call time (triton/python/triton/language/standard.py:158-165). This is arguably more a limitation of the specific region-builder API than of the parsing approach as such; one could imagine a Triton in which tl.reduce accepted callable arguments. But the analogous JAX primitive (jax.lax.associative_scan(lambda x, y: x + y, xs)) does accept a lambda.
Aren’t Those Fixable?
Each issue described above is, in principle, fixable. But the picture stays the same: supporting metaprogramming in Triton means implementing more and more of Python inside the AST walker.
CuTe-DSL: The Same Pattern, Sharper Edges
While writing CuTe-DSL kernels we encountered two such failures and filed bug reports against both. Each illustrates the pattern:
Two CuTe-DSL bugs, both from a parsing frontend
cutlass#3268 —
storage.<field>.get_tensor(...)works at the top level of a kernel but fails with “encountered a user-defined Python object” when placed inside a runtimeifblock. The shared-storage tensor lookup is a Python object that cannot be carried across the boundary of the loweredscf.ifregion; the surface language does not signal this in any way, so semantically identical code succeeds or fails depending on where it sits in the source. The lowering treats values in “Python” blocks and values in “IR” blocks differently, and the rules for what survives the transition are not part of the user-facing language.cutlass#3266 —
nvvm.load.extrejectsBFloat16(“Unsupported FP type for ExtLoadOp”) even thoughbf16is a first-class element type elsewhere in the DSL. Each lowering path re-enumerates the dtypes it knows about, and users reach the cliff every time a new combination of (op, dtype, layout) has not been wired up.
Let’s see how parsing actually works, and how does it impose semantic limitations on the expressiveness of the DSL.
Under the Hood: Where the Limits Come From
Triton compiles by parsing the body of the decorated function into a Python AST and walking it to emit MLIR. The walker lives in python/triton/compiler/code_generator.py, a subclass of ast.NodeVisitor. Consider visitor methods used to process control flow:
def visit_If(self, node):
cond = self.visit(node.test)
if _is_triton_tensor(cond):
# runtime branch: emit scf.if in MLIR
cond = cond.to(language.int1, _semantic=self.semantic)
...
self.visit_if_scf(cond, node)
else:
# compile-time branch: take it now, discard the other one.
cond = _unwrap_if_constexpr(cond)
active_block = node.body if cond else node.orelse
self.visit_compound_statement(active_block)code_generator.py, lines 957–986.
A single Python if denotes two distinct constructs depending on the type of its condition. When cond is a Triton tensor, the walker emits an scf.if region and both branches are generated. However, when cond is a constexpr, the walker evaluates the condition at compile time, selects one branch, and never visits the other. The same duality governs for loops: tl.static_range(8) unrolls the body into eight copies at parse time, whereas tl.range(8) produces a single MLIR scf.for. The two are syntactically indistinguishable yet denote entirely different constructs, and the reader has to know which is in force at each site, with no way to force one or the other.
Pallas: Trace Instead of Parse
JAX’s Pallas kernel language follows JAX’s syntax and uses tracing. A Pallas kernel is an ordinary Python function that operates on JAX Refs and is passed to pl.pallas_call, which executes it under a tracer.
The body is not parsed, it is run with tracer expressions as arguments to
generate a jaxpr: an IR capturing the semantics of the kernel.
def _vmappable_softmax_kernel(input_ref, probs_ref, *, block_row: int):
row_len = input_ref.shape[-1]
mask = jnp.arange(block_row) < row_len
row = plgpu.load(input_ref.at[pl.ds(0, block_row)],
mask=mask, other=-float("inf"))
row_max = jnp.max(row, axis=0)
numerator = jnp.exp((row - row_max).astype(jnp.float32))
denominator = jnp.sum(numerator, axis=0)
plgpu.store(probs_ref.at[pl.ds(0, block_row)],
(numerator / denominator).astype(probs_ref.dtype),
mask=mask)jax/jax/experimental/pallas/ops/gpu/softmax.py, lines 24–49.
There is no separate DSL to parse: the operations are jnp.max, jnp.exp, and jnp.sum, drawn from the NumPy-shaped API that JAX users already know and dispatched through tracers that record into the kernel’s IR rather than executing on the CPU.
The Pallas design document states the contrast: “JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels.”
The clearest illustration of what tracing buys comes from Tokamax’s SM90 flash-attention kernel, where compile-time and runtime control flow sit side by side in a single body:
carry = (acc, m_i, m_i, l_i)
if is_causal: # compile time
causal_loop_body = functools.partial(loop_body, do_causal=True)
ub_no_causal = lax.min(ub, lax.div(q_base, block_kv))
carry = lax.fori_loop(lb, ub_no_causal, loop_body, carry) # runtime
# TODO: This cond should be redundant, but without it we
# hit a weird compiler bug.
acc, m_scale, m_i, l_i = lax.cond( # runtime
ub_no_causal < ub,
lambda: lax.fori_loop(ub_no_causal, ub, causal_loop_body, carry),
lambda: carry,
)
else:
acc, m_scale, m_i, l_i = lax.fori_loop(lb, ub, loop_body, carry) # runtime
pl.when(wg == 0)(schedule_barrier_arrive) # runtimetokamax/_src/ops/attention/pallas_mosaic_gpu_kernel_sm90.py, lines 387–403.
The two regimes are doing distinct jobs.
The outer if is_causal: is a plain Python branch evaluated at trace time; the untaken branch is never compiled and never seen by the backend.
It selects between two entirely different runtime control-flow shapes: the non-causal branch is a single lax.fori_loop over the full KV range, while the causal branch splits the range into a non-causal prefix and a causal tail, each driven by its own lax.fori_loop.
(The lax.cond wrapping the tail is, per the source comment, a workaround for a compiler bug rather than a deliberate part of the algorithm.)
The causal body itself is specialised at trace time via functools.partial(loop_body, do_causal=True).
The backend never has to guess which is which: a Python if is always compile-time, lax.fori_loop is always runtime. The trace is the program.
Runtime Combinators Look Unusual
The most common complaint about tracing-based DSLs is that runtime control flow looks awkward. Where Triton lets the user write for k in tl.range(...) or if dynamic_cond: in plain Python, Pallas requires lax.fori_loop(lb, ub, body, carry) with an explicit carry tuple, or a decorated @pl.when(cond) over a freshly-defined nested function. The combinators feel like a step sideways from idiomatic Python at first reading.
Authors get over this quickly because the patterns are few and soon become second nature. Idiomatic Pallas rarely uses the combinators raw; it wraps each one in a layer of compile-time Python that decides which runtime shape to emit, so a single source file can serve causal and non-causal attention by building two specialised loop_body closures with functools.partial and feeding them to two different lax.fori_loop arrangements (exactly what the Tokamax kernel above does). Two recurring patterns are worth naming:
functools.partial(body, ...static_flag=True)to specialise an inner-loop body before passing it tolax.fori_loop.- A decorator-form
@pl.when(cond)over a fresh_()function, so a runtime branch reads like a labelled block.
The combinators stay small and visible because the Python around them does the structural work first.
Compile-Time Python as a Metaprogramming Layer
Because compile-time control flow is just Python, anything Python can compute is available for metaprogramming the kernel, including patterns that a parsing-based DSL would have to special-case in its AST walker.
The Pallas ragged_dot Mosaic-GPU kernel is a good example.
The kernel needs to store a dynamic number of rows (somewhere between 1 and block_m) to global memory, but TMA descriptors require statically-known tile sizes.
Idiomatic Pallas resolves the contradiction by unrolling a Python while loop into a logarithmic ladder of fixed-size stores, each guarded by a runtime bit-test on the dynamic length:
smem_start = group_info.start_within_block
remaining_rows = min(block_m, m)
while remaining_rows > 0:
const_rows_len = 1 << int(math.log2(remaining_rows))
remaining_rows //= 2
@pl.when(group_info.actual_size & const_rows_len != 0)
def _():
o_smem_slice = o_smem.at[pl.ds(smem_start, const_rows_len)]
o_gref_slice = o_gmem.at[
pl.ds(group_info.block_start + smem_start, const_rows_len),
pl.ds(ni * block_n, block_n),
]
plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice)
smem_start += group_info.actual_size & const_rows_lenjax/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py, lines 176–228 (35 lines of in-source comments explaining the algorithm are elided). To make the trick concrete: with block_m=8 and a dynamic length of 5, the Python loop emits stores of sizes 8, 4, 2, 1 at trace time, and runtime bit-tests on actual_size fire only the 4 and the 1.
The Python while runs at trace time and emits log2(block_m) separate TMA stores with compile-time constant tile sizes; @pl.when wraps each in a runtime bit-test on the dynamic length, so only the stores whose tile size contributes to the desired total actually fire. The Python loop variable does double duty as both an unroll counter and a bit position. Implementing this in a parsing-based DSL would require either bespoke AST support for Python integer loops that close over kernel refs, or four hand-written @pl.when blocks in the source.
Ordinary Python debugging tooling continues to work inside Pallas kernels. A plain print call inside a traced function prints once, at trace time, with the tracer’s abstract value, which is useful for inspecting types and shapes during compilation. For runtime values, the deliberately side-effecting primitives go through:
def kernel(x_ref, o_ref):
if dtype == jnp.int32:
pl.debug_print('BEGIN1 x[0] == {}', x_ref[0])
pl.debug_print('BEGIN2 x[0] == {} ; x[1] == {} ; END',
x_ref[0], x_ref[1])
else:
pl.debug_print('BEGIN1 x[0] == ', x_ref[0])jax/tests/pallas/tpu_pallas_call_print_test.py, lines 88–95.
The tracing approach gets four debugging affordances either free or cheap. Parsing-based DSLs can offer the same, but at considerable engineering cost — Triton, for instance, provides both source-mapped error messages and MLIR loc attributes, but only because the AST walker has been carefully threaded with source locations at every visitor method:
- Native breakpoints.
jax.debug.breakpoint()drops into a real Python debugger at the corresponding point in the compiled program (jax/_src/debugger/core.py:160); Pallas’s interpret mode runs the kernel as a plain Python loop on the CPU, where a bare Pythonbreakpoint()andpdbwork without any special integration. - Native print. In interpret mode, plain
print()just works because the kernel is a normal Python function.jax.debug.printandpl.debug_printextend that to compiled execution. - Full Python tracebacks. When the trace raises, the traceback is the actual Python call stack into the kernel source (the line of the user’s function that produced the offending operation). Triton points at the kernel source too, via the
def_file_col_numbermachinery we touched on earlier; tracing simply gets the same for free by being Python. - Op-level provenance in MLIR. JAX threads a
SourceInfoobject (jax/_src/source_info_util.py:136) carrying the originating Python frame through every primitive binding, and the MLIR lowering converts it to aloc(...)attribute viasource_info_to_location(jax/_src/interpreters/mlir.py:524). Triton emitslocattributes as well, but only because each visitor method explicitly attaches the source location it is responsible for; tracing gets the annotation as a side effect of every primitive dispatch.
What Tracing Gives Up
The largest tracing cost is succinct expressiveness of the control flow.
Every Python if is evaluated by the Python interpreter at trace time, so if some_tracer > 0: is illegal: the condition is an abstract value rather than a concrete boolean, and Python raises. Branching on a runtime value therefore requires the explicit combinators lax.cond, lax.fori_loop, or pl.when, and the resulting code is more verbose than its Triton counterpart. A function written in plain Python control flow does not automatically function as a kernel; it must be refactored.
The Tokamax SM90 attention example earlier in the post shows what this looks like in practice: a pair of lax.fori_loop calls in the causal branch versus a single one in the non-causal, with a closure-specialised inner body to keep the combinators small. Triton would have written the same algorithm with two ordinary Python for loops and an ordinary if: visually cleaner at the call site, at the cost of the reader having to keep the compile-time / runtime distinction straight in their head. Pallas trades the visual overhead for the syntactic distinction.
Other tracing taxes are less visible but real: shape polymorphism requires explicit machinery (jax.export, polymorphic shape specs) rather than falling out of the source; effectful operations need lifting into JAX primitives; and a kernel parameter that ought to be compile-time but happens to arrive as a concrete Python value during one trace and as a tracer during the next will silently produce two different jaxprs. None of these are insurmountable, but they are real friction.
The two approaches treat Python itself very differently. Triton overloads if to mean either compile-time or runtime branching depending on the type of the condition, and pays for the overload indefinitely: every reader must re-derive which semantics is in force at each site, and every new Python feature needs a corresponding visit_X method. Pallas keeps the two regimes syntactically distinct and is never obliged to implement a visitor for ast.ListComp: a list comprehension simply executes under the Python interpreter, producing either a list of tracers or a list of concrete values, either of which is acceptable.
Parsing brings cleaner surface syntax at the cost of making generic libraries painful to write, while tracing accepts syntactic overhead in simple kernels in exchange for that expressiveness.
Two Toy Implementations
The architectural difference between the two approaches is compact enough to fit on a single screen. The two implementations below are sketches rather than runnable compilers (they omit a proper symbol table and a real type system), but they make the shape of each strategy concrete. The first is an AST-based mini-DSL.
import ast, inspect, textwrap
class ASTCompiler(ast.NodeVisitor):
def __init__(self, constexprs):
self.constexprs = constexprs # {name: int}
self.ops = []
self.tmp = 0
def fresh(self):
self.tmp += 1
return f"%{self.tmp}"
def visit_Name(self, node):
if node.id in self.constexprs:
return self.constexprs[node.id] # compile-time value
return node.id # runtime SSA name
def visit_BinOp(self, node):
l, r = self.visit(node.left), self.visit(node.right)
op = {ast.Add: 'add', ast.Mult: 'mul'}[type(node.op)]
t = self.fresh()
self.ops.append(f"{t} = {op} {l}, {r}")
return t
def visit_If(self, node):
cond = self.visit(node.test)
if isinstance(cond, int): # constexpr branch
chosen = node.body if cond else node.orelse
for stmt in chosen: self.visit(stmt)
else: # runtime branch
self.ops.append(f"scf.if {cond} {{ ... }} else {{ ... }}")
def visit_Assign(self, node):
v = self.visit(node.value)
self.ops.append(f"{node.targets[0].id} = {v}")
def compile_ast(fn, **constexprs):
src = textwrap.dedent(inspect.getsource(fn))
tree = ast.parse(src).body[0]
c = ASTCompiler(constexprs)
for stmt in tree.body: c.visit(stmt)
return "\n".join(c.ops)The tracing analogue is comparable in size.
_ops = []
_tmp = [0]
class Tracer:
def __init__(self, name): self.name = name
def _bin(self, other, op):
_tmp[0] += 1
t = Tracer(f"%{_tmp[0]}")
rhs = other.name if isinstance(other, Tracer) else other
_ops.append(f"{t.name} = {op} {self.name}, {rhs}")
return t
def __add__(self, o): return self._bin(o, 'add')
def __mul__(self, o): return self._bin(o, 'mul')
def __bool__(self):
raise TypeError("cannot branch on a tracer; use cond()")
def cond(pred, then_fn, else_fn): # runtime branch on a tracer
_ops.append(f"scf.if {pred.name} {{ ... }} else {{ ... }}")
def trace(fn, *arg_names):
_ops.clear(); _tmp[0] = 0
args = [Tracer(n) for n in arg_names]
fn(*args)
return "\n".join(_ops)The two implementations differ by only a handful of lines, yet every distinction between them becomes a larger distinction between Triton and Pallas. The AST-based version requires a visitor for every Python construct the user might invoke, whereas the tracing version inherits every Python construct for free provided it terminates in operator dispatch. In the AST version, if x: demands special-casing inside visit_If for the constexpr and runtime cases; in the tracing version the same line either succeeds, because x is a concrete Python boolean, or raises, because x is a tracer and the author must instead reach for the explicit cond(pred, ...) combinator. Both behaviours are useful; only one needs a 1700-line visitor to implement.
Conclusion
Tracing pays the cost of awkward runtime syntax. In exchange it gets a much simpler and more robust compiler implementation, and the ability to express complex generic algorithms through Python metaprogramming, which is often what a reusable set of libraries requires.
Source for the snippets above: Triton commit at triton-lang/triton, CUTLASS at NVIDIA/cutlass, JAX at jax-ml/jax. The two CuTe bugs are #3266 and #3268. The compile timing in the receipt box was measured locally in nvidia/cuda:12.5.0-devel-ubuntu22.04; the figure will vary with toolkit version and host configuration. Corrections and counter-examples are welcome by email.