lhutton1 commented on a change in pull request #9442:
URL: https://github.com/apache/tvm/pull/9442#discussion_r742958078
##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -558,5 +558,193 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethosu_main_0"])
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+ "ifm_shape, ifm2_shape, reversed_operands",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4], False),
+ ([1, 2, 3, 4], [1, 1, 3, 1], False),
+ ([1, 1, 3, 1], [1, 2, 3, 4], True),
+ ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_tflite_binary_elemwise_legalize(
+ operator_type,
+ ifm_shape,
+ ifm2_shape,
+ reversed_operands,
+ activation_function,
+):
+ dtype = "int8"
+
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, x, y):
+ if operator_type == "ADD":
+ op = tf.math.add(x, y)
+ elif operator_type == "SUB":
+ op = tf.math.subtract(x, y)
+ elif operator_type == "MUL":
+ op = tf.math.multiply(x, y)
+ elif operator_type == "MIN":
+ op = tf.math.minimum(x, y)
+ elif operator_type == "MAX":
+ op = tf.math.maximum(x, y)
+ if activation_function == "RELU":
+ op = tf.nn.relu(op)
+ return op
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ tf.TensorSpec(ifm_shape, dtype=tf.float32),
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+ )
+
+ # Convert the model
+ def representative_dataset():
+ for _ in range(100):
+ data = np.random.rand(*tuple(ifm_shape))
+ data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+ yield [data.astype(np.float32), data2.astype(np.float32)]
+
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+ return tflite_model
+
+ def verify(ext_func):
+ out_shape = ifm2_shape if reversed_operands else ifm_shape
+ shapes = [ifm_shape, ifm2_shape]
+ ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
+ op = ext_func.body
+ assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
+ assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
+ assert op.args[0].checked_type.dtype == dtype
+ assert list(op.checked_type.shape) == out_shape
+ assert op.checked_type.dtype == dtype
+ assert op.attrs.operator_type == operator_type
+ assert op.attrs.reversed_operands == reversed_operands
+ if activation_function == "RELU":
+ assert str(op.attrs.activation) == "CLIP"
+
+ if operator_type == "ADD":
+ rewriter = legalize.AddRewriter()
+ pattern_table = [
+ (
+ ethosu.AddParams.composite_name,
+ ethosu.qnn_add_pattern(),
+ lambda pat: ethosu.AddParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "SUB":
+ rewriter = legalize.SubRewriter()
+ pattern_table = [
+ (
+ ethosu.SubParams.composite_name,
+ ethosu.qnn_subtract_pattern(),
+ lambda pat: ethosu.SubParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "MUL":
+ rewriter = legalize.MulRewriter()
+ pattern_table = [
+ (
+ ethosu.MulParams.composite_name,
+ ethosu.qnn_mul_pattern(),
+ lambda pat: ethosu.MulParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "MIN":
+ rewriter = legalize.MinRewriter()
+ pattern_table = [
+ (
+ ethosu.MinParams.composite_name,
+ ethosu.minimum_pattern(),
+ lambda pat: ethosu.MinParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "MAX":
+ rewriter = legalize.MaxRewriter()
+ pattern_table = [
+ (
+ ethosu.MaxParams.composite_name,
+ ethosu.maximum_pattern(),
+ lambda pat: ethosu.MaxParams(pat).is_valid(),
+ ),
+ ]
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ mod, _ = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={"x": ifm_shape, "y": ifm2_shape},
+ dtype_dict={"x": dtype, "y": dtype},
+ )
+ mod = partition_ethosu_by_table(mod, pattern_table)
+
+ mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+ rewriter, mod["tvmgen_default_ethosu_main_0"]
+ )
+ verify(mod["tvmgen_default_ethosu_main_0"])
+
+
[email protected](
+ "ifm_shape, ifm2_shape, reversed_operands",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4], False),
+ ([1, 2, 3, 4], [1, 1, 3, 1], False),
+ ([1, 1, 3, 1], [1, 2, 3, 4], True),
+ ],
+)
+def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape,
reversed_operands):
+ dtype = "int32"
+ operator_type = "SHL"
+
+ def create_graph():
+ input1 = relay.var("x1", shape=ifm_shape, dtype=dtype)
+ input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype)
+ c1 = relay.left_shift(input1, input2)
Review comment:
Can we not legalize a right shift operator to right shift followed by
element-wise subtract of 1?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]