Skip to main content

triton_language.split

triton.language.split(a)tuple[tensor, tensor]

将张量沿着其最后 1 个维度分成 2 部分,该维度的大小必须为 2。

例如,给定 1 个形状为 (4,8,2) 的张量,生成 2 个形状为 (4,8) 的张量。给定 1 个形状为 (2) 的张量,返回 2 个标量。

如果希望拆分成多个部分,可以多次调用这个函数(可能还需要调用 reshape 函数)。这反映了 Triton 中的约束,即张量必须具有 2 的幂次方大小。

Split 是 join 的逆操作。

参数**:**

  • a (Tensor) - 要被分割的张量。

这个函数也可作为 tensor 的成员函数调用,作为 x.split(...) 而不是 split(x, ...)