This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 58434d16f0 [TFLite][Frontend] Generate name when tensor name is 
missing (#14819)
58434d16f0 is described below

commit 58434d16f0173ec1ab98edabe7a270e8d6856fac
Author: Ashutosh Parkhi <[email protected]>
AuthorDate: Mon May 15 11:44:25 2023 +0100

    [TFLite][Frontend] Generate name when tensor name is missing (#14819)
    
    After upgrade to TFLite 2.6, some networks have missing tensor names.
    This commit generates names with prefix tvmgen_ from TFLite frontend.
---
 python/tvm/relay/frontend/tflite.py                |  9 ++++++--
 tests/python/contrib/test_cmsisnn/test_networks.py | 27 ++++++++++++++++++++++
 tests/scripts/request_hook/request_hook.py         |  1 +
 3 files changed, 35 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index f5d9b5bbf2..9e2e244cb1 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -4091,7 +4091,12 @@ def get_tensor_name(subgraph, tensor_idx):
     -------
         tensor name in UTF-8 encoding
     """
-    return subgraph.Tensors(tensor_idx).Name().decode("utf-8")
+    tensor_name = subgraph.Tensors(tensor_idx).Name()
+    if tensor_name is not None:
+        tensor_name = tensor_name.decode("utf-8")
+    else:
+        tensor_name = "tvmgen_tensor_" + str(tensor_idx)
+    return tensor_name
 
 
 def _decode_type(n):
@@ -4125,7 +4130,7 @@ def _input_type(model):
             tensor = subgraph.Tensors(input_)
             input_shape = tuple(tensor.ShapeAsNumpy())
             tensor_type = tensor.Type()
-            input_name = tensor.Name().decode("utf8")
+            input_name = get_tensor_name(subgraph, input_)
             shape_dict[input_name] = input_shape
             dtype_dict[input_name] = _decode_type(tensor_type)
 
diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py 
b/tests/python/contrib/test_cmsisnn/test_networks.py
index 9f64be2461..16afffdcce 100644
--- a/tests/python/contrib/test_cmsisnn/test_networks.py
+++ b/tests/python/contrib/test_cmsisnn/test_networks.py
@@ -120,5 +120,32 @@ def test_cnn_small(test_runner):
     )
 
 
[email protected]_package("tflite")
+def test_keyword_scramble():
+    """Download keyword_scrambled and test for Relay conversion.
+    In future, this test can be extended for CMSIS-NN"""
+    # download the model
+    base_url = (
+        "https://github.com/tensorflow/tflite-micro/raw/";
+        "de8f61a074460e1fa5227d875c95aa303be01240/"
+        "tensorflow/lite/micro/models"
+    )
+    file_to_download = "keyword_scrambled.tflite"
+    file_saved = "keyword_scrambled.tflite"
+    model_file = download_testdata("{}/{}".format(base_url, file_to_download), 
file_saved)
+
+    with open(model_file, "rb") as f:
+        tflite_model_buf = f.read()
+
+    input_shape = (1, 96)
+    dtype = "int8"
+    in_min, in_max = get_dtype_range(dtype)
+    rng = np.random.default_rng(12345)
+    input_data = rng.integers(in_min, high=in_max, size=input_shape, 
dtype=dtype)
+
+    with pytest.raises(tvm.error.OpNotImplemented):
+        _, _ = _convert_to_relay(tflite_model_buf, input_data, "input")
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/scripts/request_hook/request_hook.py 
b/tests/scripts/request_hook/request_hook.py
index 3c193d84ae..dd92a92bc5 100644
--- a/tests/scripts/request_hook/request_hook.py
+++ b/tests/scripts/request_hook/request_hook.py
@@ -212,6 +212,7 @@ URL_MAP = {
     
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy":
 
f"{BASE}/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy",
     
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_person.jpg":
 f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_person.jpg",
     
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_not_person.jpg":
 f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_not_person.jpg",
+    
"https://github.com/tensorflow/tflite-micro/raw/de8f61a074460e1fa5227d875c95aa303be01240/tensorflow/lite/micro/models/keyword_scrambled.tflite":
 f"{BASE}/models/tflite/keyword_scrambled_8bit.tflite",
 }
 
 

Reply via email to