Triton (OpenAI 版)
今年 Triton 确实挺火,互联网大厂想用它写算子,比 CUDA 迭代周期更短;硬件厂想用它的 DSL 来推广自己的软件栈和生态。
本文总结了 2024 年大模型方向秋招面试中的高频技术问题。涵盖 Triton Kernel 优化流程及下降路径,MLIR 中 Tensor 与 Memref 抽象差异、Linalg Dialect 设计理念及图拓扑排序实现。涉及 LLM 推理优化技术如 KV Cache、Flash Attention、Page Attention 等,以及 GPU SM 架构基础知识和 C++ 核心概念。旨在帮助求职者梳理编译器底层原理与大模型系统优化知识体系。

今年 Triton 确实挺火,互联网大厂想用它写算子,比 CUDA 迭代周期更短;硬件厂想用它的 DSL 来推广自己的软件栈和生态。
推荐一些学习资源:
- 谈谈对 OpenAI Triton 的一些理解,帮助大家建立一个宏观印象
- 如何入门 OpenAI Triton 编程?帮助了解更多关于语法上的信息
- 浅析 Triton 执行流程,帮助初学者大致明白一段 Python 程序如何进入 Triton pipeline,然后跑起来。
挖坑:有一些基础的认识后,大家就可以在 Triton 中各取所需了。作为一个 MLIR Programmer,我还希望了解每个 transform pass 和每一步 conversion 是怎么做的、作用等,抽空好好读读源码。
不管啥 kernel,到我手上都是经过'两步走'来优化:
浅层优化:通过替换算子、(用 atomic op) 合并 kernel、拆循环、调 config 等方式实现初步优化。
深层优化:分析下降所得 IR,使用 perf 工具,对照算子库实现等方式,优化 kernel 的下降行为。

大部分情况下,'第一步'走完性能就接近算子库了,还是大哥们后续 codegen 的 pass 太顶级了,我成为了无情的 config 添加器。
关于这部分更详细的可以看看鄙人的记录,这里详细一些 ref:[Triton] Kernel Optim
当'第一步'走完性能还是和算子库有距离,那就继续'第二步',上 perf!看 ir!看看访存是否连续,得到的汇编是否符合预期等。(自己总有看不懂的时候,直接叫大哥)
分解出优化点,在 ir 下降过程加点美味的 pattern,大部分情况还是得看看算子库的大哥们是咋写的 kernel,然后(抄一下)启发一下编译器的 lowering 过程。如果还是打不过怎么办,这时候就真得看看 IO 这些是否打得比较足了,或者换过服务器多跑几次(别试,一般没用)。
官方:triton-lang -> triton ir -> triton gpu dialect -> llvmir -> ptx
其中 llvm ir 更标准地说法应该是 nvvm ir,相比官方的 llvm ir 要额外扩展了一些 hardware intrinsic 和 conversion。想了解可以看 llvm project 中的 llvm/include/llvm/IR/IntrinsicsNVVM.td 和 llvm/lib/Target/NVPTX/。
ptx 后序会根据硬件信息转为 sass。

triton 中的 layout 在 triton gpu dialect 才第一次出现,作为 attr 辅助 op 的 conversion 和 transform,主要分为两种:distributed layout 和 shared layout。
distributed layout:描述 tensor 应该如何被 thread 访问,又分为 block layout、mma layout 和 dotoperand layout
block layout:使用 AxisInfoAnalysis 获得 load 和 store 等指针操作 op 具体的操作 tensor(shape、layout 信息等) 以及连续性信息,这个信息后序会用来在 memory-coalesce (访存合并)。
mma layout 和 dotoperand layout:我理解都是描述了特定 op 的 operand 的数据布局,以指导后序 op 的 lowering。
shared layout:shared layout 描述了 share mem 中可能被同时访问的处于同一个 bank 的数据。share mem 中的每个 bank 会会单独相应内存访问请求,所以同一时间内,若多个 thread 访问的数据处于同一个 bank 就会产出 bank conflict,导致吞吐异常。所以根据 shared layout 进行 layout-swizzling,调整相关的数据布局。
关注开源社区的朋友们可能了解到前段时间寒武纪开源了 triton-linalg 这为其他 DSA(ASIC) 接入 triton 提供了一条不错的道路。
我了解到这条路:triton-lang -> triton ir -> linalg dialect -> hardware dialect -> llvm ir -> … ->hardware assembly
和硬件无关的 dialect 比官方的更多,优化可以在 ttir、linalg-on-tensor,linalg-on-memref,hardware dialect 做,arch 和 non-arch 的抽象隔离地比较好。但路铺得长工程量也就大。

