# Getting Started with Triton for CUDA Kernel Development
# Preface
It’s been a while since I last wrote a blog post. Over the past year, I’ve been heads-down in research, but I’ve kept a close eye on the AI infra field. I deeply feel that this field is getting intensely competitive—the performance pressure of training and inference for large models forces every team to write custom GPU kernels by hand, and the landscape of “what tools to use” is currently undergoing a reshuffle. Learning mainstream tile-based GPU DSLs (Domain-Specific Languages) is extremely helpful for understanding the trend of AI kernel development in the coming years, hence this article.
The current competition among tile-based GPU DSLs is fierce: OpenAI’s Triton has the largest community and the most comprehensive ecosystem; TVM’s Tilelang takes a different approach; NVIDIA’s own CuTe/CUTLASS represents the extreme of C++ template metaprogramming. The three have completely different styles, but if you could only learn one, Triton is almost certainly the best starting point. Writing kernels in Python has the lowest barrier to entry, and the backend for PyTorch 2.0’s torch.compile is Triton itself.
This is the first post in the series, walking through the complete path from scratch: “Vector Addition → Matrix Multiplication → FlashAttention”.
# 1. Why Triton?
# 1.1 The Evolution of GPU Programming Paradigms
GPU programming has never been easy. Traditional CUDA C++ requires developers to think at the thread level: manually managing shared memory, manually handling bank conflicts, and manually orchestrating asynchronous copy pipelines. Building on this, CUTLASS uses C++ templates to abstract tile-level operations, significantly reducing the cognitive load, but it still demands profound C++ template metaprogramming skills.
Triton’s positioning is very clear: Write kernels in Python that achieve 90%+ of cuBLAS performance. It elevates the programming granularity from the thread level to the tile level, leaving the tedious low-level details for the compiler to handle.
The comparison table below summarizes the core differences between the three paradigms:
| Dimension | CUDA C++ | CUTLASS | Triton |
|---|---|---|---|
| Programming Granularity | Thread-level | Tile-level (C++ Templates) | Tile-level (Python) |
| Shared Memory | Manual management | Template abstraction | Compiler automated |
| Tensor Cores | Hand-written PTX / wmma | CuTe Layouts | tl.dot auto-mapping |
| Bank Conflicts | Manual swizzle | Template swizzle | Compiler automated |
| Pipelining | Manual cp.async | Template parameters | Single num_stages param |
| Learning Curve | High | Medium-High | Low |
| Performance Ceiling | 100% | ~98% | ~90-95% |
An interesting phenomenon: Triton’s performance ceiling is indeed slightly lower, but in real-world projects, there are far more people who can write Triton to 90% of cuBLAS than those who can write CUDA C++ to 95%. Development efficiency is a part of performance itself.
# 1.2 Compilation Pipeline
Triton’s compilation process starts with Python source code, goes through multiple layers of Intermediate Representation (IR) transformations, and ultimately generates GPU binaries:

Triton Compilation Pipeline Diagram
The core of the entire pipeline lies in the transformation from Triton IR to Triton GPU IR. At this stage, the compiler maps abstract Tile-level operations to specific thread, Warp, and Tensor Core instructions. “Developers focus on Tile-level logic, while the compiler is responsible for thread-level implementation”—this is the core design philosophy behind how Triton lowers the barrier for kernel development.
# 2. Core Concepts & Instruction Cheat Sheet
# 2.1 Execution Model
Triton’s execution model has three key concepts:
- Program: A single execution instance of a kernel, similar to a thread block in CUDA. Each program runs independently and has its own program ID.
- Grid: The arrangement of all program instances, supporting 1D, 2D, and 3D.
- Tile: The block of data processed by each program. The size of the tile is determined by a compile-time constant (
tl.constexpr).
The most fundamental difference from CUDA: CUDA programmers think about “what each thread does”, while Triton programmers think about “which tile each program processes”. How threads collaborate and how data is distributed among registers is entirely handed over to the compiler.

