Featured image of post Getting Started with Triton for CUDA Kernel Development

Getting Started with Triton for CUDA Kernel Development

A systematic introduction to the Triton tile-based GPU programming model and practical optimization techniques, from vector addition to FlashAttention.

# Getting Started with Triton for CUDA Kernel Development

# Preface

It’s been a while since I last wrote a blog post. Over the past year, I’ve been heads-down in research, but I’ve kept a close eye on the AI infra field. I deeply feel that this field is getting intensely competitive—the performance pressure of training and inference for large models forces every team to write custom GPU kernels by hand, and the landscape of “what tools to use” is currently undergoing a reshuffle. Learning mainstream tile-based GPU DSLs (Domain-Specific Languages) is extremely helpful for understanding the trend of AI kernel development in the coming years, hence this article.

The current competition among tile-based GPU DSLs is fierce: OpenAI’s Triton has the largest community and the most comprehensive ecosystem; TVM’s Tilelang takes a different approach; NVIDIA’s own CuTe/CUTLASS represents the extreme of C++ template metaprogramming. The three have completely different styles, but if you could only learn one, Triton is almost certainly the best starting point. Writing kernels in Python has the lowest barrier to entry, and the backend for PyTorch 2.0’s torch.compile is Triton itself.

This is the first post in the series, walking through the complete path from scratch: “Vector Addition → Matrix Multiplication → FlashAttention”.

# 1. Why Triton?

# 1.1 The Evolution of GPU Programming Paradigms

GPU programming has never been easy. Traditional CUDA C++ requires developers to think at the thread level: manually managing shared memory, manually handling bank conflicts, and manually orchestrating asynchronous copy pipelines. Building on this, CUTLASS uses C++ templates to abstract tile-level operations, significantly reducing the cognitive load, but it still demands profound C++ template metaprogramming skills.

Triton’s positioning is very clear: Write kernels in Python that achieve 90%+ of cuBLAS performance. It elevates the programming granularity from the thread level to the tile level, leaving the tedious low-level details for the compiler to handle.

The comparison table below summarizes the core differences between the three paradigms:

DimensionCUDA C++CUTLASSTriton
Programming GranularityThread-levelTile-level (C++ Templates)Tile-level (Python)
Shared MemoryManual managementTemplate abstractionCompiler automated
Tensor CoresHand-written PTX / wmmaCuTe Layoutstl.dot auto-mapping
Bank ConflictsManual swizzleTemplate swizzleCompiler automated
PipeliningManual cp.asyncTemplate parametersSingle num_stages param
Learning CurveHighMedium-HighLow
Performance Ceiling100%~98%~90-95%

An interesting phenomenon: Triton’s performance ceiling is indeed slightly lower, but in real-world projects, there are far more people who can write Triton to 90% of cuBLAS than those who can write CUDA C++ to 95%. Development efficiency is a part of performance itself.

# 1.2 Compilation Pipeline

Triton’s compilation process starts with Python source code, goes through multiple layers of Intermediate Representation (IR) transformations, and ultimately generates GPU binaries:

Triton Compilation Pipeline Diagram

Triton Compilation Pipeline Diagram

The core of the entire pipeline lies in the transformation from Triton IR to Triton GPU IR. At this stage, the compiler maps abstract Tile-level operations to specific thread, Warp, and Tensor Core instructions. “Developers focus on Tile-level logic, while the compiler is responsible for thread-level implementation”—this is the core design philosophy behind how Triton lowers the barrier for kernel development.

# 2. Core Concepts & Instruction Cheat Sheet

# 2.1 Execution Model

Triton’s execution model has three key concepts:

  • Program: A single execution instance of a kernel, similar to a thread block in CUDA. Each program runs independently and has its own program ID.
  • Grid: The arrangement of all program instances, supporting 1D, 2D, and 3D.
  • Tile: The block of data processed by each program. The size of the tile is determined by a compile-time constant (tl.constexpr).

The most fundamental difference from CUDA: CUDA programmers think about “what each thread does”, while Triton programmers think about “which tile each program processes”. How threads collaborate and how data is distributed among registers is entirely handed over to the compiler.

Triton Execution Model Diagram

Triton Execution Model Diagram

Core APIs:

# tl.program_id

1
tl.program_id(axis: int) -> int

Returns the index of the current program along the specified axis. axis is the dimension of the program ID, ranging from 0, 1, to 2, corresponding to the different axes of a 1D, 2D, or 3D grid, respectively. For example, in a 2D grid, tl.program_id(0) returns the row index, and tl.program_id(1) returns the column index.

# tl.arange

1
tl.arange(start: int, end: int) -> Tensor

Returns a 1D tensor containing contiguous integers in the range [start, end). Commonly used to construct offset indices within a tile. BLOCK_SIZE must be a power of 2.

# tl.constexpr

1
tl.constexpr(value)

A compile-time constant decorator. Certain parameters in a Triton kernel must be determined at compile time (such as tile sizes). Once declared with tl.constexpr, these parameters become compile-time constants, allowing the compiler to perform optimizations like loop unrolling and conditional branch elimination based on them.

# 2.2 Memory Instructions

