masahi commented on a change in pull request #7334:
URL: https://github.com/apache/tvm/pull/7334#discussion_r564164916
##########
File path: python/tvm/topi/cuda/scan.py
##########
@@ -19,30 +19,36 @@
import tvm
from tvm import te
from tvm._ffi import get_global_func
-from ..transform import expand_dims, squeeze
-from ..utils import ceil_div
+from ..transform import expand_dims, squeeze, transpose, reshape
+from ..utils import ceil_div, swap, prod, get_const_int
from ..math import cast
from .. import tag
from .injective import schedule_injective_from_existing
-def exclusive_sum_scan2d_ir(data, output, reduction=None):
+binop_name_to_func = {"sum": tvm.tir.generic.add}
+
+
+def exclusive_scan_ir(data, output, reduction=None, binop="sum"):
"""Low level IR to do exclusive sum scan along rows of 2D input.
Parameters
----------
data : Buffer
- Input data. 2-D Buffer with shape [batch_size, scan_axis_size].
+ Input N-D Buffer. Scan is done over the innermost axis.
output: Buffer
- A buffer to store the output scan, of the same size as data
+ A buffer to store the output scan, of the same shape as data
reduction: Buffer, optional
- 1D Buffer of size [batch_size], to store the sum of each row.
+ (N-1)-D Buffer, to store the sum of each scan axis.
+
+ biop: string, optional
+ A string specifying which binary operator to use. Currently only "sum"
is supported.
Review comment:
Yes see the discussion at
https://github.com/apache/tvm/pull/7334#discussion_r563926496
I'll try if I can improve on this
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]