u99127 commented on a change in pull request #5508:
URL: https://github.com/apache/incubator-tvm/pull/5508#discussion_r420955608



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -343,6 +343,36 @@ def test_forward_gather():
         _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
         _test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
 
+#######################################################################
+# Gather_ND
+# ---------
+
+def _test_gather_nd(data, indices):
+    """ One iteration of GATHER_ND """
+    with tf.Graph().as_default():
+        in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, 
name="data")
+        indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype,
+                                        name="indices")
+        out = tf.gather_nd(in_data, indices_data)
+
+        compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'],
+                                  [in_data, indices_data], [out])
+
+def test_forward_gather_nd():
+    """ GATHER_ND """
+    _test_gather_nd(
+        np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 
8.1]]]).astype('float32'),
+        np.asarray([[0, 1], [1, 0]]).astype('int32')

Review comment:
       Can we put in a test for int8 input tensor types as well ? 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to