Skip to main content

triton_language.where

triton.language.where(condition, x, y)

根据 condition 返回来自 xy 元素的张量。

注意:无论 condition 的值是什么,xy 总是会被求值。

如果希望避免意外的内存操作,请使用 triton.loadtriton.store 中的 mask 参数。

xy 的形状都会被广播到 condition 的形状。xy 必须具有相同的数据类型。

参数**:**

  • conditiontriton.bool 的块)- 当为 True(非零)时,产生 x,否则产生 y。
  • x - 在条件为 True 的索引处选择的值。
  • y - 在条件为 False 的索引处选择的值。