Featured image of post Triton CUDA 算子开发入门

Triton CUDA 算子开发入门

从向量加法到 FlashAttention,系统介绍 Triton tile-based GPU 编程模型与实战优化技巧。

# Triton CUDA 算子开发入门

# 前言

好久没写博客了。最近一年都在埋头搞研究,但是也在一直关注 AI infra 领域,深感这个领域卷得越来越厉害——大模型训推的性能压力逼着每个团队手写 GPU 算子 kernel,而"用什么工具写"这件事,正在经历洗牌。学习一下主流的 tile-based GPU DSL 对于理解未来几年 AI 算子开发的趋势非常有帮助,所以就有了这篇文章。

当前基于 tile 的 GPU DSL 竞争激烈:OpenAI 的 Triton 社区最大、生态最全;TVM 阵营的 Tilelang 另辟蹊径;NVIDIA 自家的 CuTe/CUTLASS 则是 C++ 模板元编程的极致。三者风格迥异,但如果只能学一个,Triton 几乎一定是最佳起点,Python 写 kernel,上手门槛最低,而且 PyTorch 2.0 的 torch.compile 后端就是 Triton。

这是系列第一篇,从零开始走完"向量加法 → 矩阵乘法 → FlashAttention"的完整路径。

# 一、为什么是 Triton?

# 1.1 GPU 编程范式的演进

GPU 编程从来都不简单。传统的 CUDA C++ 要求开发者在线程级别思考:手动管理共享内存(shared memory)、手动处理 bank conflict、手动编排异步拷贝流水线。CUTLASS 在此基础上用 C++ 模板抽象了 tile 级别的操作,大幅降低了心智负担,但仍然需要深厚的 C++ 模板元编程功底。

Triton 的定位很明确:用 Python 写出 90%+ cuBLAS 性能的 kernel。它把编程粒度从线程级提升到 tile 级,让编译器去处理那些繁琐的底层细节。

下面这张对比表总结了三种范式的核心差异:

维度CUDA C++CUTLASSTriton
编程粒度线程级Tile 级(C++ 模板)Tile 级(Python)
共享内存手动管理模板抽象编译器自动
Tensor Core手写 PTX / wmmaCuTe 布局tl.dot 自动映射
Bank conflict手动 swizzle模板 swizzle编译器自动
流水线手动 cp.async模板参数num_stages 一个参数
上手难度中高
性能上限100%~98%~90-95%

一个有意思的现象:Triton 的性能天花板确实低一些,但实际项目里,能把 Triton 写到 90% cuBLAS 的人,远比能把 CUDA C++ 写到 95% 的人多。开发效率本身就是性能的一部分。

# 1.2 编译流水线

Triton 的编译流程始于 Python 源码,历经多层中间表示(IR)的逐级转换,最终生成 GPU 二进制文件:

Triton 编译流水线示意图

Triton 编译流水线示意图

整个流水线的核心在于 Triton IR → Triton GPU IR 的转换:在此阶段,编译器会将抽象的 Tile 级操作映射为具体的线程、Warp 及 Tensor Core 指令。“开发者专注 Tile 级逻辑,编译器负责线程级实现”,这正是 Triton 降低算子开发门槛的核心设计哲学。

# 二、核心概念与指令速查

# 2.1 执行模型

Triton 的执行模型有三个关键概念:

  • 程序实例(program):一个 kernel 的一次执行实例,类似 CUDA 中的 thread block。每个 program 独立运行,拥有自己的 program ID。
  • 网格(grid):所有 program 实例的排列方式,支持 1D、2D、3D。
  • 瓦片(tile):每个 program 处理的数据块。tile 的大小由编译期常量(tl.constexpr)决定。

和 CUDA 最本质的区别:CUDA 程序员想的是"每个线程做什么",Triton 程序员想的是"每个 program 处理哪块 tile"。线程怎么协作、数据怎么在寄存器间分布,全交给编译器。

Triton 执行模型示意图

Triton 执行模型示意图

核心 API:

# tl.program_id

1
tl.program_id(axis: int) -> int

返回当前 program 在该轴的索引。axis 是 program ID 的维度,取值范围是 0、1、2,分别对应 1D、2D、3D 网格的不同轴。比如在一个 2D 网格中,tl.program_id(0) 返回行索引,tl.program_id(1) 返回列索引。

# tl.arange

1
tl.arange(start: int, end: int) -> Tensor

返回一个包含 [start, end) 范围内连续整数的 1D 张量。常用于构造 tile 内的偏移索引。BLOCK_SIZE 必须是 2 的幂

# tl.constexpr

1
tl.constexpr(value)

编译期常量装饰器。Triton kernel 中的某些参数必须在编译期确定(如 tile 大小),用 tl.constexpr 声明后,这些参数就成为编译期常量,编译器可以基于它们进行循环展开、条件分支消除等优化。

# 2.2 内存指令

Triton 中用于数据传输的指令包括两个:tl.loadtl.store,分别负责从显存加载数据到寄存器,以及将寄存器中的数据写回显存。

对于数据传输指令,其输入数据块和输出数据块的形状、数据类型必须完全相同。