支持 triton,一是可以用户也方便自己定义 kernel,迁移成本低。二是开发人员也可以写算子库或者特定的加速库,现在很多框架中都带上了 triton 的 kernel 实现。
和官方的不同在于,(个人理解)从 ttir 开始分叉开,用不上官方的 t tgir 往后的优化 pass,本质上更贴近 SPMD 的编程范式,某些原语有自己的映射。在优化 ir 上,某些也是根据硬件特性来的。
当时比较蒙,后面面试官大哥说'更适合访存密集型任务',后来想来,或许是 codegen 这条路更好调整数据的 memory 层次,比较好优化不同 memory-space 之间的 data flow?(不知道有没有其他大哥解答下)
当时只想到了 SIMD 希望访存更加连续,SIMT 希望吞吐更大。今天回顾了一下文章,有了棍子的文章:漫谈高性能计算与性能优化:访存
突然感觉能串起来一些了。
SIMD 核心优化 latency,越快完成越好 -> 保证访存连续性,用连续指令 (非 strided,非 scalar)
其他常见优化:
SIMT 核心优化 throughtput,吞吐越大越好 -> 用好 DMA 和 TMA,打满 tensorcore
其他常见优化:
关于 memory-coalesce、layout-swizzling 再补充一点个人的看法,如有不对,烦请指正~
memory-coalesce:warp 中的 threads 在访问地址时都会发送内存访问请求到 LSU,如果这些请求是想访问一片连续的空间,那么这些请求会被合并成一个或者少数几个。
layout-swizzling:smem 中的数据是以 bank 组织的,每个 bank 可以独立地处理一个内存访问请求,如果多个请求指向同一个 bank,就会产生 bank-conflict。该 pass 就会调整数据的排布来避免 bank-conflit,至于怎么调整,是由硬件 shuffle 还是软件更改地址映射关系,我真不知道,求大佬们解答。
做算子融合的时候,从写算子的角度上,惯性是想先确定能够 fuse 到一起的 op,再把这个序列一起 tile。fuse 起来的形式一般是固定的 pattern 或者像 xla 那类,或者是贪心的。这样 tile 时就能获得一个片上一定能放得下且性能不错的 tile_size。
我熟悉的一条路是先 tile 再 fuse,找到一个锚点 op,确定好 tile 策略后再将 producer 和 consumer 给贪心的 fuse 进去,后续 hardware dialect 再做 mlu+add -> fma 这样的融合行为。今天听大哥讲,可以从 tvm ansor 的那种行为去理解,应用模版后就只需要去寻找其中的 tile_size 就好了,贪心 fuse 不进去就算了,但后续这条路应该还是会优化的,毕竟我想基于这相关的做做毕设。
好些次听到这个问题了,每次我都是回答会有一些不一样吧,毕竟推理优化和训练优化是有区别的,训练需要关心梯度传播的过程,导致算子融合的行为不能激进,且要保存很多中间结果,而且推理比较好专项根据场景优化。
面试完跟 leader 聊了下,果然是我 too young too simple 了,他说对于 codegen 这条路并不关心上层是在训练还是推理,来一个 task(或者说一段计算图)正常 lower + codegen 就好了。
软流水是大哥们做的,pass 太大了,还没细看,但我直接蒙是最内层展开,确实是这样,但是所以然呢?
大哥简单说了下,最内层展开后再排开,出现的依赖关系好分析一些,排成不同的 pipeline stage 就好了。有时候我们比算子库的性能好,是因为他们是用高级语言写出的手动排流水,而我们排流水时操作的都是很贴近硬件汇编的 dialect 了,所以更精细,理论的上限是要比算子库好。
那为什么现在挺多还是差点呢,算子库还是太懂硬件的脾气了!!
这个问题我在写代码时也思考过一些,现在总结下来,以我的理解大体如下:
mlir 的框架中主要有两种数据抽象,tensor 和 memref(aka. buffer),这两者分别对应着 ML 编译器中的高层抽象 (torch.tensor) 和传统低级编译器的低层抽象 memory buffer。tensor 通过 bufferization 转为 memref 表示,一些 dialect 中 operand 可以是 tensor 也可以是 memref,例如 linalg(tensor + linalg-on-tensor --bufferization–> memref + linalg-on-memref)。
相同点:都可以用来表示算子的 operand
不同点:tensor 语义上只能被定值一次,即声明的那一次,和 SSA 的定义有点相似。(SSA IR 要求每个变量只能有一次值域,且使用前需要先定义)
# 我们从下面的 ir 的 a、b、c 三个点去获得 extract 的值,都是相同的,都是源自于 %1 = tensor.empty 创建时获得的随机值。
%1 = tensor.empty
%extract = tensor.extract %1 // a 点
%fill = linalg.fill outs(%1)
%extract = tensor.extract %1 // b 点
%map = linalg.map outs(%1)
%extract = tensor.extract %1 // c 点
memref 是可变的,可以被多次 def,并且许多 memref 是可能存在 alias 关系,所以在 data flow analysis 中需要考虑 alias analysis。
tensor 上的 rewrite 更简单,因为 tensor 操作都没有 side-effect,而 memref 操作大概率有。
memref 中 ir 的顺序很重要,移动 ir 很可能导致程序语义改变,所以 clone 行为要注意。而 tensor 中的 clone 行为一般没问题。简而言之 ir-on-tensor 中的 clone 行为即使改变了 ir 的次序,一般也不会影响程序语意。而 ir-on-memref 就不行了,memref 上的 ir 要避免改变 ir 次序,否则可能发生下面的情况。
// 若把 %load clone 到 它的 user 前(scf.forall)内,这样程序的语意就被改变了,因为中间有对 %alloc 的 def
%load = memref.load %alloc
def %alloc
scf.forall
use %load
按我的理解,linalg 包含 linalg-on-tensor 和 linalg-on-memref,做了一个很强的中间胶水层,向上承接程序的计算描述,向下准备下降到 hardware dialect,更贴近目标硬件。后面看了看 linalg 的官方文档
上面赫然写着:Linalg is designed to solve the High-level Hierarchical Optimization
Linalg IR 也提供了许多好用的 transforms:
简而言之,Linalg Dialect 是很重要的一个层级,在这之前的 dialect 更多得是对计算的描述,表达原有的 ML 程序。而从 Linalg 开始,就会经过一系列变换 (tile, fuse, promotion, bufferize) 贴近目标硬件。
Smallvector 和 std::vector 的异同 SmallVector 会现在栈上分配一定的预留空间,当压入的元素所占空间超过其预留空间时就会退化到和 std::vector 差不多的行为(在堆上分配空间)。这样避免了小规模数据空间的申请和释放开销,而且 std::vector 在空间增长时采用的是倍增策略。
栈上的空间是由编译器自动分配释放,一般存放函数参数和局部变量啥的。
程序员申请的空间一般在堆上,需要手动申请和释放。
StringRef 和 std::string 的异同 StringRef 是个轻量化的字符串引用类,指向现有的字符串数据,而不管理这片数据的地址 (没有这片数据的所有权)。想要存储一个 StringRef 往往是不安全的。(因为 data 的真实 memory 可能随时被修改)
llvm/include/llvm/ADT/StringRef.h 中写到:This class does not own the string data。
const char *Data = nullptr; // 不能改变指针指向区域的值,但是可以改变指针指向的区域
const char * _,_指针可以改,指针指向的值不能改
char * const,指针不能改,指针指向的值可以改
std::string 完全管理自己的内存。
SmallVector 可以使用什么替换(ArrayRef、SmallVectorImpl 的使用场景) ArrayRef 表示对一个 array 的 const 引用,和 StringRef 一样,也没有真实数据的所有权。当传入函数的对象 (Smallvector) 不需要被修改时,用 ArrayRef 就可以避免不必要的拷贝。
const SmallVector & <–> ArrayRef (better)
SmallVectorImpl 在构造时不需要'预留元素个数'这个参数,所以函数可能传入的 SmallVector 实例大小不一时,常用 SmallVectorImpl,这样避免依赖或硬编码任何具体的容量信息,减少参数的拷贝。
如果数据很多,且每次增长(push 进)的量很大,或许也可以采用 std::vector,这样能减少调整空间的次数。
代码题乱入!这道题因为当时太紧张,都忘了图节点的关系应该用二维数组或者二维链表来定义了,所以就没写出来。面试官说其实是想考如何分析 def-use 链。这不巧了,我刚好写过很依赖 def-use 分析的 pass。
其实 greedily fuse produer and consumer 的行为也是在对 op 间的关系进行一个拓扑排序,最终获得的 fuse 序列应该保证原来的 ir 执行顺序的。我们以一个简单的序列为例,a -> b 表示 a 是 b 的 producer,b 是 a 的 consumer。使用一个 set 类型的数据结构作为 visited 记录,选择 vector 类型的数据结构来记录结果的拓扑序。

