codeislife99 commented on a change in pull request #7126:
URL: https://github.com/apache/tvm/pull/7126#discussion_r547197106
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1553,6 +1553,63 @@ RELAY_REGISTER_OP("meshgrid")
.set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+TVM_REGISTER_NODE_TYPE(SparseFillEmptyRowsAttrs);
+
+bool SparseFillEmptyRowsRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
+ const TypeReporter& reporter) {
+ // types: [ sparse_indices, sparse_values, default_values, result]
+ ICHECK_EQ(types.size(), 4);
+ ICHECK_EQ(num_inputs, 3);
+ std::vector<Type> fields;
+ auto sparse_indices = types[0].as<TensorTypeNode>();
+ auto default_value = types[2].as<TensorTypeNode>();
+ const auto* param = attrs.as<SparseFillEmptyRowsAttrs>();
+ CHECK(param != nullptr);
+
+ Array<IndexExpr> sp_ordered_output_shape;
+ sp_ordered_output_shape.push_back(param->dense_shape[0] +
sparse_indices->shape[0]);
+ if (sparse_indices->shape.size() > 1) {
+ sp_ordered_output_shape.push_back(sparse_indices->shape[1]);
+ }
+ fields.push_back(TensorType(sp_ordered_output_shape, sparse_indices->dtype));
+ fields.push_back(TensorType(Array<PrimExpr>{param->dense_shape[0]},
tvm::DataType::Bool()));
+ fields.push_back(TensorType(Array<PrimExpr>{sp_ordered_output_shape[0]},
default_value->dtype));
+ fields.push_back(TensorType(Array<PrimExpr>{1}, tvm::DataType::Int(32)));
+ reporter->Assign(types[3], TupleType(Array<Type>(fields)));
+ return true;
+}
+
+Array<te::Tensor> SparseFillEmptyRowsCompute(const Attrs& attrs, const
Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ CHECK_EQ(inputs.size(), 3);
Review comment:
Done.
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1553,6 +1553,63 @@ RELAY_REGISTER_OP("meshgrid")
.set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+TVM_REGISTER_NODE_TYPE(SparseFillEmptyRowsAttrs);
+
+bool SparseFillEmptyRowsRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
+ const TypeReporter& reporter) {
+ // types: [ sparse_indices, sparse_values, default_values, result]
+ ICHECK_EQ(types.size(), 4);
Review comment:
Done
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1553,6 +1553,63 @@ RELAY_REGISTER_OP("meshgrid")
.set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+TVM_REGISTER_NODE_TYPE(SparseFillEmptyRowsAttrs);
+
+bool SparseFillEmptyRowsRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
+ const TypeReporter& reporter) {
+ // types: [ sparse_indices, sparse_values, default_values, result]
+ ICHECK_EQ(types.size(), 4);
+ ICHECK_EQ(num_inputs, 3);
+ std::vector<Type> fields;
+ auto sparse_indices = types[0].as<TensorTypeNode>();
+ auto default_value = types[2].as<TensorTypeNode>();
+ const auto* param = attrs.as<SparseFillEmptyRowsAttrs>();
+ CHECK(param != nullptr);
Review comment:
Done
##########
File path: python/tvm/relay/op/transform.py
##########
@@ -1320,3 +1320,84 @@ def adv_index(inputs):
Output tensor.
"""
return _make.adv_index(Tuple(inputs))
+
+
+def sparsefillemptyrows(sparse_indices, sparse_values, dense_shape,
default_value):
+ """
+ Fill first column of the empty rows with default values for a sparse array.
+
+ Parameters
+ ----------
+ sparse_indices : relay.Expr
+ A 2-D tensor[N, n_dim] of integers containing location of sparse
values, where N is the
+ number of sparse values and n_dim is the number of dimensions of the
dense_shape
+
+ sparse_values : relay.Expr
+ A 1-D tensor[N] containing the sparse values for the sparse indices.
+
+ dense_shape : relay.Expr
+ A list of integers. Shape of the dense output tensor.
+
+ default_value : relay.Expr
+ A 0-D tensor containing the default value for the remaining locations.
+ Defaults to 0.
+
+ Returns
+ -------
+ TupleWrapper with the following four outputs
+
+ new_sparse_indices : relay.Expr
+ A 2-D tensor[N + dense_shape[0], n_dim] of integers containing
location of new sparse
+ indices where N is the number of sparse values. It is filled with -1
at to_be_discarded
Review comment:
Changed.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]