lhutton1 commented on a change in pull request #9910:
URL: https://github.com/apache/tvm/pull/9910#discussion_r788985972
##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -987,5 +990,36 @@ def split_func(x):
_compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+ "ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
+ [
+ [(1, 8, 8, 3), 1.0, 0, 1.0, 0],
+ [(1, 20, 30, 3), 1.345, 34, 0.32, -23],
+ ],
+)
+def test_ethosu_requantize(accel_type, ifm_shape, ifm_scale, ifm_zp,
ofm_scale, ofm_zp):
+ dtype = "int8"
+ ifm_shape = [1, 8, 8, 3]
Review comment:
oops, thanks!
##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -1502,5 +1503,105 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])
[email protected](
+ "ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
+ [[(1, 8, 8, 3), 1.0, 0, 1.0, 0], [(1, 20, 30, 3), 1.345, 34, 0.32, -23]],
+)
+def test_ethosu_requantize(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp):
+ dtype = "int8"
+ ifm_shape = [1, 8, 8, 3]
+
+ def create_model():
+ ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+ requantize = relay.qnn.op.requantize(
+ ifm,
+ relay.const(ifm_scale, dtype="float32"),
+ relay.const(ifm_zp, dtype="int32"),
+ relay.const(ofm_scale, dtype="float32"),
+ relay.const(ofm_zp, dtype="int32"),
+ )
+ return tvm.IRModule.from_expr(relay.Function([ifm], requantize))
+
+ def verify(ext_func):
+ op = ext_func.body
+
+ # Check IFM
+ ifm = op.args[0].checked_type
+ assert list(ifm.shape) == list(ifm_shape)
+ assert str(ifm.dtype) == dtype
+
+ # Check OFM
+ ofm = op.checked_type
+ assert list(ofm.shape) == list(ifm_shape)
+ assert str(ofm.dtype) == dtype
+
+ # Check quantization params
+ assert math.isclose(op.attrs.ifm_scale, ifm_scale, abs_tol=1e-7)
+ assert op.attrs.ifm_zero_point == ifm_zp
+ assert math.isclose(op.attrs.ofm_scale, ofm_scale, abs_tol=1e-7)
+ assert op.attrs.ofm_zero_point == ofm_zp
+
+ rewriter = legalize.RequantizeRewriter()
+ pattern_table = [
+ (
+ ethosu.RequantizeParams.composite_name,
+ ethosu.requantize_pattern(),
+ lambda pat: ethosu.RequantizeParams(pat).is_valid(),
+ ),
+ ]
+
+ mod = create_model()
+ mod = partition_ethosu_by_table(mod, pattern_table)
+
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ rewriter, mod["tvmgen_default_ethos_u_main_0"]
+ )
+ verify(mod["tvmgen_default_ethos_u_main_0"])
+
+
+def test_multiple_requantize_offload():
+ """
+ Testing requantize offload in the case one requauntize operation is part of
Review comment:
good catch, thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]