masahi commented on a change in pull request #7334:
URL: https://github.com/apache/tvm/pull/7334#discussion_r564166044



##########
File path: python/tvm/topi/cuda/scan.py
##########
@@ -353,28 +364,83 @@ def exclusive_scan(data, axis=-1, return_reduction=False, 
output_dtype=None):
             output = te.extern(
                 [data.shape],
                 [data],
-                lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
+                lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], 
binop=binop),
                 dtype=[output_dtype],
                 in_buffers=[data_buf],
                 out_buffers=[output_buf],
                 name="exclusive_scan",
                 tag="exclusive_scan_gpu",
             )
             reduction = None
-    else:
-        assert False, "Unsupported dimension {}".format(ndim)
 
-    if ndim == 1:
-        output = squeeze(output, 0)
+        if ndim == 1:
+            output = squeeze(output, 0)
+            if return_reduction:
+                reduction = squeeze(reduction, 0)
+
         if return_reduction:
-            reduction = squeeze(reduction, 0)
+            return output, reduction
+
+        return output
+
+    if output_dtype is None or output_dtype == "":
+        output_dtype = data.dtype
+
+    ndim = len(data.shape)
+    if axis < 0:
+        axis += ndim
+
+    # If scan axis is not the innermost one, swap the scan and the innermost 
axes
+    # Scan is always done on the innermost axis, for performance reason.
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        data = transpose(data, axes)
+
+    if return_reduction:
+        output, reduction = do_scan(data, output_dtype)
+    else:
+        output = do_scan(data, output_dtype)
+
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        output = transpose(output, axes)
 
     if return_reduction:
         return output, reduction
 
     return output
 
 
+def inclusive_scan(data, axis=-1, output_dtype=None, binop="sum"):
+    """Do inclusive scan on 1D or multidimensional input.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input data of any shape.
+
+    axis: int, optional
+        The axis to do scan on. By default, scan is done on the innermost axis.
+
+    output_dtype: string, optional
+        The dtype of the output scan tensor. If not provided, the dtype of the 
input is used.
+
+    biop: string, optional
+        A string specifying which binary operator to use. Currently only "sum" 
is supported.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        A N-D tensor of the same rank N as the input data.
+    """
+    ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, 
binop=binop)
+
+    if output_dtype is not None and data.dtype != output_dtype and 
output_dtype != "":
+        data = cast(data, output_dtype)
+
+    return binop_name_to_func[binop](data, ex_scan)

Review comment:
       prod makes sense, to support `cumprod`. But currently only "sum" is 
supported. If I manage to convert `binop` argument to function, we don't need 
to worry about this issue.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to