This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ca5a7b0a6b081bd84ca797495ec8b4eb97cfd248 Author: Andrew Zhao Luo <[email protected]> AuthorDate: Fri Sep 16 16:58:09 2022 -0700 pattern matching --- python/tvm/relay/qnn/transform.py | 78 ++++++++++++++++++++++ .../transform/fake_quantization_to_integer.py | 2 +- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py index 0485cecb99..7b42942c8b 100644 --- a/python/tvm/relay/qnn/transform.py +++ b/python/tvm/relay/qnn/transform.py @@ -114,3 +114,81 @@ def Legalize(): """ return relay.transform.Legalize("FTVMQnnLegalize") + + +from tvm.relay.dataflow_pattern import ( + DFPatternCallback, + is_constant, + is_expr, + is_op, + rewrite, + wildcard, +) + + +class RSqrtPattern(DFPatternCallback): + """ + Rewrites QNN RSQRT Pattern + """ + + def __init__(self): + super(RSqrtPattern, self).__init__() + + self.sqrt_data = wildcard() + self.sqrt_data_input_scale = wildcard() + self.sqrt_data_input_zp = wildcard() + + self.numerator = wildcard() + self.numerator_scale = wildcard() + self.numerator_zp = wildcard() + + self.output_scale = wildcard() + self.output_zp = wildcard() + + self.sqrt = is_op("qnn.sqrt")( + self.sqrt_data, + self.sqrt_data_input_scale, + self.sqrt_data_input_zp, + wildcard(), + wildcard(), + ) + + # TODO: match axis properly + self.rsqrt = is_op("qnn.div")( + self.numerator, + self.sqrt, + self.numerator_scale, + self.numerator_zp, + wildcard(), + wildcard(), + self.output_scale, + self.output_zp, + ) + + self.pattern = self.rsqrt + + def callback(self, pre, post, node_map): + sqrt_data = node_map[self.sqrt_data][0] + sqrt_data_scale = node_map[self.sqrt_data_input_scale][0] + sqrt_data_zp = node_map[self.sqrt_data_input_zp][0] + + numerator = node_map[self.numerator][0] + numerator_scale = node_map[self.numerator_scale][0] + numerator_zp = node_map[self.numerator_zp][0] + + output_scale = node_map[self.output_scale][0] + output_zp = node_map[self.output_zp][0] + + rsqrt = relay.qnn.op.rsqrt( + sqrt_data, sqrt_data_scale, sqrt_data_zp, numerator_scale, numerator_zp + ) + return relay.qnn.op.mul( + numerator, + rsqrt, + numerator_scale, + numerator_zp, + numerator_scale, + numerator_zp, + output_scale, + output_zp, + ) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 82afb1b4c3..d05de3a50f 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -671,4 +671,4 @@ register_unary_qnn("sigmoid", relay.qnn.op.sigmoid) register_unary_qnn("hardswish", relay.qnn.op.hardswish) register_unary_qnn("tanh", relay.qnn.op.tanh) register_unary_qnn("abs", relay.qnn.op.abs) -register_unary_qnn("log", relay.qnn.op.log) \ No newline at end of file +register_unary_qnn("log", relay.qnn.op.log)
