tqchen commented on a change in pull request #5289: add tensorflow cumsum
URL: https://github.com/apache/incubator-tvm/pull/5289#discussion_r407220082
##########
File path: topi/include/topi/transform.h
##########
@@ -1105,6 +1105,88 @@ inline tvm::te::Tensor matmul(const tvm::te::Tensor& A,
return tvm::te::compute(output_shape, l, name, tag);
}
+/**
+ * Compute the cumulative sum of the tensor `A` along `axis`.
+ *
+ * By default, this operation performs an inclusive cumsum, which means that
the first
+ * element of the input is identical to the first element of the output:
+ *
+ * ```python
+ * cumsum([a, b, c]) # [a, a + b, a + b + c]
+ * ```
+ *
+ * By setting the `exclusive` kwarg to `True`, an exclusive cumsum is
performed
+ * instead:
+ *
+ * ```python
+ * cumsum([a, b, c], exclusive=True) # [0, a, a + b]
+ * ```
+ *
+ * By setting the `reverse` kwarg to `True`, the cumsum is performed in the
+ * opposite direction:
+ *
+ * ```python
+ * cumsum([a, b, c], reverse=True) # [a + b + c, b + c, c]
+ * ```
+ *
+ * The `reverse` and `exclusive` kwargs can also be combined:
+ *
+ * ```python
+ * cumsum([a, b, c], exclusive=True, reverse=True) # [b + c, c, 0]
+ * ```
+ *
+ * @param A Input tensor
+ * @param axis Must be in the range `[-rank(x), rank(x))`
+ * @param exclusive Perform exclusive cumsum
+ * @param reverse Performed in the opposite direction
+ * @param name The name of the operation
+ * @param tag The tag to mark the operation
+ * @return A Tensor whose op member is the cumsum operation
+ */
+inline tvm::te::Tensor cumsum(const tvm::te::Tensor& A,
+ int axis,
+ bool exclusive = false,
+ bool reverse = false,
+ std::string name = "T_cumsum",
+ std::string tag = kCumsum) {
+ int totalSize = static_cast<int>(A->shape.size());
+ if (axis < 0) {
+ axis = totalSize + axis;
+ }
+ auto maxLength = A->shape[axis];
+ auto l = [&](const Array<Var>& input_indices) {
Review comment:
Hmm, as a temporary workaround, we could do transpose(the target axis to the
highest) run the scan, and transpose back. In terms of algorithm complexity,
this one would be lower. In the long run, we should add support for scaning
over non-top axis
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services