This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 47ef9466b2 [Pytorch] Add quantized::leaky_relu (#11729)
47ef9466b2 is described below

commit 47ef9466b2751acb8c545ea0c3124c70870ec399
Author: yuanfz <[email protected]>
AuthorDate: Thu Jun 16 05:48:52 2022 +0200

    [Pytorch] Add quantized::leaky_relu (#11729)
    
    * emptycommit 2nd try
    
    * add operator and test
    
    * example output
    
    * lint with black
    
    * register param index
    
    * remove assert as it is a warning in torch
    
    * fix algo bug
    
    Co-authored-by: yuanfz <[email protected]>
---
 python/tvm/relay/frontend/qnn_torch.py    | 21 +++++++++++++++++++++
 tests/python/frontend/pytorch/qnn_test.py | 14 ++++++++++++++
 2 files changed, 35 insertions(+)

diff --git a/python/tvm/relay/frontend/qnn_torch.py 
b/python/tvm/relay/frontend/qnn_torch.py
index 41543ec611..63ee6ea96f 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -271,6 +271,7 @@ def _get_quant_param_for_input(input_value):
         "quantized::add_scalar": (2, 3),
         "quantized::hardswish": (1, 2),
         "quantized::conv_transpose2d": qconv_indices,
+        "quantized::leaky_relu": (3, 4),
     }
 
     def dfs(current_node):
@@ -443,6 +444,7 @@ def add_input_quant_params_to_op_inputs(graph):
         "quantized::hardswish": 1,
         "aten::hardsigmoid": 1,
         "quantized::conv_transpose2d": 1,
+        "quantized::leaky_relu": 1,
     }
 
     need_input_quant_param = set(num_quantized_inputs.keys())
@@ -935,6 +937,24 @@ def _relu6():
     return _impl
 
 
+def _leaky_relu():
+    # refer to src/ATen/native/quantized/cpu/qrelu.cpp
+    def _impl(inputs, _):
+        assert len(inputs) == 7, "Input quant params not found in op inputs"
+        alpha = inputs[1]
+        output_scale = _expr.const(inputs[3])
+        output_zero_point = _expr.const(inputs[4])
+        input_scale = _expr.const(inputs[5])
+        input_zero_point = _expr.const(inputs[6])
+        dequant = relay.qnn.op.dequantize(inputs[0], input_scale, 
input_zero_point)
+        dequantized = _op.nn.leaky_relu(dequant, alpha)
+        return relay.qnn.op.quantize(
+            dequantized, output_scale, output_zero_point, out_dtype="uint8"
+        )
+
+    return _impl
+
+
 def _mul_scalar():
     # this is used for mobilenet v3
     def _impl(inputs, _):
@@ -1131,6 +1151,7 @@ convert_map = {
     "quantized::add_scalar": _add_scalar(),
     "quantized::mul_scalar": _mul_scalar(),
     "quantized::relu6": _relu6(),
+    "quantized::leaky_relu": _leaky_relu(),
     "quantized::linear_dynamic": _linear_dynamic(),
     "quantized::hardswish": _hswish(),
     "quantized::conv_transpose2d": _quantized_conv_transpose2d(),
diff --git a/tests/python/frontend/pytorch/qnn_test.py 
b/tests/python/frontend/pytorch/qnn_test.py
index 6e87b9ee4f..ef7f3f769c 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -148,6 +148,18 @@ class ReLU(nn.Module):
         pass
 
 
+class LeakyReLU(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.leaky_relu = QuantWrapper(nn.LeakyReLU())
+
+    def forward(self, x):
+        return self.leaky_relu(x)
+
+    def fuse_model(self):
+        pass
+
+
 # Mobilenet V3 related modules
 class Hsigmoid(nn.Module):
     def __init__(self, add_stub=False):
@@ -302,6 +314,7 @@ def test_quantized_modules():
             ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), 
False),
             ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, 
add_stub=True), True),
             ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), 
False),
+            ("leaky_relu", imagenet_ishape, LeakyReLU(), False),
         ]
 
     for (module_name, ishape, raw_module, per_channel) in qmodules:
@@ -347,6 +360,7 @@ def test_quantized_modules():
         # sample outputs
         """
         relu 0.0039215684 2.6052087e-08 0.9999933567176871
+        leaky_relu 0.0 0.0 1.0
         upsample bilinear 0.0 0.0 1.0
         conv_bn 0.22062653 0.011478779 0.6909348115006899
         conv_bn_relu 0.3700896 0.010921672 0.7489366477964451

Reply via email to