This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 474cde494d [Optimization][Operator] Implement and enable 
Conv2d-Reshape-Add-ReLU fusion (#18240)
474cde494d is described below

commit 474cde494dac47cab60d8cd4181ae22bfee98bf5
Author: kimm240 <[email protected]>
AuthorDate: Mon Mar 16 13:10:50 2026 +0900

    [Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU 
fusion (#18240)
    
    This commit extends the make_fused_bias_activation_pattern function to
    support
    PyTorch frontend's specific IR generation pattern for convolution
    operations
    with bias. When PyTorch models with bias=True are converted to Relax IR,
    the
    frontend generates a conv2d -> reshape -> add -> relu sequence instead
    of the
    simpler conv2d -> add -> relu pattern that existing fusion logic
    expected.
    
    The key changes include:
    
    1. Add allow_reshape parameter to make_fused_bias_activation_pattern in
    both
    dpl/pattern.py and backend/patterns.py with default value False to
    maintain
       backward compatibility.
    
    2. When allow_reshape=True, the pattern matcher now recognizes and fuses
    the
    complete conv2d -> reshape -> add -> relu sequence into a single
    composite
    function, eliminating intermediate tensor allocations and kernel launch
       overhead.
    
    3. The original pattern (allow_reshape=False) only fuses conv2d -> add
    -> relu,
    leaving the reshape operation outside the fused function, which results
    in
       suboptimal performance for PyTorch-originated models.
    
    This enhancement enables more efficient operator fusion for PyTorch
    models,
    reducing memory usage and improving execution performance by capturing
    the
    complete computation pattern in a single fused kernel. The
    implementation
    maintains full backward compatibility while extending support for
    PyTorch
    frontend's specific IR generation patterns.
    
    Comprehensive tests are added to verify the fusion behavior with both
    old and
    new patterns, ensuring correctness across different convolution types
    (Conv1d,
    Conv2d, Conv3d) and validating that fusion only occurs when appropriate
    conditions are met.
    
    ---------
    
    Co-authored-by: kim hyun gyu <[email protected]>
---
 python/tvm/relax/backend/patterns.py               |  10 +-
 python/tvm/relax/dpl/pattern.py                    |  13 +-
 .../relax/test_fuse_pytorch_conv2d_bias_pattern.py | 158 +++++++++++++++++++++
 3 files changed, 177 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 9d47bb8354..c9c853d702 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -37,10 +37,15 @@ def _with_bias_activation_pattern(
     annotations: dict[str, DFPattern],
     with_bias: bool = False,
     activation: str | None = None,
+    allow_reshape: bool = False,
 ) -> tuple[DFPattern, Mapping[str, DFPattern]]:
     if with_bias:
         annotations["bias"] = bias = wildcard()
-        out = is_op("relax.add")(out, bias)
+        if allow_reshape:
+            reshaped_bias = is_op("relax.reshape")(bias, wildcard(), 
varg_default_wildcard=True)
+            out = is_op("relax.add")(out, reshaped_bias, 
varg_default_wildcard=True)
+        else:
+            out = is_op("relax.add")(out, bias)
 
     if activation:
         out = is_op(activation)(out)
@@ -52,6 +57,7 @@ def make_fused_bias_activation_pattern(
     op_name: str,
     with_bias: bool = False,
     activation: str | None = None,
+    allow_reshape: bool = False,
 ) -> tuple[DFPattern, Mapping[str, DFPattern]]:
     """
     A simple utility to create patterns for an operation fused with bias 
addition and activation.
@@ -82,7 +88,7 @@ def make_fused_bias_activation_pattern(
     out = is_op(op_name)(lhs, rhs)
     annotations = {"lhs": lhs, "rhs": rhs, "root": out}
 
-    return _with_bias_activation_pattern(out, annotations, with_bias, 
activation)
+    return _with_bias_activation_pattern(out, annotations, with_bias, 
activation, allow_reshape)
 
 
 def make_residual_block_pattern(
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index ad7d1b7421..f485d8bbbd 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -1121,7 +1121,9 @@ def _only_used_by(lhs: DFPattern | PatternSeq, rhs: 
DFPattern | PatternSeq, inde
     return ffi.only_used_by(lhs, rhs, index)  # type: ignore
 
 
-def make_fused_bias_activation_pattern(op_name, with_bias=False, 
activation=None):
+def make_fused_bias_activation_pattern(
+    op_name, with_bias=False, activation=None, allow_reshape=False
+):
     """
     A simple utility to create patterns for an operation fused with bias 
addition and activation.
 
@@ -1136,6 +1138,9 @@ def make_fused_bias_activation_pattern(op_name, 
with_bias=False, activation=None
     activation: str
         The name of an activation Relax op, such as "relax.nn.relu"
 
+    allow_reshape: bool
+        Whether to allow reshape operation before bias addition (for PyTorch 
frontend)
+
     Returns
     -------
     pattern: DFPattern
@@ -1147,7 +1152,11 @@ def make_fused_bias_activation_pattern(op_name, 
with_bias=False, activation=None
 
     if with_bias:
         bias = wildcard()
-        out = is_op("relax.add")(out, bias)
+        if allow_reshape:
+            reshaped_bias = is_op("relax.reshape")(bias, wildcard(), 
varg_default_wildcard=True)
+            out = is_op("relax.add")(out, reshaped_bias, 
varg_default_wildcard=True)
+        else:
+            out = is_op("relax.add")(out, bias)
 
     if activation:
         return is_op(activation)(out)
diff --git a/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py 
b/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py
new file mode 100644
index 0000000000..ca9736af9f
--- /dev/null
+++ b/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import torch
+
+from tvm import relax
+from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern
+from tvm.relax.frontend.torch import from_fx
+
+
+def test_conv2d_bias_relu_fusion():
+    """Test PyTorch conv2d + bias + relu fusion with reshape pattern"""
+
+    class Conv2dBiasRelu(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, x):
+            return self.relu(self.conv(x))
+
+    # Convert PyTorch model to Relax IR
+    model = Conv2dBiasRelu()
+    graph_model = torch.fx.symbolic_trace(model)
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info)
+
+    # Apply fusion with modified pattern
+    patterns = [
+        (
+            "conv2d_bias_activation_with_reshape",
+            make_fused_bias_activation_pattern(
+                "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", 
allow_reshape=True
+            ),
+        )
+    ]
+
+    fused_mod = relax.transform.FuseOpsByPattern(patterns, 
bind_constants=False)(mod)
+
+    # Verify fusion occurred
+    fused_functions = [name for name in fused_mod.functions.keys() if "fused" 
in str(name)]
+
+    assert len(fused_functions) == 1, "Expected exactly one fused function"
+
+    # Verify the fused function contains all operations
+    fused_func = fused_mod[fused_functions[0]]
+    assert hasattr(fused_func, "attrs"), "Fused function should have 
attributes"
+    assert "Composite" in fused_func.attrs, "Fused function should have 
Composite attribute"
+
+
+def test_conv2d_bias_relu_fusion_comparison():
+    """Compare fusion with and without allow_reshape option"""
+
+    class Conv2dBiasRelu(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, x):
+            return self.relu(self.conv(x))
+
+    model = Conv2dBiasRelu()
+    graph_model = torch.fx.symbolic_trace(model)
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info)
+
+    # Test with allow_reshape=False
+    old_patterns = [
+        (
+            "conv2d_bias_activation_old",
+            make_fused_bias_activation_pattern(
+                "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", 
allow_reshape=False
+            ),
+        )
+    ]
+
+    old_fused_mod = relax.transform.FuseOpsByPattern(old_patterns, 
bind_constants=False)(mod)
+
+    # Test with allow_reshape=True
+    new_patterns = [
+        (
+            "conv2d_bias_activation_new",
+            make_fused_bias_activation_pattern(
+                "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", 
allow_reshape=True
+            ),
+        )
+    ]
+
+    new_fused_mod = relax.transform.FuseOpsByPattern(new_patterns, 
bind_constants=False)(mod)
+
+    # Both should create fused functions
+    old_fused_functions = [name for name in old_fused_mod.functions.keys() if 
"fused" in str(name)]
+    new_fused_functions = [name for name in new_fused_mod.functions.keys() if 
"fused" in str(name)]
+
+    assert len(old_fused_functions) >= 1, "Old pattern should create at least 
one fused function"
+    assert len(new_fused_functions) >= 1, "New pattern should create at least 
one fused function"
+
+
+def test_conv2d_no_fusion_case():
+    """Test case where fusion should not occur"""
+
+    class Conv2dNoBias(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 3, bias=False)
+
+        def forward(self, x):
+            return self.conv(x)
+
+    model = Conv2dNoBias()
+    graph_model = torch.fx.symbolic_trace(model)
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info)
+
+    # Apply fusion pattern
+    patterns = [
+        (
+            "conv2d_bias_activation",
+            make_fused_bias_activation_pattern(
+                "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", 
allow_reshape=True
+            ),
+        )
+    ]
+
+    fused_mod = relax.transform.FuseOpsByPattern(patterns, 
bind_constants=False)(mod)
+
+    # No fusion should occur
+    fused_functions = [name for name in fused_mod.functions.keys() if "fused" 
in str(name)]
+
+    assert len(fused_functions) == 0, "No fusion should occur for conv2d 
without bias and relu"
+
+
+if __name__ == "__main__":
+    test_conv2d_bias_relu_fusion()
+    test_conv2d_bias_relu_fusion_comparison()
+    test_conv2d_no_fusion_case()

Reply via email to