siju-samuel commented on a change in pull request #6303:
URL: https://github.com/apache/incubator-tvm/pull/6303#discussion_r475647341
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -3093,5 +3093,55 @@ RELAY_REGISTER_OP("sparse_to_dense")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);
+// relay.matrix_set_diag
+bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [input, diagonal, result]
+ CHECK_EQ(types.size(), 3);
+
+ const auto* input = types[0].as<TensorTypeNode>();
+ CHECK(input);
+
+ const auto* diagonal = types[1].as<TensorTypeNode>();
+ CHECK(diagonal);
+
+ int d_ndims = diagonal->shape.size();
+ for (int i = 0; i < d_ndims - 1; i++) {
+ reporter->AssertEQ(input->shape[i], diagonal->shape[i]);
+ }
+ auto min_dim = if_then_else(input->shape[d_ndims - 1] >=
input->shape[d_ndims],
+ input->shape[d_ndims], input->shape[d_ndims -
1]);
+ reporter->Assert(diagonal->shape[d_ndims - 1] >= min_dim);
+
+ reporter->Assign(types[2], TensorType(input->shape, input->dtype));
+ return true;
+}
+
+Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const
Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1])};
+}
+
+Expr MakeMatrixSetDiag(Expr input, Expr diagonal) {
+ static const Op& op = Op::Get("matrix_set_diag");
+ return Call(op, {input, diagonal}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag);
+
+RELAY_REGISTER_OP("matrix_set_diag")
+ .describe(
+ R"code(Returns a tensor with the diagonal of input tensor replaced
with the provided diagonal values.
+ **input** Input tensor.
+ **diagonal** Values to be filled in the diagonal.
+ )code" TVM_ADD_FILELINE)
+ .set_num_inputs(2)
+ .add_argument("input", "Tensor", "Input Tensor.")
+ .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.")
+ .set_support_level(10)
+ .add_type_rel("MatrixSetDiag", MatrixSetDiagRel)
+ .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
Review comment:
Why kBroadcast? i think it shud be injective.
##########
File path: include/tvm/topi/transform.h
##########
@@ -1511,6 +1511,35 @@ inline Tensor sparse_to_dense(const Tensor&
sparse_indices, const Array<Integer>
name, tag);
}
+/*!
+ * \brief Returns a tensor with the diagonal of input tensor replaced with the
provided diagonal.
+ * \param input input tensor.
+ * \param diagonal values to be filled in the diagonal.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return new tensor with given diagonal values.
+ */
+inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal,
Review comment:
A suggestion:- may be if we can support `alignment` and `k`(offset)
similar to `MatrixSetDiagV3` in tf, it will be good. we can support directly
for tensorflow ops as well.
##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -2652,6 +2652,77 @@ def test_forward_reverse_v2():
_test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype)
+#######################################################################
+# MATRIX_SET_DIAG
+# ---------------
+
+def _test_matrix_set_diag(input_shape, input_type, quantized=False):
+ """ One iteration of MATRIX_SET_DIAG """
+ with tf.Graph().as_default():
+ diagonal_shape = list(input_shape[:-2])
+ diagonal_shape.append(min(input_shape[-2], input_shape[-1]))
+
+ if quantized:
+ # ignoring input_type as quantized requires uint8
+ input = np.random.uniform(0, 256, input_shape).astype('uint8')
+ in_input = tf.placeholder(dtype='float32', shape=input.shape,
name="input")
+ inq_input = tf.quantization.fake_quant_with_min_max_args(
+ in_input,
+ min=-100,
+ max=100,
+ name="q_input")
+
+ diagonal = np.random.uniform(0, 256,
diagonal_shape).astype('uint8')
+ in_diagonal = tf.placeholder(dtype='float32',
shape=diagonal.shape, name="diagonal")
+ inq_diagonal = tf.quantization.fake_quant_with_min_max_args(
+ in_diagonal,
+ min=-100,
+ max=100,
+ name="q_diagonal")
+
+ input_range = {'q_input': (-100, 100), 'q_diagonal': (-100, 100)}
+
+ out = array_ops.matrix_set_diag(inq_input, inq_diagonal)
+ out = tf.quantization.fake_quant_with_min_max_args(
+ out,
+ min=-100,
+ max=100,
+ name="out")
+
+ compare_tflite_with_tvm(
+ [input, diagonal],
+ ["q_input", "q_diagonal"],
+ [inq_input, inq_diagonal],
+ [out],
+ quantized=True,
+ input_range=input_range)
+ else:
+ input = np.random.uniform(0, 100, input_shape).astype(input_type)
+ diagonal = np.random.uniform(0, 100,
diagonal_shape).astype(input_type)
+
+ in_input = tf.placeholder(dtype=input.dtype, shape=input.shape,
name="input")
+ in_diagonal = tf.placeholder(dtype=diagonal.dtype,
shape=diagonal.shape, name="diagonal")
+
+ out = array_ops.matrix_set_diag(in_input, in_diagonal)
+
+ compare_tflite_with_tvm(
+ [input, diagonal],
+ ["input", "diagonal"],
+ [in_input, in_diagonal],
+ [out])
+
+def test_forward_matrix_set_diag():
+ """ MATRIX_SET_DIAG """
Review comment:
add a pkg version check > '1.14.0'
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]