tqchen commented on a change in pull request #5289: add tensorflow cumsum
URL: https://github.com/apache/incubator-tvm/pull/5289#discussion_r407139910
##########
File path: topi/include/topi/transform.h
##########
@@ -1105,6 +1105,88 @@ inline tvm::te::Tensor matmul(const tvm::te::Tensor& A,
return tvm::te::compute(output_shape, l, name, tag);
}
+/**
+ * Compute the cumulative sum of the tensor `A` along `axis`.
+ *
+ * By default, this operation performs an inclusive cumsum, which means that
the first
+ * element of the input is identical to the first element of the output:
+ *
+ * ```python
+ * cumsum([a, b, c]) # [a, a + b, a + b + c]
+ * ```
+ *
+ * By setting the `exclusive` kwarg to `True`, an exclusive cumsum is
performed
+ * instead:
+ *
+ * ```python
+ * cumsum([a, b, c], exclusive=True) # [0, a, a + b]
+ * ```
+ *
+ * By setting the `reverse` kwarg to `True`, the cumsum is performed in the
+ * opposite direction:
+ *
+ * ```python
+ * cumsum([a, b, c], reverse=True) # [a + b + c, b + c, c]
+ * ```
+ *
+ * The `reverse` and `exclusive` kwargs can also be combined:
+ *
+ * ```python
+ * cumsum([a, b, c], exclusive=True, reverse=True) # [b + c, c, 0]
+ * ```
+ *
+ * @param A Input tensor
+ * @param axis Must be in the range `[-rank(x), rank(x))`
+ * @param exclusive Perform exclusive cumsum
+ * @param reverse Performed in the opposite direction
+ * @param name The name of the operation
+ * @param tag The tag to mark the operation
+ * @return A Tensor whose op member is the cumsum operation
+ */
+inline tvm::te::Tensor cumsum(const tvm::te::Tensor& A,
+ int axis,
+ bool exclusive = false,
+ bool reverse = false,
+ std::string name = "T_cumsum",
+ std::string tag = kCumsum) {
+ int totalSize = static_cast<int>(A->shape.size());
+ if (axis < 0) {
+ axis = totalSize + axis;
+ }
+ auto maxLength = A->shape[axis];
+ auto l = [&](const Array<Var>& input_indices) {
Review comment:
This way of implementing cumsum would cost additional compute, since we are
doing reduction in parallel without reuse, can we use scan instead?
https://tvm.apache.org/docs/tutorials/language/scan.html
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services