# tl.load

1
2
3
4
5
tl.load(pointer: Union[int, Tensor, BlockPtr],
        mask: Optional[Tensor] = None,
        other: Optional[Union[int, float]] = None,
        cache_modifier: str = "",
        eviction_policy: str = "") -> Tensor

将显存中的数据加载到寄存器。

  • pointer 可以是一个整数地址、一个张量(指针矩阵)或一个 block pointer。
  • mask 是一个布尔张量,形状与 pointer 相同,用于指示哪些位置是有效的内存访问。无效位置会被替换为 other 的值。
  • other 是当 mask 为 False 时,用于替换无效位置的默认值。
  • cache_modifiereviction_policy 是高级参数,用于控制缓存行为,通常不需要修改。

警告

无 mask 时越界访问是未定义行为——不会报错,但结果不可预测,且极难排查

other 参数必须配合 mask 使用,否则报错

cache_modifier 仅对 NVIDIA GPU 有效


提示

1D 示例:

1
2
3
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_elements
x = tl.load(x_ptr + offs, mask=mask, other=0.0)

表示每个 program 加载一个长度为 BLOCK_SIZE 的向量块。mask 确保当 n_elements 不是 BLOCK_SIZE 的倍数时,越界位置不会被访问。

提示

2D 示例:

1
2
3
4
5
6
7
8
# 加载矩阵的一个 (BLOCK_M, BLOCK_N) 子块
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)  # shape: (BLOCK_M,)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)  # shape: (BLOCK_N,)

# [:, None] 和 [None, :] 产生广播,构造 (BLOCK_M, BLOCK_N) 的指针矩阵
ptrs = base_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tile = tl.load(ptrs, mask=mask, other=0.0)

这里的 [:, None](BLOCK_M,) 变成 (BLOCK_M, 1)[None, :](BLOCK_N,) 变成 (1, BLOCK_N),广播后得到 (BLOCK_M, BLOCK_N) 的指针张量。这是 Triton 中最常见的 2D 索引模式。

# tl.store

1
2
3
4
5
6
7
tl.store(
    pointer: Union[int, Tensor, BlockPtr],
    value: Tensor,
    mask: Optional[Tensor] = None,
    cache_modifier: str = "",
    eviction_policy: str = "",
)

将寄存器中的数据写回显存。参数与 tl.load 类似

提示

1D 示例:

1
2
3
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_elements
tl.store(out_ptr + offs, output, mask=mask)

表示每个 program 将一个长度为 BLOCK_SIZE 的向量块写回显存。mask 确保当 n_elements 不是 BLOCK_SIZE 的倍数时,越界位置不会被写入。

提示

2D 示例:

1
2
3
4
5
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ptrs = base_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(ptrs, tile, mask=mask)

表示每个 program 将一个 (BLOCK_M, BLOCK_N) 的 tile 写回显存。mask 确保越界位置不会被写入。

# tl.make_block_ptr

Block pointer(块指针)是 Triton 提供的高层抽象,用"矩阵子块"的语义来描述内存访问,而非手动计算指针偏移:

1
2
3
4
5
6
7
8
ptr = tl.make_block_ptr(
    base: Union[int, Tensor],  
    shape: Tuple[int, ...],
    strides: Tuple[int, ...],
    offsets: Tuple[int, ...],
    block_shape: Tuple[int, ...],
    order: Tuple[int, ...],
) -> BlockPtr
  • base 是基础指针,可以是整数地址或张量。
  • shape 是父张量的完整形状,如 (M, K)
  • strides 是父张量的步长,如 (stride_m, stride_k)
  • offsets 是起始偏移,如 (pid_m * BM, 0)
  • block_shape 是要加载的块形状,如 (BM, BK)
  • order 是内存布局顺序,(1, 0) 表示行主序,(0, 1) 表示列主序。

提示

示例:Block pointer 版本的 tl.load

1
2
3
4
5
6
7
8
9
ptr = tl.make_block_ptr(
    base=a_ptr,
    shape=(M, K),
    strides=(stride_am, stride_ak),
    offsets=(pid_m * BLOCK_M, 0),
    block_shape=(BLOCK_M, BLOCK_K),
    order=(1, 0),
)
tile = tl.load(ptr, boundary_check=(0, 1))

对比手动指针算术:

1
2
3
4
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)

语义更加清晰,编译器也有更多优化空间。

# tl.advance

Advance 是 Triton 提供的 block pointer 专用指令,用于在块指针上进行语义清晰的推进:

1
ptr = tl.advance(ptr, offsets: Tuple[int, ...]) -> BlockPtr
  • ptr 是一个 block pointer。
  • offsets 是要推进的块偏移,如 (0, BLOCK_K) 表示沿 K 维度推进一个块。
  • 返回一个新的 block pointer,指向推进后的新位置。
  • 语义清晰,编译器可以自动处理边界检查和优化。

对比手动指针算术:

1
a_ptrs += BLOCK_K * stride_ak