There are two instructions used for data transfer in Triton: tl.load and tl.store, which are responsible for loading data from VRAM into registers and writing data from registers back to VRAM, respectively.

For data transfer instructions, the shape and data type of the input data block and output data block must be exactly the same.

# tl.load

1
2
3
4
5
tl.load(pointer: Union[int, Tensor, BlockPtr],
        mask: Optional[Tensor] = None,
        other: Optional[Union[int, float]] = None,
        cache_modifier: str = "",
        eviction_policy: str = "") -> Tensor

Loads data from VRAM into registers.

  • pointer can be an integer address, a tensor of pointers, or a block pointer.
  • mask is a boolean tensor with the same shape as pointer, used to indicate which positions are valid memory accesses. Invalid positions will be replaced by the value of other.
  • other is the default value used to replace invalid positions when mask is False.
  • cache_modifier and eviction_policy are advanced parameters used to control cache behavior, usually not needing modification.

Warning

Out-of-bounds access without a mask is undefined behavior—it won’t throw an error, but the results are unpredictable and extremely hard to debug.

The other parameter must be used in conjunction with mask, otherwise an error will be raised.

cache_modifier is only effective on NVIDIA GPUs.


Tip

1D Example:

1
2
3
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_elements
x = tl.load(x_ptr + offs, mask=mask, other=0.0)

This means each program loads a vector block of length BLOCK_SIZE. The mask ensures that out-of-bounds positions are not accessed when n_elements is not a multiple of BLOCK_SIZE.

Tip

2D Example:

1
2
3
4
5
6
7
# Load a (BLOCK_M, BLOCK_N) sub-block of the matrix
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)  # shape: (BLOCK_M,)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)  # shape: (BLOCK_N,)
# [:, None] and [None, :] generate broadcasting to construct a (BLOCK_M, BLOCK_N) pointer matrix
ptrs = base_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tile = tl.load(ptrs, mask=mask, other=0.0)

Here, [:, None] turns (BLOCK_M,) into (BLOCK_M, 1), and [None, :] turns (BLOCK_N,) into (1, BLOCK_N). After broadcasting, a pointer tensor of (BLOCK_M, BLOCK_N) is obtained. This is the most common 2D indexing pattern in Triton.

# tl.store

1
2
3
4
5
6
7
tl.store(
    pointer: Union[int, Tensor, BlockPtr],
    value: Tensor,
    mask: Optional[Tensor] = None,
    cache_modifier: str = "",
    eviction_policy: str = "",
)

Writes data from registers back to VRAM. Parameters are similar to tl.load.

Tip

1D Example:

1
2
3
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_elements
tl.store(out_ptr + offs, output, mask=mask)

This means each program writes a vector block of length BLOCK_SIZE back to VRAM. The mask ensures that out-of-bounds positions are not written to.

Tip

2D Example:

1
2
3
4
5
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ptrs = base_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(ptrs, tile, mask=mask)

This means each program writes a (BLOCK_M, BLOCK_N) tile back to VRAM. The mask ensures out-of-bounds positions are not written to.

# tl.make_block_ptr

Block pointer is a high-level abstraction provided by Triton, using the semantics of a “matrix sub-block” to describe memory access, rather than manually calculating pointer offsets:

1
2
3
4
5
6
7
8
ptr = tl.make_block_ptr(
    base: Union[int, Tensor],  
    shape: Tuple[int, ...],
    strides: Tuple[int, ...],
    offsets: Tuple[int, ...],
    block_shape: Tuple[int, ...],
    order: Tuple[int, ...],
) -> BlockPtr
  • base is the base pointer, which can be an integer address or a tensor.
  • shape is the full shape of the parent tensor, e.g., (M, K).
  • strides are the strides of the parent tensor, e.g., (stride_m, stride_k).
  • offsets are the starting offsets, e.g., (pid_m * BM, 0).
  • block_shape is the shape of the block to be loaded, e.g., (BM, BK).
  • order is the memory layout order, (1, 0) implies row-major, (0, 1) implies column-major.

Tip

Example: tl.load using a Block pointer:

1
2
3
4
5
6
7
8
9
ptr = tl.make_block_ptr(
    base=a_ptr,
    shape=(M, K),
    strides=(stride_am, stride_ak),
    offsets=(pid_m * BLOCK_M, 0),
    block_shape=(BLOCK_M, BLOCK_K),
    order=(1, 0),
)
tile = tl.load(ptr, boundary_check=(0, 1))

Compared to manual pointer arithmetic:

1
2
3
4
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)

The semantics are much clearer, and the compiler has more room for optimization.

# tl.advance

Advance is a dedicated instruction for block pointers provided by Triton, used to semantically cleanly advance a block pointer:

1
ptr = tl.advance(ptr, offsets: Tuple[int, ...]) -> BlockPtr
  • ptr is a block pointer.
  • offsets are the block offsets to advance, e.g., (0, BLOCK_K) means advancing one block along the K dimension.
  • Returns a new block pointer pointing to the newly advanced position.
  • Clear semantics allow the compiler to automatically handle boundary checks and optimizations.

Compared to manual pointer arithmetic:

