NicolaLancellotti commented on a change in pull request #9627:
URL: https://github.com/apache/tvm/pull/9627#discussion_r762957516



##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -194,6 +205,48 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+def sigmoid_calc_func(x):

Review comment:
       ```suggestion
   def sigmoid_calc_func(x: float) -> float:
   ```

##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -125,30 +125,30 @@ def __call__(self, *args, **kwargs):
         pass
 
 
-def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
-    """Method to calculate the values of the tanh lookup table"""
+def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func):

Review comment:
       ```suggestion
   def get_lut_from_func(ifm_scale: float, ifm_zp: int, ofm_scale: float, 
ofm_zp: int, func: Callable[[float], float]) -> list[int]:
   ```

##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -125,30 +125,30 @@ def __call__(self, *args, **kwargs):
         pass
 
 
-def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
-    """Method to calculate the values of the tanh lookup table"""
+def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func):
+    """Method to calculate the values of the lookup table based on the 
calculation function"""
     lut_values = list()
     # Only int8 is currently supported
     dtype = np.int8
     qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
     for x in range(qmin, qmax + 1):
         x_real = ifm_scale * (x - ifm_zp)
-        out_real = math.tanh(x_real)
+        out_real = func(x_real)
         lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
         lut_result = min(qmax, max(qmin, lut_result))
         lut_values.append(lut_result)
 
     return lut_values
 
 
-class TanhRewriter(DFPatternCallback):
-    """This pass adds tanh as a LUT to the identity operator"""
+class LutActivationRewriter(DFPatternCallback):
+    """A class to create an identity operator with the LUT"""
 
-    def __init__(self):
+    def __init__(self, params_class, activation_type, calc_func):
         super().__init__(require_type=True, rewrite_once=True)
-        self.pattern = (
-            wildcard().has_attr({"Composite": 
ethosu_patterns.TanhParams.composite_name})
-        )(wildcard())
+        self.pattern = (wildcard().has_attr({"Composite": 
params_class.composite_name}))(wildcard())
+        self.activation_type = activation_type
+        self.calc_func = calc_func
 
     def callback(self, pre, post, node_map):

Review comment:
       ```suggestion
       def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: 
tvm.ir.container.Map):
   ```

##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -125,30 +125,30 @@ def __call__(self, *args, **kwargs):
         pass
 
 
-def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
-    """Method to calculate the values of the tanh lookup table"""
+def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func):
+    """Method to calculate the values of the lookup table based on the 
calculation function"""
     lut_values = list()
     # Only int8 is currently supported
     dtype = np.int8
     qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
     for x in range(qmin, qmax + 1):
         x_real = ifm_scale * (x - ifm_zp)
-        out_real = math.tanh(x_real)
+        out_real = func(x_real)
         lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
         lut_result = min(qmax, max(qmin, lut_result))
         lut_values.append(lut_result)
 
     return lut_values
 
 
-class TanhRewriter(DFPatternCallback):
-    """This pass adds tanh as a LUT to the identity operator"""
+class LutActivationRewriter(DFPatternCallback):
+    """A class to create an identity operator with the LUT"""
 
-    def __init__(self):
+    def __init__(self, params_class, activation_type, calc_func):

Review comment:
       ```suggestion
       def __init__(self, params_class: Type, activation_type: string, 
calc_func: Callable[[float], float]):
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to