u99127 commented on a change in pull request #4788: [FRONTEND][TFLITE]Gather,
StridedSlice op support added
URL: https://github.com/apache/incubator-tvm/pull/4788#discussion_r374497747
##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -244,6 +244,74 @@ def test_forward_slice():
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0,
1], size=[-1, -1])
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4],
size=[-1])
+#######################################################################
+# Gather
+# ------
+
+def _test_gather(dshape, indices, axis, dtype):
+ """ One iteration of Gather """
+ data = np.random.uniform(1, 10, size=dshape).astype(dtype)
+ indices = np.asarray(indices).astype('int32')
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ out = array_ops.gather(in_data, indices, axis=axis)
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+ #Test quantized input
+ data = np.random.uniform(1, 10, size=dshape).astype(np.uint8)
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype,
name="in_data")
+ out = array_ops.gather(in_data, indices, axis=axis)
+ compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
quantized=True)
+
+def test_forward_gather():
+ """ GATHER """
+ _test_gather((4,), [1], 0, 'float32')
+ _test_gather((1, 4), [0], 0, 'int32')
+ _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
+ _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
+ _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
+ _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
+ _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
+ _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
+ _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
Review comment:
I think a testcase is missing to test that out of bounds indices behave as
tflite expects it to behave. See the documentation for relay.op.transform.take
with the last optional parameter for "mode"
See here for more . https://docs.tvm.ai/api/python/relay/op.html
Further, axis is an optional input to gather in tflite and _op.take . Do the
2 semantics match up as I cannot see this obviously in the tflite documentation
?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services