jainris commented on a change in pull request #6303:
URL: https://github.com/apache/incubator-tvm/pull/6303#discussion_r475695170
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -3093,5 +3093,55 @@ RELAY_REGISTER_OP("sparse_to_dense")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);
+// relay.matrix_set_diag
+bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [input, diagonal, result]
+ CHECK_EQ(types.size(), 3);
+
+ const auto* input = types[0].as<TensorTypeNode>();
+ CHECK(input);
+
+ const auto* diagonal = types[1].as<TensorTypeNode>();
+ CHECK(diagonal);
+
+ int d_ndims = diagonal->shape.size();
+ for (int i = 0; i < d_ndims - 1; i++) {
+ reporter->AssertEQ(input->shape[i], diagonal->shape[i]);
+ }
+ auto min_dim = if_then_else(input->shape[d_ndims - 1] >=
input->shape[d_ndims],
+ input->shape[d_ndims], input->shape[d_ndims -
1]);
+ reporter->Assert(diagonal->shape[d_ndims - 1] >= min_dim);
+
+ reporter->Assign(types[2], TensorType(input->shape, input->dtype));
+ return true;
+}
+
+Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const
Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1])};
+}
+
+Expr MakeMatrixSetDiag(Expr input, Expr diagonal) {
+ static const Op& op = Op::Get("matrix_set_diag");
+ return Call(op, {input, diagonal}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag);
+
+RELAY_REGISTER_OP("matrix_set_diag")
+ .describe(
+ R"code(Returns a tensor with the diagonal of input tensor replaced
with the provided diagonal values.
+ **input** Input tensor.
+ **diagonal** Values to be filled in the diagonal.
+ )code" TVM_ADD_FILELINE)
+ .set_num_inputs(2)
+ .add_argument("input", "Tensor", "Input Tensor.")
+ .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.")
+ .set_support_level(10)
+ .add_type_rel("MatrixSetDiag", MatrixSetDiagRel)
+ .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
Review comment:
Thanks for reviewing.
Changed it to be injective.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]