triton_language.where
triton.language.where(condition, x, y)
根据 condition 返回来自 x 或 y 元素的张量。
注意:无论 condition 的值是什么,x 和 y 总是会被求值。
如果希望避免意外的内存操作,请使用 triton.load 和 triton.store 中的 mask 参数。
x 和 y 的形状都会被广播到 condition 的形状。x 和 y 必须具有相同的数据类型。
参数**:**
- condition(triton.bool 的块)- 当为 True(非零)时,产生 x,否则产生 y。
- x - 在条件为 True 的索引处选择的值。
- y - 在条件为 False 的索引处选择的值。