manupa-arm commented on a change in pull request #9589:
URL: https://github.com/apache/tvm/pull/9589#discussion_r760957009
##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -1070,5 +1070,78 @@ def representative_dataset():
infra.verify_source(compiled_models, accel_type)
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+ "shapes, axis",
+ [
+ ([(2, 3), (4, 3)], 0),
+ ([(3, 2, 1), (3, 1, 1)], 1),
+ ([(10,), (13,), (14,)], 0),
+ ([(1, 5, 2, 1), (1, 5, 7, 1), (1, 5, 3, 1)], 2),
+ ],
+)
+def test_tflite_concat(shapes, axis, accel_type):
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, shapes, axis):
+ op = tf.concat(shapes, axis)
+ return op
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ [tf.TensorSpec(shape, tf.float32) for shape in shapes], axis
+ )
+
+ def representative_dataset():
+ for _ in range(100):
+ datas = [np.random.rand(*shape) for shape in shapes]
+ yield [data.astype(np.float32) for data in datas]
+
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ relay_module, params = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(shapes)},
+ dtype_dict={("ifm" + str(i)): "int8" for i, _ in enumerate(shapes)},
+ )
+
+ mod = partition_for_ethosu(relay_module, params)
+
+ # Generate reference data
+ input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+ compiled_models = infra.build_source(
+ mod,
+ input_data,
+ output_data,
+ accel_type,
+ )
+
+ # Assumes only two runtime.Modules are created -- i.e. single offload
module
+ imported_modules = compiled_models[0].executor_factory.lib.imported_modules
Review comment:
I think this needs changing after target hooks but you ll find it anyway
after a rebase :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]