(1)先找到当前无依赖(无前驱/无 producer)的节点作为起点,这里选 1,作为当前的 candidantOp
(2)对于 candidantOp,首先看 visited 中是否已经处理过 visited.insert(candidantOp).second,如果没有处理过,如果 visited 中没有则进入下一步,反正当前对该
(3)遍历 candidantOp 的 producer,若存在 producer,则递归先把 producer 当作新的 candidantOp 去处理;若没有 producer 则把该 candidantOp 压入 topological seq vector。
(4)然后遍历 candidantOp 的 consumer,若 consumer 存在,则继续把 consumer 当作新的 candidantOp 去处理。
直到所有 op 都处理完,(2)(3)(4)步会多次处理。
结合上面的例子和算法,那么我们访问的顺序是:
遇见 1,1 没访问过,1 插入 visited;1 无 producer,1 入 topological seq vector。访问 1 的 consumer,首先是 3。3 没访问过,3 插入 visited;3 的 producer 还没处理,则当前去处理 2。2 没有被访问过,则开始处理 2。2 无 producer,2 入 topological seq vector。访问 2 的 consumer,2 的 consumer 只有 3,且 3 在 visited 中,2 的处理结束。返回 3 的处理,现在 3 的 producer 已在 topological seq vector 中,3 无其他依赖,将 3 加入 topological seq vector。访问 3 的 consumer,4 没访问过,4 插入 visited。4 的 producer 是 1 和 3,1 和 3 都已经被 visited,所以将 4 加入 topological seq vector。4 没有 consumer。继续访问 1 的 consumer,当前是 4,但是 4 已经被 visited,所以结束。最终获得序列 1->2->3->4。