1
a_ptrs += BLOCK_K * stride_ak

Why are block pointers recommended? Three reasons:

  1. Automatic boundary checking: Handled via a single boundary_check=(0, 1) parameter, eliminating the need to write manual mask broadcasts.
  2. Triggers TMA: On the Hopper (H100) architecture, block pointers automatically map to the Tensor Memory Accelerator (TMA), drastically improving memory bandwidth utilization.
  3. Larger compiler optimization space: Block pointers carry richer semantic information, allowing the compiler to automatically perform optimizations like swizzling.

# 2.3 Compute Instructions

# tl.dot

1
2
3
4
5
6
tl.dot(
    input: Tensor,
    other: Tensor,
    acc: Optional[Tensor] = None,
    input_precision: Optional[str] = None,
    max_num_imprecise_acc: Optional[int] = None) -> Tensor

Computes matrix multiplication out = input @ other. Assuming input has shape (M, K) and other has shape (K, N), the output shape will be (M, N).

Tip

Example: Typical usage of tl.dot

1
2
3
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
acc = tl.dot(a, b, acc)

tl.dot automatically maps to the MMA (Matrix Multiply-Accumulate) instructions of Tensor Cores. This is the most performance-critical operation in Triton—with a single tl.dot call, the compiler automatically generates wmma (Volta/Turing), mma (Ampere), or wgmma (Hopper) instructions for you.

Supported data type combinations:

Input AInput BAccumulatorHardware Requirements
float16float16float32Volta+ (SM70)
bfloat16bfloat16float32Ampere+ (SM80)
float8e4nvfloat8e4nvfloat32Hopper+ (SM90)
int8int8int32Turing+ (SM75)
float32float32float32Defaults to TF32 (SM80+)

Important

Accumulators must ALWAYS use float32 (int32 for integers). Using float16 for an accumulator will result in severe precision loss. This is an iron rule of Triton kernels.

# Reduction Operations

1
2
3
tl.sum(x, axis=0)    # Sum along axis=0
tl.max(x, axis=1)    # Max along axis=1
tl.min(x, axis=0)    # Min along axis=0

These reduction operations appear frequently in kernels like softmax and LayerNorm. The compiler maps them to efficient warp-level shuffle reductions.

# Conditional Selection

1
tl.where(condition, x, y)

Semantically identical to NumPy’s np.where. Note: both branches are evaluated, so it cannot be used to prevent out-of-bounds access. A typical use case is implementing a causal mask—replacing attention scores of future positions with -inf.

# 2.4 Others

# @triton.autotune

Auto-tuning is a major feature of Triton. You simply provide a set of candidate configurations, and Triton will automatically benchmark each configuration on the first call, then cache the optimal one:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=3),
        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=4),
    ],
    key=["M", "N", "K"],  
)
@triton.jit
def kernel(...):
    ...

The key parameter is crucial: when the values of M, N, and K change, the optimal configuration often changes as well (small matrices suit small tiles, large matrices suit large tiles). Thus, Triton searches for the optimal configuration for each different set of key values.

# @triton.heuristics

The heuristics decorator is used to derive compile-time constants from runtime parameters, avoiding unnecessary benchmark overhead:

1
@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0})

Typical application: when K is exactly a multiple of BLOCK_K, mask checks can be skipped, saving conditional branches.

# 3. Practice I: Vector Addition

Problem: Given two vectors A and B of length N, compute C[i] = A[i] + B[i]. This is the simplest Triton kernel, but it encompasses all the fundamental elements of Triton programming.

# 3.1 Tiling Strategy

Triton Vector Addition Tiling Strategy Diagram

Triton Vector Addition Tiling Strategy Diagram

Each program processes contiguous BLOCK_SIZE elements. The total number of programs = ceil(N / BLOCK_SIZE).

# 3.2 Complete Implementation

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr, y_ptr, out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(out_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n = output.numel()
    grid = lambda META: (triton.cdiv(n, META["BLOCK_SIZE"]),)
    add_kernel[grid](x, y, output, n, BLOCK_SIZE=1024)
    return output

if __name__ == "__main__":
    torch.manual_seed(42)
    size = 98432 
    x = torch.rand(size, device="cuda")
    y = torch.rand(size, device="cuda")
    output_triton = add(x, y)
    output_torch = x + y
    print(f"Max error: {(output_triton - output_torch).abs().max().item()}")

# 3.3 Key Takeaways

  1. The role of the mask: When n_elements is not a multiple of BLOCK_SIZE, part of the offsets in the final program will exceed the array bounds. The mask ensures these out-of-bounds positions are neither read nor written. Forgetting the mask = undefined behavior.
  2. Why BLOCK_SIZE must be a power of 2: This is a hardware constraint of tl.arange. The GPU warp size is 32, and BLOCK_SIZE must be a multiple of the warp size to efficiently map to hardware threads.
  3. triton.cdiv(n, BLOCK_SIZE): Ceiling division, equivalent to ceil(n / BLOCK_SIZE). It ensures all elements are covered.

# 4. Practice II: Matrix Multiplication

Matrix Multiplication (GEMM) is the most important computation kernel on the GPU. Almost all computational load in deep learning ultimately boils down to it. In this section, we will start from the tiling strategy and progressively build a high-performance Triton GEMM kernel.

# 4.1 Tiling Strategy

Triton Matrix Multiplication Tiling Strategy Diagram

Triton Matrix Multiplication Tiling Strategy Diagram

Core algorithm:

  1. Each program is responsible for computing a (BLOCK_M, BLOCK_N) sub-block of matrix C.
  2. Initialize an fp32 accumulator to zero.
  3. Loop along the K dimension: In each step, load a (BLOCK_M, BLOCK_K) block from A and a (BLOCK_K, BLOCK_N) block from B, and perform a tl.dot.
  4. After the loop ends, cast the accumulator to the output precision and write it back to HBM VRAM.

# 4.2 Basic Implementation

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
):
    # Row-major Mapping
    pid = tl.program_id(0)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n
    
    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
        acc = tl.dot(a, b, acc)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
        
    c = acc.to(tl.float16)
    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape[1] == b.shape[0], "Dimensions mismatch"
    assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=128, BLOCK_N=128, BLOCK_K=32
    )
    return c

