This is an automated email from the ASF dual-hosted git repository.
kazum pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 8a63b7f [TFLITE]GATHER_ND (#5508)
8a63b7f is described below
commit 8a63b7f37c64d04236e915a30f8df7e0e61ffefa
Author: Dhruva Ray <[email protected]>
AuthorDate: Mon May 18 08:18:17 2020 +0530
[TFLITE]GATHER_ND (#5508)
Signed-off-by: Dhruva Ray <[email protected]>
---
python/tvm/relay/frontend/tflite.py | 26 +++++++++++++++++++++++
tests/python/frontend/tflite/test_forward.py | 31 ++++++++++++++++++++++++++++
2 files changed, 57 insertions(+)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 5a645c6..cb10ce5 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -86,6 +86,7 @@ class OperatorConverter(object):
'FLOOR': self.convert_floor,
'FULLY_CONNECTED': self.convert_fully_connected,
'GATHER': self.convert_gather,
+ 'GATHER_ND' : self.convert_gather_nd,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
@@ -1113,6 +1114,31 @@ class OperatorConverter(object):
out = _op.take(data, indices, axis=axis, mode="fast")
return out
+ def convert_gather_nd(self, op):
+ """Method to Convert TFLite GATHER_ND operator"""
+ try:
+ from tflite.TensorType import TensorType
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ for t in input_tensors:
+ assert not t.qnn_params, "Quantized input is not expected."
+
+ data = self.get_tensor_expr(input_tensors[0])
+ indices = self.get_tensor_expr(input_tensors[1])
+
+ indices_type = input_tensors[1].tensor.Type()
+ assert indices_type in (TensorType.INT32, TensorType.INT64)
+
+ indices_dims = len(_infer_shape(indices))
+ indices_t = _op.transpose(indices, axes=[-1] +
list(range(indices_dims-1)))
+
+ out = _op.gather_nd(data, indices_t)
+ return out
+
def convert_strided_slice(self, op):
"""Method to Convert TFLite STRIDED_SLICE operator.
NOTE: Eventhough tensorflow supports begin_mask, end_mask,
ellipsis_mask, new_axis_mask
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 9963479..2319904 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -355,6 +355,36 @@ def test_forward_gather():
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
#######################################################################
+# Gather_ND
+# ---------
+
+def _test_gather_nd(data, indices):
+ """ One iteration of GATHER_ND """
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(shape=data.shape, dtype=data.dtype,
name="data")
+ indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype,
+ name="indices")
+ out = tf.gather_nd(in_data, indices_data)
+
+ compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'],
+ [in_data, indices_data], [out])
+
+def test_forward_gather_nd():
+ """ GATHER_ND """
+ _test_gather_nd(
+ np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1,
8.1]]]).astype('float32'),
+ np.asarray([[0, 1], [1, 0]]).astype('int32')
+ )
+ _test_gather_nd(
+ np.reshape(np.arange(30), [5, 6]).astype('int32'),
+ np.asarray([[1, 2]]).astype('int32')
+ )
+ _test_gather_nd(
+ np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
+ np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
+ )
+
+#######################################################################
# StridedSlice
# ------------
@@ -2217,6 +2247,7 @@ if __name__ == '__main__':
test_forward_slice()
test_forward_topk()
test_forward_gather()
+ test_forward_gather_nd()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()