Triton Execution Model Diagram
Core APIs:
#
tl.program_id
| |
Returns the index of the current program along the specified axis. axis is the dimension of the program ID, ranging from 0, 1, to 2, corresponding to the different axes of a 1D, 2D, or 3D grid, respectively. For example, in a 2D grid, tl.program_id(0) returns the row index, and tl.program_id(1) returns the column index.
#
tl.arange
| |
Returns a 1D tensor containing contiguous integers in the range [start, end). Commonly used to construct offset indices within a tile. BLOCK_SIZE must be a power of 2.
#
tl.constexpr
| |
A compile-time constant decorator. Certain parameters in a Triton kernel must be determined at compile time (such as tile sizes). Once declared with tl.constexpr, these parameters become compile-time constants, allowing the compiler to perform optimizations like loop unrolling and conditional branch elimination based on them.
# 2.2 Memory Instructions
There are two instructions used for data transfer in Triton: tl.load and tl.store, which are responsible for loading data from VRAM into registers and writing data from registers back to VRAM, respectively.
For data transfer instructions, the shape and data type of the input data block and output data block must be exactly the same.
#
tl.load
| |
Loads data from VRAM into registers.
pointercan be an integer address, a tensor of pointers, or a block pointer.maskis a boolean tensor with the same shape aspointer, used to indicate which positions are valid memory accesses. Invalid positions will be replaced by the value ofother.otheris the default value used to replace invalid positions whenmaskis False.cache_modifierandeviction_policyare advanced parameters used to control cache behavior, usually not needing modification.
Warning
Out-of-bounds access without a mask is undefined behavior—it won’t throw an error, but the results are unpredictable and extremely hard to debug.
The
otherparameter must be used in conjunction withmask, otherwise an error will be raised.
cache_modifieris only effective on NVIDIA GPUs.
Tip
1D Example:
| |
This means each program loads a vector block of length BLOCK_SIZE. The mask ensures that out-of-bounds positions are not accessed when n_elements is not a multiple of BLOCK_SIZE.
Tip
2D Example:
| |
Here, [:, None] turns (BLOCK_M,) into (BLOCK_M, 1), and [None, :] turns (BLOCK_N,) into (1, BLOCK_N). After broadcasting, a pointer tensor of (BLOCK_M, BLOCK_N) is obtained. This is the most common 2D indexing pattern in Triton.
#
tl.store
| |
Writes data from registers back to VRAM. Parameters are similar to tl.load.
Tip
1D Example:
| |
This means each program writes a vector block of length BLOCK_SIZE back to VRAM. The mask ensures that out-of-bounds positions are not written to.
Tip
2D Example:
| |
This means each program writes a (BLOCK_M, BLOCK_N) tile back to VRAM. The mask ensures out-of-bounds positions are not written to.
#
tl.make_block_ptr
Block pointer is a high-level abstraction provided by Triton, using the semantics of a “matrix sub-block” to describe memory access, rather than manually calculating pointer offsets:
| |
baseis the base pointer, which can be an integer address or a tensor.shapeis the full shape of the parent tensor, e.g.,(M, K).stridesare the strides of the parent tensor, e.g.,(stride_m, stride_k).offsetsare the starting offsets, e.g.,(pid_m * BM, 0).block_shapeis the shape of the block to be loaded, e.g.,(BM, BK).orderis the memory layout order,(1, 0)implies row-major,(0, 1)implies column-major.
Tip
Example:
tl.loadusing a Block pointer:
| |
Compared to manual pointer arithmetic:
| |
The semantics are much clearer, and the compiler has more room for optimization.
#
tl.advance
Advance is a dedicated instruction for block pointers provided by Triton, used to semantically cleanly advance a block pointer:
| |
ptris a block pointer.offsetsare the block offsets to advance, e.g.,(0, BLOCK_K)means advancing one block along the K dimension.- Returns a new block pointer pointing to the newly advanced position.
- Clear semantics allow the compiler to automatically handle boundary checks and optimizations.
Compared to manual pointer arithmetic:
| |
Why are block pointers recommended? Three reasons:
- Automatic boundary checking: Handled via a single
boundary_check=(0, 1)parameter, eliminating the need to write manual mask broadcasts. - Triggers TMA: On the Hopper (H100) architecture, block pointers automatically map to the Tensor Memory Accelerator (TMA), drastically improving memory bandwidth utilization.
- Larger compiler optimization space: Block pointers carry richer semantic information, allowing the compiler to automatically perform optimizations like swizzling.
# 2.3 Compute Instructions
#
tl.dot
| |
Computes matrix multiplication out = input @ other. Assuming input has shape (M, K) and other has shape (K, N), the output shape will be (M, N).
Tip
Example: Typical usage of
tl.dot
| |
tl.dot automatically maps to the MMA (Matrix Multiply-Accumulate) instructions of Tensor Cores. This is the most performance-critical operation in Triton—with a single tl.dot call, the compiler automatically generates wmma (Volta/Turing), mma (Ampere), or wgmma (Hopper) instructions for you.
Supported data type combinations:
| Input A | Input B | Accumulator | Hardware Requirements |
|---|---|---|---|
| float16 | float16 | float32 | Volta+ (SM70) |
| bfloat16 | bfloat16 | float32 | Ampere+ (SM80) |
| float8e4nv | float8e4nv | float32 | Hopper+ (SM90) |
| int8 | int8 | int32 | Turing+ (SM75) |
| float32 | float32 | float32 | Defaults to TF32 (SM80+) |
Important
Accumulators must ALWAYS use float32 (int32 for integers). Using float16 for an accumulator will result in severe precision loss. This is an iron rule of Triton kernels.
# Reduction Operations
| |
These reduction operations appear frequently in kernels like softmax and LayerNorm. The compiler maps them to efficient warp-level shuffle reductions.
# Conditional Selection
| |
Semantically identical to NumPy’s np.where. Note: both branches are evaluated, so it cannot be used to prevent out-of-bounds access. A typical use case is implementing a causal mask—replacing attention scores of future positions with -inf.
# 2.4 Others
#
@triton.autotune
Auto-tuning is a major feature of Triton. You simply provide a set of candidate configurations, and Triton will automatically benchmark each configuration on the first call, then cache the optimal one:
| |
The key parameter is crucial: when the values of M, N, and K change, the optimal configuration often changes as well (small matrices suit small tiles, large matrices suit large tiles). Thus, Triton searches for the optimal configuration for each different set of key values.
#
@triton.heuristics
The heuristics decorator is used to derive compile-time constants from runtime parameters, avoiding unnecessary benchmark overhead:
| |
Typical application: when K is exactly a multiple of BLOCK_K, mask checks can be skipped, saving conditional branches.
# 3. Practice I: Vector Addition
Problem: Given two vectors A and B of length N, compute C[i] = A[i] + B[i].
This is the simplest Triton kernel, but it encompasses all the fundamental elements of Triton programming.
# 3.1 Tiling Strategy