为什么推荐 block pointer?三个原因:

  1. 自动边界检查boundary_check=(0, 1) 一个参数搞定,无需手写 mask 广播
  2. 触发 TMA:在 Hopper(H100)架构上,block pointer 会自动映射到 TMA,大幅提升内存带宽利用率
  3. 编译器优化空间更大:block pointer 携带了更丰富的语义信息,编译器可以自动做 swizzle 等优化

# 2.3 计算指令

# tl.dot

1
2
3
4
5
6
tl.dot(
    input: Tensor,
    other: Tensor,
    acc: Optional[Tensor] = None,
    input_precision: Optional[str] = None,
    max_num_imprecise_acc: Optional[int] = None) -> Tensor

计算矩阵乘法 out = input @ other,假设 input 的形状是 (M, K)other 的形状是 (K, N),则输出的形状是 (M, N)

提示

示例:tl.dot 的典型用法

1
2
3
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
acc = tl.dot(a, b, acc)

tl.dot 会自动映射到 Tensor Core 的 MMA(Matrix Multiply-Accumulate)指令。这是 Triton 中性能最关键的操作——一个 tl.dot 调用,编译器会自动帮你生成 wmma(Volta/Turing)或 mma(Ampere)或 wgmma(Hopper)指令。

支持的数据类型组合:

Input AInput B累加器硬件要求
float16float16float32Volta+ (SM70)
bfloat16bfloat16float32Ampere+ (SM80)
float8e4nvfloat8e4nvfloat32Hopper+ (SM90)
int8int8int32Turing+ (SM75)
float32float32float32默认用 TF32 (SM80+)

重要

累加器永远用 float32(整型用 int32)。如果用 float16 做累加器,精度损失会非常严重。这是 Triton kernel 的铁律。

# 归约操作

1
2
3
tl.sum(x, axis=0)    # 沿 axis=0 求和
tl.max(x, axis=1)    # 沿 axis=1 取最大值
tl.min(x, axis=0)    # 沿 axis=0 取最小值

这些归约操作在 softmax、LayerNorm 等 kernel 中频繁出现,编译器会将其映射为高效的 warp-level shuffle 归约。

# 条件选择

1
tl.where(condition, x, y)

与 NumPy 的 np.where 语义一致。注意:两个分支都会被求值,所以不能用它来避免越界访问。典型用途是实现 causal mask——把未来位置的注意力分数替换为 -inf

# 2.4 其它

# @triton.autotune

自动调优是 Triton 的一大特性。你只需要提供一组候选配置,Triton 会在第一次调用时自动 benchmark 每种配置,然后缓存最优的那个:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=3),
        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=4),
    ],
    key=["M", "N", "K"],  
)
@triton.jit
def kernel(...):
    ...

key 参数非常重要:当 M、N、K 的值发生变化时,最优配置往往也会变(小矩阵适合小 tile,大矩阵适合大 tile),所以 Triton 会为每组不同的 key 值分别搜索最优配置。

# @triton.heuristics

启发式装饰器(heuristics)用于从运行时参数派生编译期常量,避免不必要的 benchmark 开销:

1
@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0})

典型应用:当 K 恰好是 BLOCK_K 的倍数时,可以跳过 mask 检查,省掉条件分支。

# 三、实战一:向量加法

问题:给定两个长度为 N 的向量 A 和 B,计算 C[i] = A[i] + B[i]

这算是最简单的 Triton kernel,但包含了 Triton 编程的所有基本要素。

# 3.1 分块思路

Triton 向量加法分块示意图

Triton 向量加法分块示意图

每个 program 处理连续的 BLOCK_SIZE 个元素。program 的总数 = ceil(N / BLOCK_SIZE)

# 3.2 完整实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr, y_ptr, out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(out_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n = output.numel()
    grid = lambda META: (triton.cdiv(n, META["BLOCK_SIZE"]),)
    add_kernel[grid](x, y, output, n, BLOCK_SIZE=1024)
    return output

if __name__ == "__main__":
    torch.manual_seed(42)
    size = 98432 
    x = torch.rand(size, device="cuda")
    y = torch.rand(size, device="cuda")
    output_triton = add(x, y)
    output_torch = x + y
    print(f"Max error: {(output_triton - output_torch).abs().max().item()}")

# 3.3 关键要点

  1. mask 的作用:当 n_elements 不是 BLOCK_SIZE 的倍数时,最后一个 program 的部分 offset 会超出数组边界。mask 确保这些越界位置不会被读写。忘记 mask = 未定义行为。
  2. BLOCK_SIZE 为什么必须是 2 的幂:这是 tl.arange 的硬件约束。GPU 的 warp 大小是 32,BLOCK_SIZE 必须是 warp 大小的整数倍才能高效映射到硬件线程。
  3. triton.cdiv(n, BLOCK_SIZE):向上取整除法,等价于 ceil(n / BLOCK_SIZE)。保证所有元素都被覆盖。

# 四、实战二:矩阵乘法

矩阵乘法(GEMM)是 GPU 上最重要的计算 kernel,几乎所有深度学习的计算量最终都归结于此。这一节我们从分块思路开始,逐步写出一个高性能的 Triton GEMM kernel。

# 4.1 分块思路

Triton 矩阵乘法分块示意图