This is the most basic version, taking only around 50 lines of code. The following performance tests were conducted on an A100. As you can see, the peak performance reaches about 70% of the A100’s theoretical fp16 compute capacity (312 TFLOPS).

Size (M, K, N)Time (ms)TFLOPS
(512, 1024, 512)0.02521.34
(1024, 1024, 1024)0.02682.78
(2048, 2048, 2048)0.094182.18
(4096, 4096, 4096)0.631217.91
(8192, 8192, 8192)6.297174.62
(16384, 16384, 16384)59.825147.03

# 4.3 Advanced Optimizations

Matrix multiplication (GEMM) is an extremely fundamental operator. Common optimization strategies include swizzling, pipelining, warp specialization, auto-tuning, WMMA, etc. Shared memory swizzling is already handled automatically by Triton to resolve bank conflicts. Therefore, we focus more on Tile Swizzling (Grouped Ordering), warp specialization, pipeline overlap, and auto-tuning. The implementation of these optimizations in Triton is detailed below.

# Optimization 1: Grouped Ordering

Grouped Ordering is a tile scheduling strategy primarily aimed at solving the issue of low L2 cache hit rates.

In the basic GEMM implementation earlier, program IDs were mapped in a Row-major fashion: program IDs were assigned by row first, then by column.

  • Execution Order: Finish computing all tiles in row 0, then move to row 1.
  • Cache Issue: When computing row 0, we need to load the 0th row block of matrix A and all column blocks of matrix B. If matrix B is very large, by the time we start computing row 1, the matrix B data previously loaded into the L2 cache might have already been evicted.
  • Consequence: Every row computation has to re-read matrix B entirely from VRAM, causing immense bandwidth pressure.

The core idea of Grouped Ordering is to divide computations into groups along both the M dimension (rows) and N dimension (columns), allowing programs within the same group to execute consecutively. For example, setting GROUP_SIZE_M=4 makes every 4 rows a group, ensuring that a small chunk of matrix B data can be reused by multiple programs while it remains in the L2 cache.

Grouped Ordering Diagram

Grouped Ordering Diagram

In a naive ordering, programs (0,0) and (1,0) both require the 0th column tile of B, but they are separated by num_pid_n programs. When num_pid_n is large, the 0th column tile of B is long evicted from the L2 cache. Grouped Ordering makes programs that require the same B tile execute consecutively, significantly boosting the L2 hit rate.

# Optimization 2: Block Pointer

Rewrite the core loop using tl.make_block_ptr:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
a_block = tl.make_block_ptr(a_ptr, (M, K), (stride_am, stride_ak),
                             (pid_m * BLOCK_M, 0), (BLOCK_M, BLOCK_K), (1, 0))
b_block = tl.make_block_ptr(b_ptr, (K, N), (stride_bk, stride_bn),
                             (0, pid_n * BLOCK_N), (BLOCK_K, BLOCK_N), (1, 0))
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
    a = tl.load(a_block, boundary_check=(0, 1))
    b = tl.load(b_block, boundary_check=(0, 1))
    acc = tl.dot(a, b, acc)
    a_block = tl.advance(a_block, (0, BLOCK_K))
    b_block = tl.advance(b_block, (BLOCK_K, 0))

Section 2.2 already introduced tl.make_block_ptr and tl.advance, so we won’t repeat it here. In short, the benefits of this optimization are:

  1. Cleaner code: No need to manually calculate pointer offsets and masks; semantics are much clearer.
  2. Automatic boundary checking: The boundary_check=(0, 1) parameter lets the compiler automatically handle out-of-bounds access, avoiding the complex mask calculation and broadcasting from the previous version.
  3. Triggers TMA: On Hopper (H100) architecture, block pointers automatically map to TMA, dramatically improving memory bandwidth utilization.

# Optimization 3: Autotune

We can use @triton.autotune and @triton.heuristics for auto-tuning and heuristic optimization, allowing Triton to automatically select the optimal combination of BLOCK_M, BLOCK_N, BLOCK_K, and GROUP_SIZE_M under different M, N, and K dimensions, thereby achieving excellent performance across various matrix sizes.

