ekalda commented on a change in pull request #9547:
URL: https://github.com/apache/tvm/pull/9547#discussion_r757438005



##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -123,6 +124,80 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+def round_away_zero(f):
+    r = -0.5 if (f < 0) else 0.5
+    return np.trunc(f + r)
+
+
+def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
+    """Method to calculate the values of the tanh lookup table"""
+    lut_values = list()
+    # Only int8 is currently supported
+    dtype = np.int8
+    qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
+    for x in range(qmin, qmax + 1):
+        x_real = ifm_scale * (x - ifm_zp)
+        out_real = math.tanh(x_real)
+        lut_result = int(round_away_zero(ofm_zp + out_real / ofm_scale))
+        lut_result = min(qmax, max(qmin, lut_result))
+        lut_values.append(lut_result)
+
+    return lut_values
+
+
+class TanhRewriter(DFPatternCallback):
+    """This pass adds tanh as a LUT to the identity operator"""
+
+    def __init__(self):
+        super().__init__(require_type=True, rewrite_once=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": 
ethosu_patterns.TanhParams.composite_name})
+        )(wildcard())
+
+    def callback(self, pre, post, node_map):
+        id_input = post.args[0]
+
+        quantize_args = post.op.body.args
+        output_scale = float(quantize_args[1].data.asnumpy())
+        output_zp = int(quantize_args[2].data.asnumpy())
+
+        dequantize_args = quantize_args[0].args[0].args
+        input_scale = float(dequantize_args[1].data.asnumpy())
+        input_zp = int(dequantize_args[2].data.asnumpy())
+
+        lut_values = find_tanh_values(input_scale, input_zp, output_scale, 
output_zp)
+        lut = relay.const(lut_values, dtype="uint8")

Review comment:
       Currently, the values are calculated for an `int8` activation and then 
cast into `uint8` due to that can of worms dtype problem in the 
`tir_to_cs_translator.py`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to