Triton 矩阵乘法分块示意图

核心算法:

  1. 每个 program 负责计算 C 的一个 (BLOCK_M, BLOCK_N) 子块
  2. 初始化 fp32 累加器为零
  3. 沿 K 维度循环:每步加载 A 的 (BLOCK_M, BLOCK_K) 和 B 的 (BLOCK_K, BLOCK_N),做一次 tl.dot
  4. 循环结束后,将累加器转换为输出精度并写回 HBM 显存

# 4.2 基础实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
):
    # Row-major Mapping
    pid = tl.program_id(0)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n

    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
        acc = tl.dot(a, b, acc)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    c = acc.to(tl.float16)
    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = c_ptr + offs_cm[:, None] * \
        stride_cm + offs_cn[None, :] * stride_cn
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape[1] == b.shape[0], "Dimensions mismatch"
    assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=128, BLOCK_N=128, BLOCK_K=32
    )
    return c

这个实现是最基础的版本,只需要 50 多行代码,以下性能测试在 A100 上进行,可以看到峰值性能约达 A100 fp16 理论算力(312 TFLOPS)的 70%

Size (M, K, N)Time (ms)TFLOPS
(512, 1024, 512)0.02521.34
(1024, 1024, 1024)0.02682.78
(2048, 2048, 2048)0.094182.18
(4096, 4096, 4096)0.631217.91
(8192, 8192, 8192)6.297174.62
(16384, 16384, 16384)59.825147.03

# 4.3 进阶优化

矩阵乘法 GEMM 是一个非常基础的算子了,常见的优化思路无非是:swizzling、pipelining、warp specialization、自动调优、WMMA 等,其中共享内存 swizzling 已经被 Triton 自动管理来解决 bank conflict 问题了,我们更关注的是 Tile Swizzling,warp specialization,流水线重叠和自动调优,下面将逐一介绍这些优化在 Triton 中的实现方式。

# 优化 1:Grouped Ordering

Grouped Ordering 是一种 tile 调度策略,主要解决的是 L2 缓存命中率低的问题。

之前 GEMM 的基础实现中,program ID 是按照行优先(Row-major)方式映射的:先按行分配 program ID,再按列分配。

  • 计算顺序:先算完第 0 行的所有 Tile,再算第 1 行。
  • 缓存问题:在计算第 0 行时,我们需要加载矩阵 A 的第 0 行块和矩阵 B 的所有列块。如果矩阵 B 非常大,等到开始算第 1 行时,之前加载到 L2 缓存里的矩阵 B 的数据可能已经被踢出去了。
  • 后果:每一行计算都要重新从显存(VRAM)读取一遍矩阵 B,带宽压力巨大。

Grouped Ordering 的核心思想是:让计算在 M 维度(行)和 N 维度(列)上都分成小组(group),同一组内的 program 紧挨着执行。比如设置 GROUP_SIZE_M=4,每 4 行为一个 group,使得一小块矩阵 B 的数据能在 L2 缓存中被多个 Program 循环利用。

Grouped Ordering 示意图

Grouped Ordering 示意图

在朴素顺序中,program (0,0)(1,0) 都需要 B 的第 0 列 tile,但它们之间隔了 num_pid_n 个 program。当 num_pid_n 很大时,B 的第 0 列 tile 早已被 L2 缓存驱逐。而 Grouped Ordering 让需要同一列 B tile 的 program 紧挨着执行,L2 命中率显著提升。

# 优化 2:Block Pointer

tl.make_block_ptr 重写核心循环

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
a_block = tl.make_block_ptr(a_ptr, (M, K), (stride_am, stride_ak),
                             (pid_m * BLOCK_M, 0), (BLOCK_M, BLOCK_K), (1, 0))
b_block = tl.make_block_ptr(b_ptr, (K, N), (stride_bk, stride_bn),
                             (0, pid_n * BLOCK_N), (BLOCK_K, BLOCK_N), (1, 0))

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
    a = tl.load(a_block, boundary_check=(0, 1))
    b = tl.load(b_block, boundary_check=(0, 1))
    acc = tl.dot(a, b, acc)
    a_block = tl.advance(a_block, (0, BLOCK_K))
    b_block = tl.advance(b_block, (BLOCK_K, 0))

2.2 节已经介绍过 tl.make_block_ptrtl.advance,这里不再赘述。总之,这个优化的好处在于:

  1. 代码更简洁:不需要手动计算指针偏移和 mask,语义更清晰
  2. 自动边界检查boundary_check=(0, 1) 参数让编译器自动处理越界访问,避免了之前版本中复杂的 mask 计算和广播
  3. 触发 TMA:在 Hopper(H100)架构上,block pointer 会自动映射到 TMA,大幅提升内存带宽利用率

# 优化 3:Autotune

可以使用 @triton.autotune@triton.heuristics 加上自动调优和启发式优化,让 Triton 在不同的 M、N、K 维度下自动选择最优的 BLOCK_M、BLOCK_N、BLOCK_K 和 GROUP_SIZE_M 参数组合,从而在各种矩阵大小上都能达到较好的性能。