The final implemented version looks like this:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import torch
import triton
import triton.language as tl

def get_autotune_config():
    return [
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        # Good configs for fp8 inputs.
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4)
    ]

@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0})
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr, EVEN_K: tl.constexpr
):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    
    a_block = tl.make_block_ptr(
        a_ptr, (M, K), (stride_am, stride_ak), (pid_m * BLOCK_M, 0), (BLOCK_M, BLOCK_K), (1, 0))
    b_block = tl.make_block_ptr(
        b_ptr, (K, N), (stride_bk, stride_bn), (0, pid_n * BLOCK_N), (BLOCK_K, BLOCK_N), (1, 0))
        
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
        if EVEN_K:
            a = tl.load(a_block)
            b = tl.load(b_block)
        else:
            a = tl.load(a_block, boundary_check=(0, 1))
            b = tl.load(b_block, boundary_check=(0, 1))
        acc = tl.dot(a, b, acc)
        a_block = tl.advance(a_block, (0, BLOCK_K))
        b_block = tl.advance(b_block, (BLOCK_K, 0))
        
    c = acc.to(tl.float16)
    c_block = tl.make_block_ptr(c_ptr, (M, N), (stride_cm, stride_cn),
                                (pid_m * BLOCK_M, pid_n * BLOCK_N), (BLOCK_M, BLOCK_N), (1, 0))
    tl.store(c_block, c, boundary_check=(0, 1))

def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape[1] == b.shape[0], "Dimensions mismatch"
    assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
    
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

While the lines of code only increased to roughly 100, the performance improvement is significant, especially on large matrices. It already closely approaches the fp16 performance of cuBLAS:

Size (M, K, N)Time (ms)TFLOPS
(512, 1024, 512)0.01438.31
(1024, 1024, 1024)0.02298.55
(2048, 2048, 2048)0.091189.26
(4096, 4096, 4096)0.621221.42
(8192, 8192, 8192)4.905224.14
(16384, 16384, 16384)38.054231.15

# 5. Practice III: FlashAttention

# 5.1 Memory Bottleneck of Standard Attention

The standard Multi-Head Attention computation process is as follows:

  1. Compute Attention Scores: Multiply the query matrix $Q$ with the transposed key matrix $K^T$ to generate the similarity matrix $S$. $$S = Q K^T, \quad S \in \mathbb{R}^{B \times H \times N \times N}$$ The time and space complexity for this step are both $O(N^2)$, producing a massive intermediate tensor.
  2. Normalization (Softmax): Scale $S$ and apply the Softmax operation to obtain the attention weight probability distribution $P$. $$P = \text{softmax}\left(\frac{S}{\sqrt{d_k}}\right), \quad P \in \mathbb{R}^{B \times H \times N \times N}$$ This produces another $O(N^2)$ full-scale weight matrix.
  3. Compute Weighted Output: Use the weight matrix $P$ to perform a weighted aggregation over the value matrix $V$. $$O = P V, \quad O \in \mathbb{R}^{B \times H \times N \times d}$$

This process generates two intermediate matrices of size $O(N^2)$: $S = Q K^T$ and $P = \text{softmax}(S)$. When the sequence length $N$ scales to 32K or even 128K, VRAM is instantly blown out. Concurrently, frequently moving these two giant matrices between HBM and SRAM completely bottlenecks the bandwidth.

FlashAttention’s Solution: Do not instantiate the $S$ and $P$ matrices. Instead, use block partitioning (Tiling) to fuse the computation entirely within SRAM in one go, reducing the memory complexity to $O(N)$.

# 5.2 Online Softmax Algorithm

The core idea of Online Softmax is to maintain a state triplet $(m_i, l_i, \mathbf{o}_i)$ for each row, implementing incremental computation through iterative updates:

  • $m_i$: The maximum value of the row processed so far.
  • $l_i$: The current local normalization constant (Softmax denominator).
  • $\mathbf{o}_i$: The current weighted accumulation result (output vector).

For each new $K/V$ block $j$, the update steps are as follows:

  1. Compute Local Similarity: Calculate the raw score $S^{(j)}$ for the current block: $$S^{(j)} = Q_{\text{block}} \times K_{\text{block}, j}^T$$
  2. Update Row Maximum: Compare the maximum of the current block with the historical maximum to determine the new global maximum $m_{\text{new}}$: $$m_{\text{new}} = \max(m_i, \text{row_max}(S^{(j)}))$$
  3. Compute Correction Coefficient: To ensure numerical stability and alignment across different blocks, compute the scaling factor $\alpha$ for the old statistics: $$\alpha = \exp(m_i - m_{\text{new}})$$
  4. Update Normalization Denominator: Align the old denominator using the correction coefficient, and add the contribution from the current block: $$l_{\text{new}} = l_i \cdot \alpha + \sum \exp(S^{(j)} - m_{\text{new}})$$
  5. Update Accumulator: Rescale the old output vector and accumulate the new contribution from the current block: $$\mathbf{o}_{\text{new}} = \mathbf{o}_i \cdot \alpha + \exp(S^{(j)} - m_{\text{new}}) \times V_{\text{block}, j}$$
  6. Iterate State: Update the current state: $$m_i \leftarrow m_{\text{new}}, l_i \leftarrow l_{\text{new}}, \mathbf{o}_i \leftarrow \mathbf{o}_{\text{new}}$$

