ekalda commented on a change in pull request #9589:
URL: https://github.com/apache/tvm/pull/9589#discussion_r760994549
##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -1070,5 +1070,78 @@ def representative_dataset():
infra.verify_source(compiled_models, accel_type)
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+ "shapes, axis",
+ [
+ ([(2, 3), (4, 3)], 0),
+ ([(3, 2, 1), (3, 1, 1)], 1),
+ ([(10,), (13,), (14,)], 0),
+ ([(1, 5, 2, 1), (1, 5, 7, 1), (1, 5, 3, 1)], 2),
+ ],
+)
+def test_tflite_concat(shapes, axis, accel_type):
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, shapes, axis):
+ op = tf.concat(shapes, axis)
+ return op
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ [tf.TensorSpec(shape, tf.float32) for shape in shapes], axis
+ )
+
+ def representative_dataset():
+ for _ in range(100):
+ datas = [np.random.rand(*shape) for shape in shapes]
+ yield [data.astype(np.float32) for data in datas]
+
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ relay_module, params = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(shapes)},
+ dtype_dict={("ifm" + str(i)): "int8" for i, _ in enumerate(shapes)},
+ )
+
+ mod = partition_for_ethosu(relay_module, params)
+
+ # Generate reference data
+ input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+ compiled_models = infra.build_source(
+ mod,
+ input_data,
+ output_data,
+ accel_type,
+ )
+
+ # Assumes only two runtime.Modules are created -- i.e. single offload
module
+ imported_modules = compiled_models[0].executor_factory.lib.imported_modules
Review comment:
I have resolved all the conflicts (for now) and all the tests pass
(except the mean tests which we will fix)!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]