manupa-arm commented on a change in pull request #9589:
URL: https://github.com/apache/tvm/pull/9589#discussion_r760957009



##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -1070,5 +1070,78 @@ def representative_dataset():
     infra.verify_source(compiled_models, accel_type)
 
 
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+    "shapes, axis",
+    [
+        ([(2, 3), (4, 3)], 0),
+        ([(3, 2, 1), (3, 1, 1)], 1),
+        ([(10,), (13,), (14,)], 0),
+        ([(1, 5, 2, 1), (1, 5, 7, 1), (1, 5, 3, 1)], 2),
+    ],
+)
+def test_tflite_concat(shapes, axis, accel_type):
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, shapes, axis):
+                op = tf.concat(shapes, axis)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            [tf.TensorSpec(shape, tf.float32) for shape in shapes], axis
+        )
+
+        def representative_dataset():
+            for _ in range(100):
+                datas = [np.random.rand(*shape) for shape in shapes]
+                yield [data.astype(np.float32) for data in datas]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+
+        return tflite_model
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(shapes)},
+        dtype_dict={("ifm" + str(i)): "int8" for i, _ in enumerate(shapes)},
+    )
+
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+    )
+
+    # Assumes only two runtime.Modules are created -- i.e. single offload 
module
+    imported_modules = compiled_models[0].executor_factory.lib.imported_modules

Review comment:
       I think this needs changing after target hooks but you ll find it anyway 
after a rebase :) 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to