Algorithm Correctness: The correction coefficient $\alpha = \exp(m_i - m_{\text{new}})$ is the key to the algorithm. When a larger local maximum is found ($m_{\text{new}} > m_i$), $\alpha$ acts as a scaling factor less than 1, compensating for exponential terms previously computed based on a smaller $m_i$. This mechanism ensures that at the end of traversing all blocks, the final output $O = \mathbf{o}_i / l_i$ is mathematically strictly equivalent to a one-time global computation.

# 5.3 Kernel Implementation

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import triton
import triton.language as tl

def get_autotune_config():
    return [
        # Small blocks — short sequences / small head_dim
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        # Medium blocks
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        # Large blocks — long sequences
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        # Higher pipeline depth
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
    ]

@triton.autotune(
    configs=get_autotune_config(),
    key=['seq_len', 'HEAD_DIM'],
)
@triton.heuristics({
    "EVEN_N": lambda args: args["seq_len"] % args["BLOCK_N"] == 0,
})
@triton.jit
def flash_attention_fwd_kernel(
    Q, K, V, O,
    sm_scale,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vn, stride_vk,
    stride_oz, stride_oh, stride_om, stride_ok,
    n_heads,
    seq_len,
    n_batch_heads,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    EVEN_N: tl.constexpr,
):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(seq_len, BLOCK_M)
    num_pid_in_group = GROUP_SIZE_M * n_batch_heads
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_bh = (pid % num_pid_in_group) // group_size_m
    batch_id = pid_bh // n_heads
    head_id = pid_bh % n_heads
    
    # Base offsets for this (batch, head)
    q_offset = batch_id * stride_qz + head_id * stride_qh
    k_offset = batch_id * stride_kz + head_id * stride_kh
    v_offset = batch_id * stride_vz + head_id * stride_vh
    o_offset = batch_id * stride_oz + head_id * stride_oh
    
    q_block = tl.make_block_ptr(
        Q + q_offset, (seq_len, HEAD_DIM), (stride_qm, stride_qk),
        (pid_m * BLOCK_M, 0), (BLOCK_M, HEAD_DIM), (1, 0),
    )
    k_block = tl.make_block_ptr(
        K + k_offset, (seq_len, HEAD_DIM), (stride_kn, stride_kk),
        (0, 0), (BLOCK_N, HEAD_DIM), (1, 0),
    )
    v_block = tl.make_block_ptr(
        V + v_offset, (seq_len, HEAD_DIM), (stride_vn, stride_vk),
        (0, 0), (BLOCK_N, HEAD_DIM), (1, 0),
    )
    
    # sm_scale * log2(e)
    qk_scale = sm_scale * 1.44269504088896340736  # sm_scale * log2(e)
    
    q = tl.load(q_block, boundary_check=(0,))
    q = (q.to(tl.float32) * qk_scale).to(Q.dtype.element_ty)
    
    # Online softmax accumulators
    m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    
    # Row indices (needed for causal / boundary masks on attention scores)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    
    if IS_CAUSAL:
        full_blocks = (pid_m * BLOCK_M) // BLOCK_N
        n_blocks = tl.cdiv((pid_m + 1) * BLOCK_M, BLOCK_N)
    else:
        full_blocks = tl.cdiv(seq_len, BLOCK_N)
        n_blocks = full_blocks
        
    for block_n in range(0, full_blocks):
        cols = block_n * BLOCK_N + tl.arange(0, BLOCK_N)
        if EVEN_N:
            k = tl.load(k_block)
        else:
            k = tl.load(k_block, boundary_check=(0,))
            
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, tl.trans(k))
        
        if not EVEN_N:
            qk = tl.where(cols[None, :] < seq_len, qk, float("-inf"))
            
        # Online softmax with exp2
        m_ij = tl.max(qk, axis=1)
        m_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp2(m_i - m_new)
        p = tl.exp2(qk - m_new[:, None])
        
        l_i = l_i * alpha + tl.sum(p, axis=1)
        acc = acc * alpha[:, None]
        
        if EVEN_N:
            v = tl.load(v_block)
        else:
            v = tl.load(v_block, boundary_check=(0,))
        acc += tl.dot(p.to(v.dtype), v)
        
        m_i = m_new
        k_block = tl.advance(k_block, (BLOCK_N, 0))
        v_block = tl.advance(v_block, (BLOCK_N, 0))
        
    for block_n in range(full_blocks, n_blocks):
        cols = block_n * BLOCK_N + tl.arange(0, BLOCK_N)
        k = tl.load(k_block, boundary_check=(0,))
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, tl.trans(k))
        
        if IS_CAUSAL:
            qk = tl.where(offs_m[:, None] >= cols[None, :], qk, float("-inf"))
        if not EVEN_N:
            qk = tl.where(cols[None, :] < seq_len, qk, float("-inf"))
            
        m_ij = tl.max(qk, axis=1)
        m_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp2(m_i - m_new)
        p = tl.exp2(qk - m_new[:, None])
        
        l_i = l_i * alpha + tl.sum(p, axis=1)
        acc = acc * alpha[:, None]
        
        v = tl.load(v_block, boundary_check=(0,))
        acc += tl.dot(p.to(v.dtype), v)
        
        m_i = m_new
        k_block = tl.advance(k_block, (BLOCK_N, 0))
        v_block = tl.advance(v_block, (BLOCK_N, 0))
        
    acc = acc / l_i[:, None]
    
    o_block = tl.make_block_ptr(
        O + o_offset, (seq_len, HEAD_DIM), (stride_om, stride_ok),
        (pid_m * BLOCK_M, 0), (BLOCK_M, HEAD_DIM), (1, 0),
    )
    tl.store(o_block, acc.to(O.dtype.element_ty), boundary_check=(0,))

def flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = False,
) -> torch.Tensor:
    assert q.dim() == 4, "Expected shape (batch, n_heads, seq_len, head_dim)"
    batch, n_heads, seq_len, head_dim = q.shape
    assert head_dim in {16, 32, 64, 128,
                        256}, "head_dim must be power of 2 in [16, 256]"
    assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous()
    
    o = torch.empty_like(q)
    sm_scale = head_dim ** -0.5
    n_batch_heads = batch * n_heads
    
    # 1-D grid: grouped ordering delinearizes inside the kernel
    def grid(META): return (
        triton.cdiv(seq_len, META['BLOCK_M']) * n_batch_heads,
    )
    
    flash_attention_fwd_kernel[grid](
        q, k, v, o,
        sm_scale,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        n_heads,
        seq_len,
        n_batch_heads,
        HEAD_DIM=head_dim,
        IS_CAUSAL=causal,
    )
    return o
  • get_autotune_config defines multiple configuration combinations, covering different sequence lengths and head dimensions to adapt to various use cases.
  • Grouped Ordering is still applicable, ensuring highly efficient cache utilization of K/V blocks during computation.
  • Pre-scale Q + exp2 replaces the traditional scaling and Softmax operations, utilizing the Online Softmax algorithm to achieve numerically stable and memory-efficient attention computation.
  • Split causal loop: The inner loop is split into two phases. Phase 1 handles blocks completely below the diagonal, allowing the causal mask to be skipped; Phase 2 performs the tl.where causal mask only for the 1–2 blocks near the diagonal.

# 6. Benchmark: How to Test Your Kernel’s Performance

Writing a correct kernel is just the first step. In GPU programming, performance is the ultimate benchmark. Triton provides an entire benchmarking suite.

# 6.1 torch.testing.assert_close

Numerical correctness is a prerequisite for performance optimization. torch.testing.assert_close is a powerful tool to verify whether two tensors are approximately equal within specified absolute error (atol) and relative error (rtol) bounds:

1
2
torch.testing.assert_close(
    out, ref, atol=1e-2, rtol=1e-2)

Tolerance reference for different precisions:

dtypeatolrtolDescription
float321e-51e-5Most strict
float161e-21e-2Half-precision cumulative error is larger
bfloat161e-21e-2Shorter mantissa, lower precision
float81e-11e-1Extremely low precision, tolerance must be loose

# 6.2 triton.testing.do_bench

# Timing

1
2
3
4
5
6
7
ms = triton.testing.do_bench(
    lambda: my_kernel[grid](...),
    warmup=25,      # Warmup count (excludes the impact of JIT compilation and cache cold start)
    rep=100,        # Measurement count
    quantiles=[0.5, 0.2, 0.8],  # Returns median, 20th, and 80th percentiles
)
print(f"Median: {ms:.3f} ms  [p20={min_ms:.3f}, p80={max_ms:.3f}]")

Tip

Example: Measuring the latency of a softmax kernel:

1
2
ms = triton.testing.do_bench(lambda: softmax(x), warmup=25, rep=100)
print(f"Softmax latency: {ms:.3f} ms")

Warmup is crucial: the first call triggers JIT compilation, which can take hundreds of times longer than the actual execution time. do_bench will run a set of warmup rounds, discarding the results, before officially starting the timer.

# FLOPS Throughput

FLOPS (Floating-point Operations Per Second) is a key metric for evaluating the performance of compute-bound kernels. Before calculating this metric, you need to know how many total operations the algorithm requires. Theoretical compute load (FLOPs) is the total number of floating-point operations needed by the algorithm itself, independent of specific hardware parallelism. Common ops like add, mul, and sub count as 1 FLOP, while fma (fused multiply-add) counts as 2 FLOPs. Special built-in functions (like sin, exp) are handled by the SFU (Special Function Unit), and their compute conversions vary across architectures. However, in standard FLOPS calculations, usually only basic multiply-add operations are counted.

Example: For a simple vector addition C[i] = A[i] + B[i] with total length N, the total compute load is N FLOPs. For a matrix multiplication C = A @ B, where A is MxK and B is KxN, the total compute load is 2 * M * N * K FLOPs.

