Triton
Triton 是 OpenAI 开发的 GPU 编程语言,旨在提供比 CUDA 更高层抽象的同时,保持接近手写 CUDA 的性能。它的设计哲学是"让 AI 研究者能够轻松编写高性能 GPU kernel",无需深入了解 GPU 架构、shared memory、warp shuffle 等底层概念。
为什么需要 Triton
CUDA 编程的学习曲线陡峭,需要理解 GPU 的硬件架构(SM、warp、shared memory bank conflict)、手动管理线程块调度、优化内存访问模式(coalescing、padding)。编写一个高效的矩阵乘法 kernel 需要数百行代码,且针对不同 GPU 架构(Ampere vs Hopper)需要重新调优。
Triton 将抽象级别提高到更接近数学表达,开发者只需编写"每个输出元素如何计算",编译器自动处理线程映射、内存分块、shared memory 管理。这大幅降低了开发门槛,一个矩阵乘法 kernel 仅需约 50 行 Triton 代码,且性能可达手写 CUDA 的 90% 以上。
核心概念
Triton 的程序以 @triton.jit 装饰的函数表示,该函数将被编译为 GPU kernel。函数的参数包括输入输出张量的指针、形状、步长(stride),以及编译时常量(如 BLOCK_SIZE: tl.constexpr)。kernel 内部使用 Triton 语言(基于 Python 的 DSL)编写,包括 tl.load(加载数据)、tl.store(存储数据)、tl.dot(矩阵乘法)、tl.reduce(归约操作)等原语。
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# 每个 program 处理一个 BLOCK_SIZE 的数据
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# 加载数据
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# 存储数据
tl.store(output_ptr + offsets, output, mask=mask)Triton 的抽象模型是"program"而非"thread"。一个 program 处理一个数据块(block),编译器自动将 program 映射到 GPU 的线程块,并管理线程同步。这简化了并行编程,因为开发者无需手动计算 thread ID、block ID、grid size。
自动调优
Triton 的 triton.autotune 可自动搜索最优配置(block size、num_stages、num_warps)。对于矩阵乘法,block size 影响 shared memory 使用率,num_stages 影响流水线深度,num_warps 影响 warp 占用。这些参数的最优值依赖于硬件架构和数据形状,手动调优需要大量实验。autotune 通过运行不同配置,选择性能最优的配置,无需人工干预。
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_warps=4),
triton.Config({"BLOCK_SIZE": 256}, num_warps=8),
triton.Config({"BLOCK_SIZE": 512}, num_warps=16),
],
key=["n_elements"],
)
@triton.jit
def add_kernel(...):
...PyTorch 一等公民
Triton kernel 可通过 torch.compile 或 torch.ops.load_custom_op 集成到 PyTorch 模型中。PyTorch 2.0 原生支持 Triton,torch.compile 会自动将符合模式的算子编译为 Triton kernel(如逐元素操作、归约操作)。这无需修改模型代码,只需添加 torch.compile(model) 即可享受 Triton 的性能提升。自此,PyTorch 的强势推行,Triton 已成为 PyTorch 平台的一等公民。
对于自定义算子,可通过 torch.library.custom_op 注册 Triton kernel:
import torch
def triton_add(x, y):
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=256)
return output
# 注册为 PyTorch 算子
torch.library.define("mylib::triton_add", "(Tensor, Tensor) -> Tensor")
torch.library.impl("mylib::triton_add", "CUDA", triton_add)性能对比
Triton 官方 benchmark 显示,对于常见算子(逐元素操作、Softmax、LayerNorm),Triton 的性能可达手写 CUDA 的 85-95%。对于矩阵乘法,Triton 的性能约为 cuBLAS(NVIDIA 官方库)的 70-80%,差距主要来自 cuBLAS 针对特定 GPU 架构的手写汇编优化。
Triton 的优势在于开发效率。一个 Triton kernel 从编写、调试、调优到部署,通常仅需 1-2 天,而手写 CUDA 需要 1-2 周。对于快速迭代的研究场景(尝试新的 Attention 机制、新的归一化方法),Triton 是更实用的选择。
局限性
Triton 的抽象级别高于 CUDA,但也因此牺牲了部分性能控制力。对于需要极致优化的算子(如 FlashAttention-2 的手写 assembly),Triton 难以达到同等性能。此外,Triton 的生态系统仍在发展中,调试工具(triton.testing)和性能分析工具(triton.testing.do_bench)不如 CUDA 成熟。
Triton 目前支持 NVIDIA GPU 和 AMD GPU(通过 ROCm),不支持 CPU 和其他加速器(如 TPU、NPU)。对于非 CUDA 平台,需要考虑其他方案(如 OpenCL、SYCL)。
Triton 只定义了数据面 DSL,未能实现控制面代码的封装屏蔽,对于不同的厂商硬件上的控制面代码,例如:显存管理和通信控制等,依然需要上层的 AI 引擎层来管理和调度。
未来展望
Triton 的发展方向包括:更好的自动调优(基于机器学习的配置搜索)、更丰富的标准库(卷积、RNN、Transformer)、更完善的调试工具(interleaved execution、race condition 检测)。OpenAI 正在使用 Triton 重写 PyTorch 的标准算子,未来 torch.nn.functional 的大部分算子可能由 Triton 实现,而非 cuDNN。
Triton 有望成为 AI 领域的"Vulkan"——跨平台、高性能、易用性兼备的 GPU 编程语言。与 Vulkan 一样,Triton 的成功依赖于生态建设(IDE 支持、性能分析工具、第三方库),这需要社区的共同努力。