Triton 语义
Triton 在大多数情况下遵守 NumPy 的语义,但也有一些例外。在本文档中,我们将介绍 Triton 支持的一些数组计算功能,并讨论 Triton 语义与 NumPy 不同的例外情况。
类型提升
类型提升 (Type Promotion) 是在不同数据类型的张量参与运算时发生的。对于与双下划线方法相关的二元运算和三元函数 tl.where
的最后两个参数,Triton 会自动将输入张量转换为一个通用的数据类型,这一转换遵循数据类型种类的层级顺序:{bool} < {integral dypes} < {floating point dtypes}
。
该算法如下:
-
类型:如果一个张量的数据类型属于更高级的类型,则另一个张量将被提升至该数据类型,例如:
(int32, bfloat16) -> bfloat16
。 -
宽度:如果两个张量的数据类型属于同一类,但其中一个张量具有更大的宽度,则另一个张量将被提升至此数据类型,例如:
(float32, float16) -> float32
。 -
上限:如果两个张量的宽度和符号相同,但数据类型不同,则它们都将被提升为下一个更大的数据类型,例如:
(float16, bfloat16) -> float32
。3.1 如果两个张量的数据类型都是不同的
fp8
类型,它们将都被统一转换为float16
。 -
无符号类型优先:在其他情况下(相同宽度,不同符号),它们将被提升为相应的无符号数据类型:
(int32, uint32) -> uint32
。
当涉及标量时,规则会有所不同。在此处,标量是指数值字面量、标记为 tl.constexpr 的变量或这些的组合。它们由 NumPy 标量表示,并具有bool
、int
和float
等类型。
当一个操作涉及张量和标量时:
-
如果标量的类型低于或等于张量的类型,则标量不会参与类型提升:
(uint8, int) -> uint8
。 -
如果标量的类型高于张量,我们将选择最适合标量的数据类型,在
int32
<uint32
<int64
<uint64
中为整数选择,在float32
<float64
中选择浮点数。然后,张量和标量都将提升为这种数据类型:(int16, 4.0) -> float32
。
广播
广播 (Broadcasting) 允许对不同形状的张量进行操作,它会自动将张量的形状扩展至兼容的大小,且在此过程中无需复制数据。其遵循以下规则:
-
如果其中一个张量的形状较短,则在左侧填充 1,直到两个张量的维数相同,例如:
((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))
。 -
如果两个维度相等,或者其中之一为 1,那么这两个维度是兼容的。维度为 1 的那个维度将会被扩展以匹配另一个张量的维度,例如:
((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))
。
与 NumPy 的不同之处
为了提高效率,Triton 的整数除法运算符遵守 C 语义,而不是 Python 语义。因此,混合符号的整数的 int // int
运算实现为 C 语言中的向零舍入,而不是 Python 中的向负无穷舍入。同样,取模运算符 int % int
(定义为 a % b = a - b * (a // b)
)也遵守 C 语义,而不是 Python 语义。
在所有输入都是标量的情况下,整数除法和取模运算遵循 Python 语义,这点可能会造成迷惑。