Skip to main content

triton.language

编程模式

tensor表示一个值或指针的 N 维数组
program_id沿指定轴返回当前程序实例的 id
num_programs沿指定轴返回当前程序实例的数量

创建操作

arange返回半开区间 [start, end) 内的连续值
cat连接给定的块
full返回一个张量,该张量填充了指定 shape 和 dtype 的标量值
zeros返回一个张量,该张量用指定 shape 和 dtype 填充了标量值 0
zeros_like返回一个 shape 和 dtype 与给定张量相同的全零张量
cast将张量转换为指定的 dtype

Shape Manipulation Ops

broadcast尝试将两个给定的块广播到一个共同兼容的 shape
broadcast_to尝试将给定的张量广播到新的 shape
expand_dims通过插入新的长度为 1 的维度来扩展张量的形状
interleave沿着最后一个维度交错两个张量的值
join在一个新的次要维度中连接给定的张量
permute排列张量的维度
ravel返回 x 的连续扁平视图
reshape返回一个具有与输入相同元素数但具有提供的形状的张量
split将张量沿其最后一个维度分成两部分,该维度大小必须为 2
trans排列张量的维度。
view返回具有与输入相同元素但形状不同的张量

Linear Algebra Ops

dot返回两个块的矩阵乘积

内存/指针操作

load返回一个张量,其值从由指针定义的内存位置加载
store将数据张量存储到由指针定义的内存位置
make_block_ptr返回指向父张量中某个块的指针
advance推进一个块指针

索引操作

flip沿着维度 dim 翻转张量 x
where根据 condition 返回来自 x 或 y 的元素组成的张量
swizzle2d将行主序排列为 size_i ***size_j 的矩阵的索引,转换为每组 size_g 行的列主序矩阵的索引

数学操作

abs计算 x 的逐元素绝对值
cdiv计算 x 除以 div 的向上取整除法
ceil计算 x 的逐元素向上取整值
clamp将输入张量 x 的值限制在 [min, max] 范围内
cos计算 x 的逐元素余弦值
div_rn计算 x 和 y 的逐元素精确除法(根据 IEEE 标准四舍五入到最近的值)
erf计算 x 的逐元素误差函数
exp计算 x 的逐元素指数
exp2计算 x 的逐元素指数(以 2 为底)
fdiv计算 x 和 y 的逐元素快速除法
floor计算 x 的逐元素向下取整
fma计算 x、y 和 z 的逐元素融合乘加运算
log计算 x 的逐元素自然对数
log2计算 x 的逐元素对数(以 2 为底)
maximum计算 x 和 y 的逐元素最大值
minimum计算 x 和 y 的逐元素最小值
rsqrt计算 x 的逐元素的平方根倒数
sigmoid计算 x 的逐元素 sigmoid 函数值
sinComputes the element-wise sine of x. 计算 x 的逐元素正弦值
softmax计算 x 的逐元素 softmax 值
sqrt计算 x 的逐元素快速平方根
sqrt_rn计算 x 的逐元素精确平方根(根据 IEEE 标准四舍五入到最近的值)
umulhi计算 x 和 y 的 2N 位乘积的逐元素最高有效 N 位

归约操作

argmax返回沿指定 axis 轴上 input 张量中所有元素的最大索引
argmin返回沿指定 axis 轴上 input 张量中所有元素的最小索引
max返回沿指定 axis 轴上 input 张量中所有元素的最大值
min返回沿指定 axis 轴上 input 张量中所有元素的最小值
reduce将 combine_fn 应用于沿指定 axis 的 input 张量中的所有元素
sum返回 input 张量中,沿指定 axis 的所有元素的总和
xor_sum返回 input 张量中,沿指定 axis 的所有元素的异或和

扫描/排序操作

associative_scan沿指定 axis 将 combine_fn 应用于 input 张量的每个元素和携带的值,并更新携带的值
cumprod返回沿指定 axis 的 input 张量中所有元素的累积乘积
cumsum返回沿指定 axis 的 input 张量中所有元素的累积和
histogram基于 input 张量计算 1 个具有 num_bins 个 bin 的直方图,每个 bin 宽度为 1,起始于 0
sort沿着指定维度对张量进行排序

原子操作

atomic_add在由 pointer 指定的内存位置执行原子加法
atomic_and在由 pointer 指定的内存位置执行原子逻辑和操作
atomic_cas在由 pointer 指定的内存位置执行 1 个原子比较并交换操作
atomic_max在由 pointer 指定的内存位置执行 1 个原子最大值操作
atomic_min在由 pointer 指定的内存位置执行 1 个原子最小值操作
atomic_or在由 pointer 指定的内存位置执行 1 个原子逻辑或操作
atomic_xchg在由 pointer 指定的内存位置执行 1 个原子交换操作
atomic_xor在由 pointer 指定的内存位置执行原子逻辑异或操作

随机数生成

randint4x给定 1 个seed 标量和 1 个offset 块,返回 4 个 int32 类型的随机块
randint给定 1 个 seed 标量和 1 个 offset 块,返回 1 个 int32 类型的随机块
rand给定 1 个 seed 标量和 1 个 offset 块,返回 1 个在 U(0,1)U(0,1) 中的 float32 类型的随机块
randn给定 1 个 seed 标量和 1 个 offset 块,返回 1 个在 N(0,1) 中的 float32 类型的随机块

迭代器

range永远向上计数的迭代器
static_range永远向上计数的迭代器

内联汇编

inline_asm_elementwise在张量上执行内联汇编

编译器提示操作

debug_barrier插入 1 个屏障以同步 1 个块中的所有线程
max_constancy告知编译器 input 中的第 1 个值是常量
max_contiguous告知编译器 input 中的第 1 个值是连续
multiple_of告知编译器 input 中的所有值都是 value 的倍数

调试操作

static_print在编译时打印数值
static_assert在编译时断言条件
device_print在运行时从设备打印数值
device_assert在运行时从设备上断言条件