mbaret commented on a change in pull request #9442:
URL: https://github.com/apache/tvm/pull/9442#discussion_r743845777
##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -423,11 +640,20 @@ class LegalizeEthosU:
def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
+ """This is the method that replace the operations with
hardware/codegen supported
Review comment:
that replaces
##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -343,5 +343,256 @@ def representative_dataset():
infra.verify_source(compiled_models, accel_type)
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+ "ifm_shape, ifm2_shape",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4]),
+ ([1, 2, 3, 4], [1, 1, 3, 1]),
+ ([1, 1, 3, 1], [1, 2, 3, 4]),
+ ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_ethosu_binary_elementwise(
+ accel_type,
+ operator_type,
+ ifm_shape,
+ ifm2_shape,
+ activation_function,
+):
+ dtype = "int8"
+
+ def create_tflite_graph():
+ tf.config.run_functions_eagerly(True)
+
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, lhs, rhs):
+ if operator_type == "ADD":
+ op = tf.math.add(lhs, rhs)
+ elif operator_type == "SUB":
+ op = tf.math.subtract(lhs, rhs)
+ elif operator_type == "MUL":
+ op = tf.math.multiply(lhs, rhs)
+ elif operator_type == "MIN":
+ op = tf.math.minimum(lhs, rhs)
+ elif operator_type == "MAX":
+ op = tf.math.maximum(lhs, rhs)
+ if activation_function == "RELU":
+ op = tf.nn.relu(op)
+ return op
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
Review comment:
Not necessarily something to resolve in this patch, but I wonder if we
can abstract some of the boilerplate code here into a utility function.
##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -458,6 +458,316 @@ def qnn_avgpool2d_pattern() ->
tvm.relay.dataflow_pattern.DFPattern:
return pattern
+class BinaryElementwiseParams:
+ """
+ This class will parse a call to a ethosu.binary_elementwise composite
function
+ and extract the parameter information.
+ """
+
+ def __init__(self, func_body: Call, operator_type: str,
has_quantization_parameters: bool):
+ clip = None
+ if str(func_body.op) == "clip":
+ clip = func_body
+ binary_op = clip.args[0]
+ else:
+ binary_op = func_body
+
+ layout = "NHWC"
+
+ if has_quantization_parameters:
+ self.ifm = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm.value],
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ifm_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value],
+ )
+ self.ifm2 = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm2.value],
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ifm2_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value],
+ )
+ self.ofm = TensorParams(
+ binary_op,
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ofm_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value],
+ )
+ else:
+ self.ifm = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm.value],
+ layout,
+ )
+ self.ifm2 = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm2.value],
+ layout,
+ )
+ self.ofm = TensorParams(
+ binary_op,
+ layout,
+ )
+ self.activation = clip
+ self.operator_type = operator_type
+
+ def brodcastable(x, y):
Review comment:
broadcastable (or 'can_broadcast' might be better)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]