最终版本的实现如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
import torch
import triton
import triton.language as tl

def get_autotune_config():
    return [
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        # Good config for fp8 inputs.
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4)
    ]


@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0})
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr, EVEN_K: tl.constexpr
):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    a_block = tl.make_block_ptr(
        a_ptr, (M, K), (stride_am, stride_ak), (pid_m * BLOCK_M, 0), (BLOCK_M, BLOCK_K), (1, 0))
    b_block = tl.make_block_ptr(
        b_ptr, (K, N), (stride_bk, stride_bn), (0, pid_n * BLOCK_N), (BLOCK_K, BLOCK_N), (1, 0))

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
        if EVEN_K:
            a = tl.load(a_block)
            b = tl.load(b_block)
        else:
            a = tl.load(a_block, boundary_check=(0, 1))
            b = tl.load(b_block, boundary_check=(0, 1))
        acc = tl.dot(a, b, acc)
        a_block = tl.advance(a_block, (0, BLOCK_K))
        b_block = tl.advance(b_block, (BLOCK_K, 0))

    c = acc.to(tl.float16)
    c_block = tl.make_block_ptr(c_ptr, (M, N), (stride_cm, stride_cn),
                                (pid_m * BLOCK_M, pid_n * BLOCK_N), (BLOCK_M, BLOCK_N), (1, 0))
    tl.store(c_block, c, boundary_check=(0, 1))


def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape[1] == b.shape[0], "Dimensions mismatch"
    assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
    
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

代码行仅提升到约 100 行,但性能提升显著,尤其是在大矩阵上,已经十分接近 cuBLAS 的 fp16 性能了:

Size (M, K, N)Time (ms)TFLOPS
(512, 1024, 512)0.01438.31
(1024, 1024, 1024)0.02298.55
(2048, 2048, 2048)0.091189.26
(4096, 4096, 4096)0.621221.42
(8192, 8192, 8192)4.905224.14
(16384, 16384, 16384)38.054231.15

# 五、实战三:FlashAttention

# 5.1 标准 Attention 的内存瓶颈

标准的多头注意力(Multi-Head Attention)计算流程如下:

  1. 计算注意力分数(Attention Score): 将查询矩阵 $Q$ 与转置后的键矩阵 $K^T$ 相乘,生成相似度矩阵 $S$。 $$S = Q K^T, \quad S \in \mathbb{R}^{B \times H \times N \times N}$$ 此步骤的时间与空间复杂度均为 $O(N^2)$,会产生巨大的中间张量。

  2. 归一化处理(Softmax): 对 $S$ 进行缩放和 Softmax 操作,得到注意力权重概率分布 $P$。 $$P = \text{softmax}\left(\frac{S}{\sqrt{d_k}}\right), \quad P \in \mathbb{R}^{B \times H \times N \times N}$$ 此处再次产生一个 $O(N^2)$ 的全量权重矩阵。

  3. 计算加权输出(Output): 利用权重矩阵 $P$ 对值矩阵 $V$ 进行加权聚合。 $$O = P V, \quad O \in \mathbb{R}^{B \times H \times N \times d}$$

这个过程中会产生 $S = Q K^T$ 和 $P = \text{softmax}(S)$ 两个规模为 $O(N^2)$ 的中间矩阵。当序列长度 $N$ 上升到 32K 甚至 128K 时,显存直接撑爆;同时,频繁在 HBM 与 SRAM 之间搬运这两个巨型矩阵,导致带宽被彻底锁死。

FlashAttention 的解法:不实例化 $S$ 和 $P$ 矩阵,通过分块(Tiling)在 SRAM 中一口气算到底,将显存复杂度降为 $O(N)$。

# 5.2 Online Softmax 算法

Online Softmax 的核心思想是为每一行维护一个状态三元组 $(m_i, l_i, \mathbf{o}_i)$,通过迭代更新实现增量计算:

  • $m_i$:当前已处理部分的行最大值。
  • $l_i$:当前局部归一化常数(Softmax 分母)。
  • $\mathbf{o}_i$:当前加权累加结果(输出向量)。

对于每个新的 $K/V$ 块 $j$,更新步骤如下:

  1. 计算局部相似度: 计算当前块的原始分数 $S^{(j)}$: $$S^{(j)} = Q_{\text{block}} \times K_{\text{block}, j}^T$$

  2. 更新行最大值: 对比当前块的最大值与历史最大值,确定新的全局最大值 $m_{\text{new}}$: $$m_{\text{new}} = \max(m_i, \text{row\_max}(S^{(j)}))$$

  3. 计算修正系数: 为了确保数值稳定性以及不同块之间的对齐,计算旧统计量的缩放因子 $\alpha$: $$\alpha = \exp(m_i - m_{\text{new}})$$

  4. 更新归一化分母: 利用修正系数对齐旧分母,并加入当前块的贡献: $$l_{\text{new}} = l_i \cdot \alpha + \sum \exp(S^{(j)} - m_{\text{new}})$$

  5. 更新累加器: 对旧的输出向量进行重缩放,并累加当前块的新贡献: $$\mathbf{o}_{\text{new}} = \mathbf{o}_i \cdot \alpha + \exp(S^{(j)} - m_{\text{new}}) \times V_{\text{block}, j}$$

  6. 迭代状态: 更新当前状态: $$m_i \leftarrow m_{\text{new}}, l_i \leftarrow l_{\text{new}}, \mathbf{o}_i \leftarrow \mathbf{o}_{\text{new}}$$