Furthermore, you can compute the kernel’s actual performance FLOPS = FLOPs / execution time (seconds). For example, a GEMM Kernel:

1
2
flops = 2 * M * N * K
tflops = (flops / 1e12) / (ms / 1e3)

# Bandwidth Utilization (GBPS)

For memory-bound kernels, bandwidth utilization is a more appropriate performance metric. The calculation method is: Bandwidth (GB/s) = Transferred Data Volume (GB) / Execution Time (seconds). The data volume includes all bytes read from and written to memory.

1
2
3
4
5
read_bytes = (M * K + K * N) * 2 
write_bytes = M * N * 2
gbps = (read_bytes + write_bytes) / (ms / 1e3) / 1e9
hw_bw = 3350  # H100 HBM Bandwidth in GB/s
utilization = (gbps / hw_bw) * 100  

# Roofline Model

The Roofline model is a classic tool to determine whether a kernel is compute-bound or memory-bound. The core concept is Arithmetic Intensity (AI) = Compute Volume / Data Transfer Volume. For the H100, the inflection point is around 295 FLOP/Byte:

  • AI < 295: memory-bound; optimization direction is to reduce memory access and increase bandwidth utilization.
  • AI > 295: compute-bound; optimization direction is to increase Tensor Core utilization.
1
2
3
4
5
arithmetic_intensity = flops / (read_bytes + write_bytes)  # FLOP/Byte
if arithmetic_intensity < 295:
    print("Memory-bound kernel")
else:
    print("Compute-bound kernel")

# 6.3 @triton.testing.perf_report

@triton.testing.perf_report is a decorator used in conjunction with triton.testing.Benchmark to automate multi-configuration benchmarking and generate plots.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["N"],
        x_vals=[128 * i for i in range(2, 33)],
        line_arg="provider",
        line_vals=["triton", "torch"],
        line_names=["Triton", "PyTorch"],
        styles=[("blue", "-"), ("red", "-")],
        ylabel="GB/s",
        plot_name="softmax-performance",
        args={"M": 4096},                            
    )
)
def bench_softmax(M, N, provider):
    x = torch.randn(M, N, device="cuda", dtype=torch.float16)
    if provider == "triton":
        fn = lambda: softmax(x)
    else:
        fn = lambda: torch.softmax(x, dim=-1)
    ms = triton.testing.do_bench(fn, warmup=25, rep=100)
    gbps = 2 * M * N * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps

bench_softmax.run(show_plots=True, print_data=True)
bench_softmax.run(save_path="./benchmarks/")

Explanation of triton.testing.Benchmark parameters:

ParameterTypeDescription
x_namesList[str]List of x-axis parameter names
x_valsList[List]List of values for each x-axis parameter
line_argstrParameter name used to distinguish different lines
line_valsList[Any]List of parameter values corresponding to different lines
line_namesList[str]List of line names (used in the legend)
stylesList[Tuple]Color and line style for each curve
ylabelstry-axis label
plot_namestrFilename for the generated plot (excluding extension)
argsDict[str, Any]Other fixed parameters
Softmax Performance

Softmax Performance

# 7. Cutting-Edge: FP4 and the Blackwell Architecture

# 7.1 Blackwell Architecture Overview

Standing in 2026, NVIDIA’s Blackwell (B200/B300) architecture has already achieved scaled deployment:

FeatureHopper (H100)Blackwell (B200)
MicroarchitectureSM90SM100
FP4 TOPSN/A~4500
FP8 TOPS~1979~2250
HBM Bandwidth3.35 TB/s~8 TB/s
Tensor Core4th Gen5th Gen

# 7.2 NVFP4 (E2M1) Format

NVFP4 E2M1 Format

NVFP4 E2M1 Format

Having only 16 representable values means the expressive capacity of FP4 is extremely limited. However, when combined with per-block scaling, where a group of 16 FP4 values shares a single FP8 scaling factor, it can cover a much larger numerical range. This is essentially the same idea as DeepGEMM’s per-128-element FP8 scaling—using a coarse-grained, higher-precision scale to compensate for the lack of expressive power in fine-grained, low-precision data.

# 7.3 Impact on Triton Programming

As Triton’s support for SM100 matures:

  1. Instruction Evolution: tl.dot will natively support FP4 × FP4 -> FP32, mapped to the 5th-generation Tensor Core tcgen05 under the hood.
  2. TMA Grand Slam: In the Hopper era, TMA was primarily responsible for Loads, but Blackwell supports a much more powerful Store-direction TMA and massive-scale Multicast.
  3. Deeper Pipelines: Thanks to the 256KB+ Shared Memory, num_stages=5 or even 6 is becoming the norm, and memory latency will be perfectly hidden.

# What’s Next

Triton is currently the most mature tile-based GPU DSL, but it is not the only choice. Tilelang takes a completely different path—offering more explicit control over memory hierarchies and more flexible abstractions for layout transformations. Next time, we will discuss Tilelang’s programming model and compare it with Triton.

本博客已稳定运行
总访客数: Loading
总访问量: Loading
发表了 27 篇文章 · 总计 71.51k
Built with Hugo
Theme Stack designed by Jimmy
基于 v3.32.0 分支版本修改