jainris commented on a change in pull request #6303:
URL: https://github.com/apache/incubator-tvm/pull/6303#discussion_r477157597



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -2652,6 +2652,77 @@ def test_forward_reverse_v2():
         _test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype)
 
 
+#######################################################################
+# MATRIX_SET_DIAG
+# ---------------
+
+def _test_matrix_set_diag(input_shape, input_type, quantized=False):
+    """ One iteration of MATRIX_SET_DIAG """
+    with tf.Graph().as_default():
+        diagonal_shape = list(input_shape[:-2])
+        diagonal_shape.append(min(input_shape[-2], input_shape[-1]))
+
+        if quantized:
+            # ignoring input_type as quantized requires uint8
+            input = np.random.uniform(0, 256, input_shape).astype('uint8')
+            in_input = tf.placeholder(dtype='float32', shape=input.shape, 
name="input")
+            inq_input = tf.quantization.fake_quant_with_min_max_args(
+                in_input,
+                min=-100,
+                max=100,
+                name="q_input")
+
+            diagonal = np.random.uniform(0, 256, 
diagonal_shape).astype('uint8')
+            in_diagonal = tf.placeholder(dtype='float32', 
shape=diagonal.shape, name="diagonal")
+            inq_diagonal = tf.quantization.fake_quant_with_min_max_args(
+                in_diagonal,
+                min=-100,
+                max=100,
+                name="q_diagonal")
+
+            input_range = {'q_input': (-100, 100), 'q_diagonal': (-100, 100)}
+
+            out = array_ops.matrix_set_diag(inq_input, inq_diagonal)
+            out = tf.quantization.fake_quant_with_min_max_args(
+                out,
+                min=-100,
+                max=100,
+                name="out")
+
+            compare_tflite_with_tvm(
+                [input, diagonal],
+                ["q_input", "q_diagonal"],
+                [inq_input, inq_diagonal],
+                [out],
+                quantized=True,
+                input_range=input_range)
+        else:
+            input = np.random.uniform(0, 100, input_shape).astype(input_type)
+            diagonal = np.random.uniform(0, 100, 
diagonal_shape).astype(input_type)
+
+            in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, 
name="input")
+            in_diagonal = tf.placeholder(dtype=diagonal.dtype, 
shape=diagonal.shape, name="diagonal")
+
+            out = array_ops.matrix_set_diag(in_input, in_diagonal)
+
+            compare_tflite_with_tvm(
+                    [input, diagonal],
+                    ["input", "diagonal"],
+                    [in_input, in_diagonal],
+                    [out])
+
+def test_forward_matrix_set_diag():
+    """ MATRIX_SET_DIAG """

Review comment:
       The API docs seem to suggest that matrix_set_diag is present even in 
version '1.0'.
   So, is there some other reason to add this check?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to