Mousius commented on a change in pull request #10100:
URL: https://github.com/apache/tvm/pull/10100#discussion_r797819330
##########
File path: tests/python/contrib/test_cmsisnn/test_binary_ops.py
##########
@@ -131,6 +153,140 @@ def test_op_int8(op, input_0_scale, input_0_zero_point,
input_1_scale, input_1_z
)
+# At least one of the inputs is a constant
+def parameterize_for_constant_inputs(test):
+ op = [relay.qnn.op.mul, relay.qnn.op.add]
+ input_0_type = [
+ BinaryOpInputType.Variable,
+ BinaryOpInputType.TensorConstant,
+ BinaryOpInputType.ScalarConstant,
+ ]
+ input_1_type = [
+ BinaryOpInputType.Variable,
+ BinaryOpInputType.TensorConstant,
+ BinaryOpInputType.ScalarConstant,
+ ]
+ all_combinations = itertools.product(op, input_0_type, input_1_type)
+ all_combinations = filter(
+ lambda parameters: not (
+ (
+ parameters[1] == BinaryOpInputType.Variable
+ and parameters[2] == BinaryOpInputType.Variable
+ )
+ or (
+ parameters[1] == BinaryOpInputType.ScalarConstant
+ and parameters[2] == BinaryOpInputType.ScalarConstant
+ )
+ ),
+ all_combinations,
+ )
+ return pytest.mark.parametrize(
+ ["op", "input_0_type", "input_1_type"],
+ all_combinations,
+ )(test)
+
+
+@skip_if_no_reference_system
[email protected]_cmsisnn
+@parameterize_for_constant_inputs
+def test_constant_input_int8(op, input_0_type, input_1_type):
+ interface_api = "c"
+ use_unpacked_api = True
+ test_runner = AOT_CORSTONE300_RUNNER
+
+ dtype = "int8"
+ shape = [1, 16, 16, 3]
+ input_0_scale = 0.256
+ input_0_zero_point = 33
+ input_1_scale = 0.128
+ input_1_zero_point = -24
+ model = make_model(
+ op,
+ shape,
+ dtype,
+ dtype,
+ input_0_scale,
+ input_0_zero_point,
+ input_1_scale,
+ input_1_zero_point,
+ input_0_type=input_0_type,
+ input_1_type=input_1_type,
+ )
+ orig_mod = make_module(model)
+
+ cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+ # validate pattern matching
+ attrs = [
+ cmsisnn_mod[var.name_hint].attrs
+ for var in cmsisnn_mod.get_global_vars()
+ if cmsisnn_mod[var.name_hint].attrs
+ ]
+ assert any(attrs), "At least one function with external attributes was
expected."
+
+ compilers = [
+ key == "Compiler" and value == "cmsis-nn" for attr in attrs for key,
value in attr.items()
+ ]
+ assert any(compilers), "Module does not contain function for cmsisnn
target."
+
+ assert count_num_calls(orig_mod) == count_num_calls(
+ cmsisnn_mod
+ ), "Number of calls changed during partitioning"
Review comment:
Sounds good to me!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]