Skip to main content

triton_language.device_print

triton.language.device_print(prefix, *args, hex=False)

在运行时从设备打印值。字符串格式化不适用于运行时的值,因此你应该将要打印的值作为参数提供。第一个值必须是字符串,所有后续的值必须是标量或张量。

调用 Python 内置的 print 等同于调用这个函数,并且参数的要求将与这个函数匹配(而不是普通的 print 要求)。

tl.device_print("pid", pid)
print("pid", pid)

在 CUDA 中,printfs 的输出通过一个大小有限的 buffer 进行流传输在一个主机上,我们测量默认为 6912 KiB,但这可能在不同的 GPU 和 CUDA 版本上有所不同)。如果你注意到一些 printf 被丢弃了,你可以通过调用增加 buffer 大小。

triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)

在运行使用 printf 的内核后,CUDA 可能会在尝试更改这个值时引发错误。在这里设置的值可能只影响当前设备(因此如果你有多个 GPU,你需要多次调用)。

参数**:**

  • prefix - 在值之前打印的前缀,必须是字符串字面值。
  • args - 要打印的值,可以是任何张量或标量。
  • hex - 以十六进制打印所有值,而不是十进制。