以下是拓扑排序的 Python 实现示例:
def topological_sort(graph):
# graph: dict {node: [neighbors]}, neighbors are consumers of node
in_degree = {node: 0 for node in graph}
for u in graph:
for v in graph[u]:
in_degree[v] += 1
queue = [node for node in graph if in_degree[node] == 0]
result = []
while queue:
node = queue.pop(0)
result.append(node)
for neighbor in graph[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
return result if len(result) == len(graph) else []
就只知道这些常见的了,更高级的还是没多关注下,之前看到 Mamba,最近看到了一个 linear-attention,但都还没来得及看看是怎么回事。
每个 SM 都有独立的 smem, constant cache, register mem,SM 之间共享 L2 Cache 和 gdram。一个 SM 包含多个 SP(即 cuda core)和 tensor core。
一个 SM 可以处理多个 thread block(或者说 CTA),当其中有 block 的所有 thread 都处理完后,他就会再去找其他还没处理的 block 来处理。
自从 Volta 架构后,smem 和 L1 Cache 就合成一块 memory unit 了,程序员可以根据任务场景自己配置对应的大小。
例如,在放存密集且连续的场景下(例如 matmul),smem 大一些性能更好。但是 smem 和 L1 Cache 的总大小是一定的。
L1 Cache 保留的原因:L1 在某些场景下也是必要的,例如以 sparse computing 中;smem 是很快会用到的,L1 是从 dram 上取来的,cache 是防止低速访存必要的,smem 能防止污染 cache。
关于 smem 和 L1 cache 合并起来我一直稀里糊涂的,在想程序猿咋能将 cache 和 smem 统一编程呢,cache 咋看得到的。稀里糊涂下,私下请教了下问 @Antinomi,清晰了许多。
关于 arch 的知识还是太薄弱了!

网络图图
还没人问我八股其实。。。但问到我也大概率不咋记得,平时用的多的就继承、虚函数、智能指针、constexpr、override 等一些概念,其他用得少记忆模糊.jpg
考一次被代码手撕一次,肥肠残忍,这就是我这个短板了。每次想到要刷题世界就美好起来了,读源码都十分有意思。
希望接下来的面试官轻点敲打
虽然到现在还没拿到梦厂的 offer 很让人痛苦,但是知道不去实习拿到梦厂的 offer 概率比较难。继续跟着大哥们学习吧,祝大家万事胜意~


微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML转Markdown 互为补充。 在线工具,Markdown转HTML在线工具,online