算法正确性: 修正系数 $\alpha = \exp(m_i - m_{\text{new}})$ 是算法的关键。当发现更大的局部最大值时($m_{\text{new}} > m_i$),$\alpha$ 会作为一个小于 1 的缩放因子,对之前基于较小 $m_i$ 计算出的指数项进行补偿。这种机制确保了在分块遍历结束时,最终输出 $O = \mathbf{o}_i / l_i$ 与一次性全局计算的结果在数学上完全等价。

# 5.3 Kernel 实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
import triton
import triton.language as tl

def get_autotune_config():
    return [
        # Small blocks — short sequences / small head_dim
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        # Medium blocks
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        # Large blocks — long sequences
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        # Higher pipeline depth
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128,
                      'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8),
    ]


@triton.autotune(
    configs=get_autotune_config(),
    key=['seq_len', 'HEAD_DIM'],
)
@triton.heuristics({
    "EVEN_N": lambda args: args["seq_len"] % args["BLOCK_N"] == 0,
})
@triton.jit
def flash_attention_fwd_kernel(
    Q, K, V, O,
    sm_scale,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vn, stride_vk,
    stride_oz, stride_oh, stride_om, stride_ok,
    n_heads,
    seq_len,
    n_batch_heads,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    EVEN_N: tl.constexpr,
):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(seq_len, BLOCK_M)
    num_pid_in_group = GROUP_SIZE_M * n_batch_heads
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_bh = (pid % num_pid_in_group) // group_size_m

    batch_id = pid_bh // n_heads
    head_id = pid_bh % n_heads

    # Base offsets for this (batch, head)
    q_offset = batch_id * stride_qz + head_id * stride_qh
    k_offset = batch_id * stride_kz + head_id * stride_kh
    v_offset = batch_id * stride_vz + head_id * stride_vh
    o_offset = batch_id * stride_oz + head_id * stride_oh

    q_block = tl.make_block_ptr(
        Q + q_offset, (seq_len, HEAD_DIM), (stride_qm, stride_qk),
        (pid_m * BLOCK_M, 0), (BLOCK_M, HEAD_DIM), (1, 0),
    )
    k_block = tl.make_block_ptr(
        K + k_offset, (seq_len, HEAD_DIM), (stride_kn, stride_kk),
        (0, 0), (BLOCK_N, HEAD_DIM), (1, 0),
    )
    v_block = tl.make_block_ptr(
        V + v_offset, (seq_len, HEAD_DIM), (stride_vn, stride_vk),
        (0, 0), (BLOCK_N, HEAD_DIM), (1, 0),
    )

    # sm_scale * log2(e)
    qk_scale = sm_scale * 1.44269504088896340736  # sm_scale * log2(e)
    q = tl.load(q_block, boundary_check=(0,))
    q = (q.to(tl.float32) * qk_scale).to(Q.dtype.element_ty)

    # Online softmax accumulators
    m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    # Row indices (needed for causal / boundary masks on attention scores)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

    if IS_CAUSAL:
        full_blocks = (pid_m * BLOCK_M) // BLOCK_N
        n_blocks = tl.cdiv((pid_m + 1) * BLOCK_M, BLOCK_N)
    else:
        full_blocks = tl.cdiv(seq_len, BLOCK_N)
        n_blocks = full_blocks

    for block_n in range(0, full_blocks):
        cols = block_n * BLOCK_N + tl.arange(0, BLOCK_N)
        if EVEN_N:
            k = tl.load(k_block)
        else:
            k = tl.load(k_block, boundary_check=(0,))

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, tl.trans(k))

        if not EVEN_N:
            qk = tl.where(cols[None, :] < seq_len, qk, float("-inf"))

        # Online softmax with exp2
        m_ij = tl.max(qk, axis=1)
        m_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp2(m_i - m_new)
        p = tl.exp2(qk - m_new[:, None])

        l_i = l_i * alpha + tl.sum(p, axis=1)
        acc = acc * alpha[:, None]

        if EVEN_N:
            v = tl.load(v_block)
        else:
            v = tl.load(v_block, boundary_check=(0,))

        acc += tl.dot(p.to(v.dtype), v)
        m_i = m_new

        k_block = tl.advance(k_block, (BLOCK_N, 0))
        v_block = tl.advance(v_block, (BLOCK_N, 0))

    for block_n in range(full_blocks, n_blocks):
        cols = block_n * BLOCK_N + tl.arange(0, BLOCK_N)

        k = tl.load(k_block, boundary_check=(0,))

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, tl.trans(k))
        
        if IS_CAUSAL:
            qk = tl.where(offs_m[:, None] >= cols[None, :], qk, float("-inf"))

        if not EVEN_N:
            qk = tl.where(cols[None, :] < seq_len, qk, float("-inf"))

        m_ij = tl.max(qk, axis=1)
        m_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp2(m_i - m_new)
        p = tl.exp2(qk - m_new[:, None])

        l_i = l_i * alpha + tl.sum(p, axis=1)
        acc = acc * alpha[:, None]

        v = tl.load(v_block, boundary_check=(0,))

        acc += tl.dot(p.to(v.dtype), v)
        m_i = m_new

        k_block = tl.advance(k_block, (BLOCK_N, 0))
        v_block = tl.advance(v_block, (BLOCK_N, 0))

    acc = acc / l_i[:, None]

    o_block = tl.make_block_ptr(
        O + o_offset, (seq_len, HEAD_DIM), (stride_om, stride_ok),
        (pid_m * BLOCK_M, 0), (BLOCK_M, HEAD_DIM), (1, 0),
    )
    tl.store(o_block, acc.to(O.dtype.element_ty), boundary_check=(0,))


def flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = False,
) -> torch.Tensor:
    assert q.dim() == 4, "Expected shape (batch, n_heads, seq_len, head_dim)"
    batch, n_heads, seq_len, head_dim = q.shape
    assert head_dim in {16, 32, 64, 128,
                        256}, "head_dim must be power of 2 in [16, 256]"
    assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous()

    o = torch.empty_like(q)
    sm_scale = head_dim ** -0.5
    n_batch_heads = batch * n_heads

    # 1-D grid: grouped ordering delinearizes inside the kernel
    def grid(META): return (
        triton.cdiv(seq_len, META['BLOCK_M']) * n_batch_heads,
    )

    flash_attention_fwd_kernel[grid](
        q, k, v, o,
        sm_scale,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        n_heads,
        seq_len,
        n_batch_heads,
        HEAD_DIM=head_dim,
        IS_CAUSAL=causal,
    )
    return o
  • get_autotune_config 定义了多个配置组合,覆盖了不同的序列长度和头维度,以适应各种使用场景。
  • Grouped Ordering 仍然适用,确保了在计算过程中对 K/V 块的高效缓存利用。
  • Pre-scale Q + exp2 替代了传统的缩放和 Softmax 操作,利用 Online Softmax 算法实现了数值稳定且内存高效的注意力计算。
  • Split causal loop:将内层循环拆为两阶段。Phase 1 处理完全在对角线以下的 block,可跳过 causal mask;Phase 2 仅对对角线附近的 1–2 个 block 执行 tl.where causal mask。

# 六、Benchmark:如何测试你的 Kernel 性能

写出正确的 kernel 只是第一步。在 GPU 编程中,性能才是最终的评判标准。Triton 提供了一整套 benchmark 框架。

# 6.1 torch.testing.assert_close

数值正确性是性能优化的前提。torch.testing.assert_close 是一个功能强大的工具,可以比较两个张量是否在指定的绝对误差(atol)和相对误差(rtol)范围内近似相等:

1
2
torch.testing.assert_close(
    out, ref, atol=1e-2, rtol=1e-2)

不同精度下的容差参考:

dtypeatolrtol说明
float321e-51e-5最严格
float161e-21e-2半精度累积误差较大
bfloat161e-21e-2尾数更短,精度更低
float81e-11e-1极低精度,容差必须宽松

# 6.2 triton.testing.do_bench

# 计时

1
2
3
4
5
6
7
ms = triton.testing.do_bench(
    lambda: my_kernel[grid](...),
    warmup=25,      # 预热次数(排除 JIT 编译和缓存冷启动的影响)
    rep=100,        # 测量次数
    quantiles=[0.5, 0.2, 0.8],  # 返回中位数、20th、80th 分位数
)
print(f"Median: {ms:.3f} ms  [p20={min_ms:.3f}, p80={max_ms:.3f}]")

提示

示例:测量 softmax kernel 的延迟:

1
2
ms = triton.testing.do_bench(lambda: softmax(x), warmup=25, rep=100)
print(f"Softmax latency: {ms:.3f} ms")

预热(warmup)很重要:第一次调用会触发 JIT 编译,耗时可能是实际执行时间的数百倍。do_bench 会先运行 warmup 轮丢弃结果,然后才开始正式计时。

# FLOPS 吞吐量

FLOPS(Floating-point Operations Per Second,每秒浮点运算次数)是衡量 compute-bound kernel 性能的关键指标。计算这个指标前需要知道整个算法总共需要多少次运算。理论运算量 (FLOPs) 是算法本身所需的浮点运算总数,与具体的硬件并行度无关。其中常见的 addmulsub 都是 1 FLOP,fma 是 2 FLOP。特殊的内置函数(如 sin, exp)属于 SFU (Special Function Unit) 处理,其算力换算在不同架构下不同,但在标准 FLOPS 计算中,通常只统计基础的乘加运算。

例子:一个简单的向量加法 C[i] = A[i] + B[i],若总长度为 N,则总运算量为 N 个 FLOPs。若是一个矩阵乘法 C = A @ B,其中 A 是 MxK,B 是 KxN,则总运算量为 2 * M * N * K FLOPs 。