Triton Vector Addition Tiling Strategy Diagram
Each program processes contiguous BLOCK_SIZE elements. The total number of programs = ceil(N / BLOCK_SIZE).
# 3.2 Complete Implementation
| |
# 3.3 Key Takeaways
- The role of the mask: When
n_elementsis not a multiple ofBLOCK_SIZE, part of the offsets in the final program will exceed the array bounds. The mask ensures these out-of-bounds positions are neither read nor written. Forgetting the mask = undefined behavior. - Why BLOCK_SIZE must be a power of 2: This is a hardware constraint of
tl.arange. The GPU warp size is 32, and BLOCK_SIZE must be a multiple of the warp size to efficiently map to hardware threads. triton.cdiv(n, BLOCK_SIZE): Ceiling division, equivalent toceil(n / BLOCK_SIZE). It ensures all elements are covered.
# 4. Practice II: Matrix Multiplication
Matrix Multiplication (GEMM) is the most important computation kernel on the GPU. Almost all computational load in deep learning ultimately boils down to it. In this section, we will start from the tiling strategy and progressively build a high-performance Triton GEMM kernel.
# 4.1 Tiling Strategy

Triton Matrix Multiplication Tiling Strategy Diagram
Core algorithm:
- Each program is responsible for computing a
(BLOCK_M, BLOCK_N)sub-block of matrix C. - Initialize an fp32 accumulator to zero.
- Loop along the K dimension: In each step, load a
(BLOCK_M, BLOCK_K)block from A and a(BLOCK_K, BLOCK_N)block from B, and perform atl.dot. - After the loop ends, cast the accumulator to the output precision and write it back to HBM VRAM.
# 4.2 Basic Implementation
| |
This is the most basic version, taking only around 50 lines of code. The following performance tests were conducted on an A100. As you can see, the peak performance reaches about 70% of the A100’s theoretical fp16 compute capacity (312 TFLOPS).
| Size (M, K, N) | Time (ms) | TFLOPS |
|---|---|---|
| (512, 1024, 512) | 0.025 | 21.34 |
| (1024, 1024, 1024) | 0.026 | 82.78 |
| (2048, 2048, 2048) | 0.094 | 182.18 |
| (4096, 4096, 4096) | 0.631 | 217.91 |
| (8192, 8192, 8192) | 6.297 | 174.62 |
| (16384, 16384, 16384) | 59.825 | 147.03 |
# 4.3 Advanced Optimizations
Matrix multiplication (GEMM) is an extremely fundamental operator. Common optimization strategies include swizzling, pipelining, warp specialization, auto-tuning, WMMA, etc. Shared memory swizzling is already handled automatically by Triton to resolve bank conflicts. Therefore, we focus more on Tile Swizzling (Grouped Ordering), warp specialization, pipeline overlap, and auto-tuning. The implementation of these optimizations in Triton is detailed below.
# Optimization 1: Grouped Ordering
Grouped Ordering is a tile scheduling strategy primarily aimed at solving the issue of low L2 cache hit rates.
In the basic GEMM implementation earlier, program IDs were mapped in a Row-major fashion: program IDs were assigned by row first, then by column.
- Execution Order: Finish computing all tiles in row 0, then move to row 1.
- Cache Issue: When computing row 0, we need to load the 0th row block of matrix A and all column blocks of matrix B. If matrix B is very large, by the time we start computing row 1, the matrix B data previously loaded into the L2 cache might have already been evicted.
- Consequence: Every row computation has to re-read matrix B entirely from VRAM, causing immense bandwidth pressure.
The core idea of Grouped Ordering is to divide computations into groups along both the M dimension (rows) and N dimension (columns), allowing programs within the same group to execute consecutively. For example, setting GROUP_SIZE_M=4 makes every 4 rows a group, ensuring that a small chunk of matrix B data can be reused by multiple programs while it remains in the L2 cache.

