This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 58434d16f0 [TFLite][Frontend] Generate name when tensor name is
missing (#14819)
58434d16f0 is described below
commit 58434d16f0173ec1ab98edabe7a270e8d6856fac
Author: Ashutosh Parkhi <[email protected]>
AuthorDate: Mon May 15 11:44:25 2023 +0100
[TFLite][Frontend] Generate name when tensor name is missing (#14819)
After upgrade to TFLite 2.6, some networks have missing tensor names.
This commit generates names with prefix tvmgen_ from TFLite frontend.
---
python/tvm/relay/frontend/tflite.py | 9 ++++++--
tests/python/contrib/test_cmsisnn/test_networks.py | 27 ++++++++++++++++++++++
tests/scripts/request_hook/request_hook.py | 1 +
3 files changed, 35 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index f5d9b5bbf2..9e2e244cb1 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -4091,7 +4091,12 @@ def get_tensor_name(subgraph, tensor_idx):
-------
tensor name in UTF-8 encoding
"""
- return subgraph.Tensors(tensor_idx).Name().decode("utf-8")
+ tensor_name = subgraph.Tensors(tensor_idx).Name()
+ if tensor_name is not None:
+ tensor_name = tensor_name.decode("utf-8")
+ else:
+ tensor_name = "tvmgen_tensor_" + str(tensor_idx)
+ return tensor_name
def _decode_type(n):
@@ -4125,7 +4130,7 @@ def _input_type(model):
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
- input_name = tensor.Name().decode("utf8")
+ input_name = get_tensor_name(subgraph, input_)
shape_dict[input_name] = input_shape
dtype_dict[input_name] = _decode_type(tensor_type)
diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py
b/tests/python/contrib/test_cmsisnn/test_networks.py
index 9f64be2461..16afffdcce 100644
--- a/tests/python/contrib/test_cmsisnn/test_networks.py
+++ b/tests/python/contrib/test_cmsisnn/test_networks.py
@@ -120,5 +120,32 @@ def test_cnn_small(test_runner):
)
[email protected]_package("tflite")
+def test_keyword_scramble():
+ """Download keyword_scrambled and test for Relay conversion.
+ In future, this test can be extended for CMSIS-NN"""
+ # download the model
+ base_url = (
+ "https://github.com/tensorflow/tflite-micro/raw/"
+ "de8f61a074460e1fa5227d875c95aa303be01240/"
+ "tensorflow/lite/micro/models"
+ )
+ file_to_download = "keyword_scrambled.tflite"
+ file_saved = "keyword_scrambled.tflite"
+ model_file = download_testdata("{}/{}".format(base_url, file_to_download),
file_saved)
+
+ with open(model_file, "rb") as f:
+ tflite_model_buf = f.read()
+
+ input_shape = (1, 96)
+ dtype = "int8"
+ in_min, in_max = get_dtype_range(dtype)
+ rng = np.random.default_rng(12345)
+ input_data = rng.integers(in_min, high=in_max, size=input_shape,
dtype=dtype)
+
+ with pytest.raises(tvm.error.OpNotImplemented):
+ _, _ = _convert_to_relay(tflite_model_buf, input_data, "input")
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/scripts/request_hook/request_hook.py
b/tests/scripts/request_hook/request_hook.py
index 3c193d84ae..dd92a92bc5 100644
--- a/tests/scripts/request_hook/request_hook.py
+++ b/tests/scripts/request_hook/request_hook.py
@@ -212,6 +212,7 @@ URL_MAP = {
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy":
f"{BASE}/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy",
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_person.jpg":
f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_person.jpg",
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_not_person.jpg":
f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_not_person.jpg",
+
"https://github.com/tensorflow/tflite-micro/raw/de8f61a074460e1fa5227d875c95aa303be01240/tensorflow/lite/micro/models/keyword_scrambled.tflite":
f"{BASE}/models/tflite/keyword_scrambled_8bit.tflite",
}