codeislife99 commented on a change in pull request #7126:
URL: https://github.com/apache/tvm/pull/7126#discussion_r547596393



##########
File path: include/tvm/topi/transform.h
##########
@@ -1386,6 +1386,96 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& 
inputs, const std::string& in
   return result;
 }
 
+/*!
+ * \brief Fill Empty rows of a sparse tensor with default value
+ *
+ * \param sparse_indices Indices where values of the dense tensor exist
+ * \param sparse_values Values at the above indices respectively
+ * \param default_value Default value at to be used at empty rows
+ * \param dense_shape Dense shape of the sparse tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the SparseFillEmptyRows operation
+ */
+inline Array<Tensor> SparseFillEmptyRows(const Tensor& sparse_indices, const 
Tensor& sparse_values,
+                                         const Tensor& default_value,
+                                         const Array<Integer>& dense_shape,
+                                         const std::string name = 
"T_sparse_fill_empty_rows",
+                                         std::string tag = kInjective) {
+  Array<Tensor> result;
+  Array<PrimExpr> sp_ordered_output_shape;
+  sp_ordered_output_shape.push_back(dense_shape[0] + sparse_indices->shape[0]);
+  if (sparse_indices->shape.size() > 1) {
+    sp_ordered_output_shape.push_back(sparse_indices->shape[1]);
+  }
+  auto empty_row_indicator =
+      compute(Array<PrimExpr>{dense_shape[0]}, [&](const Array<Var>& indices) {
+        PrimExpr ret = PrimExpr(Bool(1));
+        for (int i = 0; i < GetConstInt(sparse_indices->shape[0]); ++i) {

Review comment:
       Yes there are 3 sparse ops that we are trying to target for a customer 
for a TF model explicitly. 
   These 3 are : 
   1. 
[sparse_reshape](https://www.tensorflow.org/api_docs/python/tf/sparse/reshape) 
   2. 
[sparse_segment_sum](https://www.tensorflow.org/api_docs/python/tf/sparse/segment_sum?hl=bn)
   3. 
[sparse_fill_empty_rows](https://www.tensorflow.org/api_docs/python/tf/sparse/fill_empty_rows)
   4. 
[sparse_segment_sum_sqrt_n](https://www.tensorflow.org/api_docs/python/tf/sparse/segment_sqrt_n?hl=bn)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to