u99127 commented on a change in pull request #4788: [FRONTEND][TFLITE]Gather, 
StridedSlice op support added
URL: https://github.com/apache/incubator-tvm/pull/4788#discussion_r374497747
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -244,6 +244,74 @@ def test_forward_slice():
         _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 
1], size=[-1, -1])
         _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], 
size=[-1])
 
+#######################################################################
+# Gather
+# ------
+
+def _test_gather(dshape, indices, axis, dtype):
+    """ One iteration of Gather """
+    data = np.random.uniform(1, 10, size=dshape).astype(dtype)
+    indices = np.asarray(indices).astype('int32')
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        out = array_ops.gather(in_data, indices, axis=axis)
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+    #Test quantized input
+    data = np.random.uniform(1, 10, size=dshape).astype(np.uint8)
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, 
name="in_data")
+        out = array_ops.gather(in_data, indices, axis=axis)
+        compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], 
quantized=True)
+
+def test_forward_gather():
+    """ GATHER """
+    _test_gather((4,), [1], 0, 'float32')
+    _test_gather((1, 4), [0], 0, 'int32')
+    _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
+    _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
+    _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
+    _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
+    _test_gather((3, 3, 3),  [[[1, 0]]], 0, 'int32')
+    _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
+    _test_gather((4, 3, 5, 6),  [[2, 1, 0, 0]], 0, 'float32')
 
 Review comment:
   I think a testcase is missing to test that out of bounds indices behave as 
tflite expects it to behave. See the documentation for relay.op.transform.take 
with the last optional parameter for "mode" 
   
   See here for more . https://docs.tvm.ai/api/python/relay/op.html
   
   Further, axis is an optional input to gather in tflite and _op.take . Do the 
2 semantics match up as I cannot see this obviously in the tflite documentation 
? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to