Grouped Ordering Diagram
In a naive ordering, programs (0,0) and (1,0) both require the 0th column tile of B, but they are separated by num_pid_n programs. When num_pid_n is large, the 0th column tile of B is long evicted from the L2 cache. Grouped Ordering makes programs that require the same B tile execute consecutively, significantly boosting the L2 hit rate.
# Optimization 2: Block Pointer
Rewrite the core loop using tl.make_block_ptr:
| |
Section 2.2 already introduced tl.make_block_ptr and tl.advance, so we won’t repeat it here. In short, the benefits of this optimization are:
- Cleaner code: No need to manually calculate pointer offsets and masks; semantics are much clearer.
- Automatic boundary checking: The
boundary_check=(0, 1)parameter lets the compiler automatically handle out-of-bounds access, avoiding the complex mask calculation and broadcasting from the previous version. - Triggers TMA: On Hopper (H100) architecture, block pointers automatically map to TMA, dramatically improving memory bandwidth utilization.
# Optimization 3: Autotune
We can use @triton.autotune and @triton.heuristics for auto-tuning and heuristic optimization, allowing Triton to automatically select the optimal combination of BLOCK_M, BLOCK_N, BLOCK_K, and GROUP_SIZE_M under different M, N, and K dimensions, thereby achieving excellent performance across various matrix sizes.
The final implemented version looks like this:
| |
While the lines of code only increased to roughly 100, the performance improvement is significant, especially on large matrices. It already closely approaches the fp16 performance of cuBLAS:
| Size (M, K, N) | Time (ms) | TFLOPS |
|---|---|---|
| (512, 1024, 512) | 0.014 | 38.31 |
| (1024, 1024, 1024) | 0.022 | 98.55 |
| (2048, 2048, 2048) | 0.091 | 189.26 |
| (4096, 4096, 4096) | 0.621 | 221.42 |
| (8192, 8192, 8192) | 4.905 | 224.14 |
| (16384, 16384, 16384) | 38.054 | 231.15 |
# 5. Practice III: FlashAttention
# 5.1 Memory Bottleneck of Standard Attention
The standard Multi-Head Attention computation process is as follows:
- Compute Attention Scores: Multiply the query matrix $Q$ with the transposed key matrix $K^T$ to generate the similarity matrix $S$. $$S = Q K^T, \quad S \in \mathbb{R}^{B \times H \times N \times N}$$ The time and space complexity for this step are both $O(N^2)$, producing a massive intermediate tensor.
- Normalization (Softmax): Scale $S$ and apply the Softmax operation to obtain the attention weight probability distribution $P$. $$P = \text{softmax}\left(\frac{S}{\sqrt{d_k}}\right), \quad P \in \mathbb{R}^{B \times H \times N \times N}$$ This produces another $O(N^2)$ full-scale weight matrix.
- Compute Weighted Output: Use the weight matrix $P$ to perform a weighted aggregation over the value matrix $V$. $$O = P V, \quad O \in \mathbb{R}^{B \times H \times N \times d}$$
This process generates two intermediate matrices of size $O(N^2)$: $S = Q K^T$ and $P = \text{softmax}(S)$. When the sequence length $N$ scales to 32K or even 128K, VRAM is instantly blown out. Concurrently, frequently moving these two giant matrices between HBM and SRAM completely bottlenecks the bandwidth.
FlashAttention’s Solution: Do not instantiate the $S$ and $P$ matrices. Instead, use block partitioning (Tiling) to fuse the computation entirely within SRAM in one go, reducing the memory complexity to $O(N)$.
# 5.2 Online Softmax Algorithm
The core idea of Online Softmax is to maintain a state triplet $(m_i, l_i, \mathbf{o}_i)$ for each row, implementing incremental computation through iterative updates:
- $m_i$: The maximum value of the row processed so far.
- $l_i$: The current local normalization constant (Softmax denominator).
- $\mathbf{o}_i$: The current weighted accumulation result (output vector).
For each new $K/V$ block $j$, the update steps are as follows:
- Compute Local Similarity: Calculate the raw score $S^{(j)}$ for the current block: $$S^{(j)} = Q_{\text{block}} \times K_{\text{block}, j}^T$$
- Update Row Maximum: Compare the maximum of the current block with the historical maximum to determine the new global maximum $m_{\text{new}}$: $$m_{\text{new}} = \max(m_i, \text{row_max}(S^{(j)}))$$
- Compute Correction Coefficient: To ensure numerical stability and alignment across different blocks, compute the scaling factor $\alpha$ for the old statistics: $$\alpha = \exp(m_i - m_{\text{new}})$$
- Update Normalization Denominator: Align the old denominator using the correction coefficient, and add the contribution from the current block: $$l_{\text{new}} = l_i \cdot \alpha + \sum \exp(S^{(j)} - m_{\text{new}})$$
- Update Accumulator: Rescale the old output vector and accumulate the new contribution from the current block: $$\mathbf{o}_{\text{new}} = \mathbf{o}_i \cdot \alpha + \exp(S^{(j)} - m_{\text{new}}) \times V_{\text{block}, j}$$
- Iterate State: Update the current state: $$m_i \leftarrow m_{\text{new}}, l_i \leftarrow l_{\text{new}}, \mathbf{o}_i \leftarrow \mathbf{o}_{\text{new}}$$
Algorithm Correctness: The correction coefficient $\alpha = \exp(m_i - m_{\text{new}})$ is the key to the algorithm. When a larger local maximum is found ($m_{\text{new}} > m_i$), $\alpha$ acts as a scaling factor less than 1, compensating for exponential terms previously computed based on a smaller $m_i$. This mechanism ensures that at the end of traversing all blocks, the final output $O = \mathbf{o}_i / l_i$ is mathematically strictly equivalent to a one-time global computation.
# 5.3 Kernel Implementation
| |
get_autotune_configdefines multiple configuration combinations, covering different sequence lengths and head dimensions to adapt to various use cases.- Grouped Ordering is still applicable, ensuring highly efficient cache utilization of K/V blocks during computation.
- Pre-scale Q + exp2 replaces the traditional scaling and Softmax operations, utilizing the Online Softmax algorithm to achieve numerically stable and memory-efficient attention computation.
- Split causal loop: The inner loop is split into two phases. Phase 1 handles blocks completely below the diagonal, allowing the causal mask to be skipped; Phase 2 performs the
tl.wherecausal mask only for the 1–2 blocks near the diagonal.
# 6. Benchmark: How to Test Your Kernel’s Performance
Writing a correct kernel is just the first step. In GPU programming, performance is the ultimate benchmark. Triton provides an entire benchmarking suite.
#
6.1 torch.testing.assert_close
Numerical correctness is a prerequisite for performance optimization. torch.testing.assert_close is a powerful tool to verify whether two tensors are approximately equal within specified absolute error (atol) and relative error (rtol) bounds:
| |
Tolerance reference for different precisions:
| dtype | atol | rtol | Description |
|---|---|---|---|
| float32 | 1e-5 | 1e-5 | Most strict |
| float16 | 1e-2 | 1e-2 | Half-precision cumulative error is larger |
| bfloat16 | 1e-2 | 1e-2 | Shorter mantissa, lower precision |
| float8 | 1e-1 | 1e-1 | Extremely low precision, tolerance must be loose |
#
6.2 triton.testing.do_bench
# Timing
| |
Tip
Example: Measuring the latency of a softmax kernel:
| |
Warmup is crucial: the first call triggers JIT compilation, which can take hundreds of times longer than the actual execution time. do_bench will run a set of warmup rounds, discarding the results, before officially starting the timer.
# FLOPS Throughput
FLOPS (Floating-point Operations Per Second) is a key metric for evaluating the performance of compute-bound kernels. Before calculating this metric, you need to know how many total operations the algorithm requires. Theoretical compute load (FLOPs) is the total number of floating-point operations needed by the algorithm itself, independent of specific hardware parallelism. Common ops like add, mul, and sub count as 1 FLOP, while fma (fused multiply-add) counts as 2 FLOPs. Special built-in functions (like sin, exp) are handled by the SFU (Special Function Unit), and their compute conversions vary across architectures. However, in standard FLOPS calculations, usually only basic multiply-add operations are counted.
Example: For a simple vector addition C[i] = A[i] + B[i] with total length N, the total compute load is N FLOPs. For a matrix multiplication C = A @ B, where A is MxK and B is KxN, the total compute load is 2 * M * N * K FLOPs.
Furthermore, you can compute the kernel’s actual performance FLOPS = FLOPs / execution time (seconds). For example, a GEMM Kernel:
| |
# Bandwidth Utilization (GBPS)
For memory-bound kernels, bandwidth utilization is a more appropriate performance metric. The calculation method is: Bandwidth (GB/s) = Transferred Data Volume (GB) / Execution Time (seconds). The data volume includes all bytes read from and written to memory.
| |
# Roofline Model
The Roofline model is a classic tool to determine whether a kernel is compute-bound or memory-bound. The core concept is Arithmetic Intensity (AI) = Compute Volume / Data Transfer Volume. For the H100, the inflection point is around 295 FLOP/Byte:
- AI < 295: memory-bound; optimization direction is to reduce memory access and increase bandwidth utilization.
- AI > 295: compute-bound; optimization direction is to increase Tensor Core utilization.
| |
#
6.3 @triton.testing.perf_report
@triton.testing.perf_report is a decorator used in conjunction with triton.testing.Benchmark to automate multi-configuration benchmarking and generate plots.
| |
Explanation of triton.testing.Benchmark parameters:
| Parameter | Type | Description |
|---|---|---|
| x_names | List[str] | List of x-axis parameter names |
| x_vals | List[List] | List of values for each x-axis parameter |
| line_arg | str | Parameter name used to distinguish different lines |
| line_vals | List[Any] | List of parameter values corresponding to different lines |
| line_names | List[str] | List of line names (used in the legend) |
| styles | List[Tuple] | Color and line style for each curve |
| ylabel | str | y-axis label |
| plot_name | str | Filename for the generated plot (excluding extension) |
| args | Dict[str, Any] | Other fixed parameters |

