Skip to main content

融合 Softmax (Fused Softmax)

在本教程中,您将编写一个融合的 softmax 操作,该操作在某些类别的矩阵上比 PyTorch 的原生操作快得多:即那些可以适应 GPU 静态随机存取存储器 (SRAM) 的行。

通过这个过程,您将了解以下内容:

  • 内核融合对于带宽受限操作的优势。
  • Triton 中缩减操作。

动机

用于逐元素加法的自定义 GPU 内核有教学上的价值,但在实践中不能带来很大的进展。

让我们转而考虑一个简单的(数值稳定的)softmax 操作:

import torch


import triton
import triton.language as tl
from triton.runtime import driver




def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
使用原生 PyTorch 计算 X 的逐行 softmax


We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
我们减去最大元素以避免溢出。Softmax 对于这种偏移是不变的。
"""
# read MN elements ; write M elements
# 读取 MN 个元素;写入 M 个元素
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
# 读取 MN + M 个元素;写入 MN 个元素
z = x - x_max[:, None]
# read MN elements ; write MN elements
# 读取 MN 个元素;写入 MN 个元素
numerator = torch.exp(z)
# read MN elements ; write M elements
# 读取 MN 个元素;写入 M 个元素
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
# 读取 MN + M 个元素;写入 MN 个元素
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
# 总计:读取 5MN + 2M 个元素;写入 3MN + 2M 个元素
return ret

直接在 PyTorch 中实现时,对于 xRM×Nx \in \mathbb{R}^{M \times N},计算 y = naive_softmax(x) 需要从 DRAM 中读取 5MN+2M5MN + 2M 个元素,并写回 3MN+2M3MN + 2M 个元素。

这显然是浪费的;我们更希望有一个自定义的「融合」内核,它只需读取一次 X,并在芯片上进行所有必要的计算。

这样做只需要读写 MNMN 字节,因此我们可以期望理论上的加速约为 4 倍(即 8MN+4M2MN\frac{8MN + 4M}{2MN})。

torch.jit.script 标志旨在自动执行这种「内核融合」,但正如我们后面将看到的,它仍不够理想。

计算内核

softmax 内核工作原理如下:每个程序加载输入矩阵 X 的一组行,按程序数量跨步处理,对其进行归一化,并将结果写回输出 Y。

注意,Triton 的一个重要限制是每个块必须具有 2 的幂次数的元素,因此,如果我们要处理任意可能的输入形状,我们需要在内部「填充」每一行,并适当保护内存操作。

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
# 程序起始行
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
# 步长表示我们需要对指针增加多少以推进 1 行
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# 块大小是大于 n_cols 的下一个二的幂,因此我们可以适配
# row in a single block
# 单个块中的行
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
# 将行加载到 SRAM 中,使用掩码,因为 BLOCK_SIZE 可能大于 n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
# 为了数值稳定性而减去最大值
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
# 请注意,Triton 中的指数运算速度很快,但是是近似的(例如,类似于 CUDA 中的 __expf)。
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
# 将输出写回 DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)

我们可以创建一个辅助函数,为任何给定的输入张量建立内核及其(元)参数队列。

device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
n_rows, n_cols = x.shape


# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
# 每次循环迭代的块大小是大于 `x` 列数的最小二的幂
BLOCK_SIZE = triton.next_power_of_2(n_cols)


# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# 另一个技巧是通过增加每行分配的线程数来要求编译器使用更多的线程块 (`num_warps`)
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
# 将在下一个教程中看到如何以更自然的方式自动调整此值,以免自己进行手动启发式处理。
num_warps = 8


# Number of software piepling stages.
# 软件流水线阶段的数量
num_stages = 4 if SIZE_SMEM > 200000 else 2


# Allocate output
# 分配输出空间
y = torch.empty_like(x)


# pre-compile kernel to get register usage and compute thread occupancy.
# 预编译内核以获取寄存器使用情况并计算线程占用情况。
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
kernels[BLOCK_SIZE] = (kernel, num_programs)


num_programs = min(num_programs, n_rows)


# Create a number of persistent programs.
# 创建一些持久化程序。
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
)
return y

单元测试

我们将在一个具有不规则行和列数的矩阵上测试我们的内核。

这将验证我们的 Padding 机制是否起作用。

torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

结果与预期相同。

基准测试

此处将基于输入矩阵中列数的函数进行基准测试,假设有 4096 行,定义 naive_softmax

然后将其性能与(1)torch.softmax 和(2)上面定义的 naive_softmax 进行比较。

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot 用作图表 x 轴的参数名
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` `x_name` 的不同可能值
line_arg='provider', # argument name whose value corresponds to a different line in the plot 参数名,其值对应于图表中不同线条
line_vals=['triton', 'torch'], # possible values for `line_arg`` `line_arg` 的可能值
line_names=[
"Triton",
"Torch",
], # label name for the lines 线条的标签名称
styles=[('blue', '-'), ('green', '-')], # line styles 线条的样式
ylabel="GB/s", # label name for the y-axis y 轴的标签名称
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. 图表的名称,也用作保存图表的文件名
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` `x_names` 和 `y_name` 中未包含的函数参数的值
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)




benchmark.run(show_plots=True, print_data=True)

图片

Out:

softmax-performance:

NTritonTorch
256.0475.581977708.619322
384.0619.872425812.799315
512.0752.326527927.222924
640.0788.217790946.386719
768.0880.8876791014.912158
896.0937.3441581074.519017
1024.0994.0493281120.599053
1152.01096.160464616.484209
1280.01136.037680669.424776
1408.01150.661622725.262518
1536.01195.385896783.556680
1664.01218.037815812.802866
1792.01240.453775857.206087
1920.01249.594057910.379759
2048.01281.002942960.369226
2176.01258.141618976.327061
2304.01268.0293741013.671493
2432.01295.3843871059.587886
2560.01306.6141871084.683454
2688.01317.1690331104.769558
2816.01327.2175781127.015242
2944.01321.8508461164.100210
3072.01351.1404191185.534776
3200.01355.2709501195.189132
3328.01350.9267971219.700403
3456.01370.8510951249.846232
3584.01370.7333451257.045186
3712.01380.2226911272.332674
3840.01386.8470051304.931759
3968.01390.0967651314.800917
4096.01395.6918521329.296474
4224.01336.8921091157.837774
4352.01338.4902691173.375508
4480.01350.2031361183.423201
4608.01361.6925571198.281856
4736.01359.5385111196.113447
4864.01374.1597251224.748171
4992.01370.3390121237.542346
5120.01371.0618811250.239195
5248.01373.7410131256.002531
5376.01382.8621701286.354639
5504.01377.6797971300.142739
5632.01378.5580081311.940458
5760.01393.9621791329.921162
5888.01395.8248881346.085280
6016.01401.4880371355.059080
6144.01406.3459071374.157489
6272.01412.6870191376.883517
6400.01415.3091061389.410912
6528.01417.2047271392.583463
6656.01422.0827751405.407043
6784.01416.9996531415.459830
6912.01427.9975481424.580919
7040.01420.0798211433.713238
7168.01428.2268681434.182051
7296.01426.9072411443.570904
[.01431.2459691444.524696
7552.01429.8527751455.236120
7680.01438.2228461459.114601
7808.01432.0842051467.194446
7936.01435.6123361467.986631
8064.01434.1184611472.734245
8192.01442.3121921483.740088
8320.01388.7842961401.371945
8448.01380.6489711407.791889
8576.01397.3848331396.228603
8704.01393.3290001400.798649
8832.01382.3156051401.590681
8960.01396.8303111413.000037
9088.01409.9164071418.336829
9216.01406.3856281423.454382
9344.01399.8745281424.049331
9472.01399.2376551435.753027
9600.01397.4596281430.972090
9728.01397.8916601440.957731
9856.01413.1151591443.096680
9984.01403.1939951448.557274
10112.01410.6191291460.444766
10240.01419.4791371469.013209
10368.01410.9516461462.658968
10496.01418.6957291464.967769
10624.01408.4280651471.324774
10752.01406.4216981472.142310
10880.01400.0460541480.109648
11008.01420.7144011480.192162
11136.01419.6326771486.624982
11264.01431.4037611485.485510
11392.01414.2971531487.808132
11520.01424.3871301492.888886
11648.01420.4365001499.220597
11776.01426.9110011499.184

在上面的图中,我们可以看到:

  • Triton 比 Torch JIT 快 4 倍。这证实了我们对 Torch JIT 在这里没有进行任何融合的怀疑。
  • 除了更易于阅读、理解和维护外,Triton 明显比 torch.softmax 快。

Download Jupyter notebook: 02-fused-softmax.ipynb

Download Python source code: 02-fused-softmax.py

Download zipped: 02-fused-softmax.zip