进一步地,可以计算 Kernel 的实测性能 FLOPS = FLOPs / 执行时间(秒)。例如 GEMM Kernel :

1
2
flops = 2 * M * N * K
tflops = (flops / 1e12) / (ms / 1e3)

# 带宽利用率 GBPS

对于 memory-bound kernel,带宽利用率是更合适的性能指标。计算方法是:带宽 (GB/s) = 传输的数据量 (GB) / 执行时间 (秒)。数据量包括所有从内存读取和写入的字节数。

1
2
3
4
5
read_bytes = (M * K + K * N) * 2 
write_bytes = M * N * 2
gbps = (read_bytes + write_bytes) / (ms / 1e3) / 1e9
hw_bw = 3350  # H100 HBM Bandwidth in GB/s
utilization = (gbps / hw_bw) * 100  

# Roofline 模型

Roofline 模型是判断 kernel 是 compute-bound 还是 memory-bound 的经典工具。核心概念是算术强度(Arithmetic Intensity, AI)= 计算量 / 数据传输量。对于 H100,拐点大约在 295 FLOP/Byte:

  • AI < 295:memory-bound,优化方向是减少内存访问、提高带宽利用率
  • AI > 295:compute-bound,优化方向是提高 Tensor Core 利用率
1
2
3
4
5
arithmetic_intensity = flops / (read_bytes + write_bytes)  # FLOP/Byte
if arithmetic_intensity < 295:
    print("Memory-bound kernel")
else:
    print("Compute-bound kernel")

# 6.3 @triton.testing.perf_report

@triton.testing.perf_report 是一个装饰器,配合 triton.testing.Benchmark 使用,用于自动化多配置基准测试并生成图表。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["N"],
        x_vals=[128 * i for i in range(2, 33)],
        line_arg="provider",
        line_vals=["triton", "torch"],
        line_names=["Triton", "PyTorch"],
        styles=[("blue", "-"), ("red", "-")],
        ylabel="GB/s",
        plot_name="softmax-performance",
        args={"M": 4096},                            
    )
)
def bench_softmax(M, N, provider):
    x = torch.randn(M, N, device="cuda", dtype=torch.float16)
    if provider == "triton":
        fn = lambda: softmax(x)
    else:
        fn = lambda: torch.softmax(x, dim=-1)
    ms = triton.testing.do_bench(fn, warmup=25, rep=100)
    gbps = 2 * M * N * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps

bench_softmax.run(show_plots=True, print_data=True)
bench_softmax.run(save_path="./benchmarks/")

triton.testing.Benchmark 参数说明:

参数类型说明
x_namesList[str]x 轴参数名称列表
x_valsList[List]每个 x 轴参数的取值列表
line_argstr用于区分不同曲线的参数名称
line_valsList[Any]不同曲线对应的参数取值列表
line_namesList[str]曲线名称列表(用于图例)
stylesList[Tuple]每条曲线的颜色和线型
ylabelstry 轴标签
plot_namestr生成的图表文件名(不含扩展名)
argsDict[str, Any]其他固定参数
Softmax Performance

Softmax Performance

# 七、前沿:FP4 与 Blackwell 架构

# 7.1 Blackwell 架构速览

站在 2026 年的时间点,NVIDIA Blackwell(B200/B300)架构已经规模化部署:

特性Hopper (H100)Blackwell (B200)
微架构SM90SM100
FP4 TOPSN/A~4500
FP8 TOPS~1979~2250
HBM 带宽3.35 TB/s~8 TB/s
Tensor Core第 4 代第 5 代

# 7.2 NVFP4 (E2M1) 格式

NVFP4 E2M1 Format

NVFP4 E2M1 Format

只有 16 个可表示值意味着 FP4 的表达能力极其有限。但配合 per-block scaling,每组 16 个 FP4 值共享一个 FP8 的缩放因子,就可以覆盖更大的数值范围。这本质上和 DeepGEMM 的 per-128-element FP8 scaling 是同一个思路——用粗粒度的高精度 scale 来弥补细粒度低精度数据的表达能力不足。

# 7.3 对 Triton 编程的影响

随着 Triton 对 SM100 支持的完善:

  1. 指令进化tl.dot 将原生支持 FP4 × FP4 -> FP32,背后映射至第五代 Tensor Core tcgen05
  2. TMA 大满贯:Hopper 时代 TMA 主要负责 Load,而 Blackwell 支持更强大的 Store 方向 TMA 及超大范围的 Multicast(多播)。
  3. 更深的流水线:得益于 256KB+ 的 Shared Memory,num_stages=5 甚至 6 成为常态,内存延迟将被完美隐藏。

# 后续

Triton 是目前最成熟的 tile-based GPU DSL,但并非唯一选择。Tilelang 走了一条截然不同的路——更显式的内存层级控制、更灵活的布局变换抽象。接下来会将讲一下 Tilelang 的编程模型,并且与 Triton 进行对比。

本博客已稳定运行
总访客数: Loading
总访问量: Loading
发表了 75 篇文章 · 总计 343.94k
使用 Hugo 构建
主题 StackJimmy 设计
基于 v3.32.0 分支版本修改