Softmax Performance
# 7. Cutting-Edge: FP4 and the Blackwell Architecture
# 7.1 Blackwell Architecture Overview
Standing in 2026, NVIDIA’s Blackwell (B200/B300) architecture has already achieved scaled deployment:
| Feature | Hopper (H100) | Blackwell (B200) |
|---|---|---|
| Microarchitecture | SM90 | SM100 |
| FP4 TOPS | N/A | ~4500 |
| FP8 TOPS | ~1979 | ~2250 |
| HBM Bandwidth | 3.35 TB/s | ~8 TB/s |
| Tensor Core | 4th Gen | 5th Gen |
# 7.2 NVFP4 (E2M1) Format

NVFP4 E2M1 Format
Having only 16 representable values means the expressive capacity of FP4 is extremely limited. However, when combined with per-block scaling, where a group of 16 FP4 values shares a single FP8 scaling factor, it can cover a much larger numerical range. This is essentially the same idea as DeepGEMM’s per-128-element FP8 scaling—using a coarse-grained, higher-precision scale to compensate for the lack of expressive power in fine-grained, low-precision data.
# 7.3 Impact on Triton Programming
As Triton’s support for SM100 matures:
- Instruction Evolution:
tl.dotwill natively supportFP4 × FP4 -> FP32, mapped to the 5th-generation Tensor Coretcgen05under the hood. - TMA Grand Slam: In the Hopper era, TMA was primarily responsible for Loads, but Blackwell supports a much more powerful Store-direction TMA and massive-scale Multicast.
- Deeper Pipelines: Thanks to the 256KB+ Shared Memory,
num_stages=5or even6is becoming the norm, and memory latency will be perfectly hidden.
# What’s Next
Triton is currently the most mature tile-based GPU DSL, but it is not the only choice. Tilelang takes a completely different path—offering more explicit control over memory hierarchies and more flexible abstractions for layout transformations. Next time, we will discuss Tilelang’s programming model and compare it with Triton.
