siju-samuel commented on a change in pull request #4788:
[FRONTEND][TFLITE]Gather, StridedSlice op support added
URL: https://github.com/apache/incubator-tvm/pull/4788#discussion_r379762146
##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -244,6 +244,74 @@ def test_forward_slice():
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0,
1], size=[-1, -1])
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4],
size=[-1])
+#######################################################################
+# Gather
+# ------
+
+def _test_gather(dshape, indices, axis, dtype):
+ """ One iteration of Gather """
+ data = np.random.uniform(1, 10, size=dshape).astype(dtype)
+ indices = np.asarray(indices).astype('int32')
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ out = array_ops.gather(in_data, indices, axis=axis)
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+ #Test quantized input
+ data = np.random.uniform(1, 10, size=dshape).astype(np.uint8)
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype,
name="in_data")
+ out = array_ops.gather(in_data, indices, axis=axis)
+ compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
quantized=True)
+
+def test_forward_gather():
+ """ GATHER """
+ _test_gather((4,), [1], 0, 'float32')
+ _test_gather((1, 4), [0], 0, 'int32')
+ _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
+ _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
+ _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
+ _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
+ _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
+ _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
+ _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
Review comment:
Tflite has seperate implementation for oob indices, for tflite cpu, the
return error and gpu they return 0 for the oob indices. but in my testing, oob
cases are not predictable.
TVM doesnt support returning zero for 'take' oob indices.
So while parsing, im checking whether the indices are oob and throwing
exception currently.
Testcases are added for the above scenario.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services