kimm240 commented on code in PR #18173:
URL: https://github.com/apache/tvm/pull/18173#discussion_r2297111596


##########
python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py:
##########
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module provides a TVM Relax pass for fusing Conv2d-Reshape-Add-ReLU 
pattern."""
+
+import tvm
+from tvm import IRModule, relax
+from tvm.relax.dpl.pattern import is_op, wildcard
+
+# Define a TVM module pass for fusing specific operations.
+# @tvm.transform.module_pass decorates a class to turn it into a TVM IRModule 
pass.
+# opt_level=0 means this pass can be run at any optimization level.
+# name="FuseConv2dReshapeAddRelu" gives a descriptive name to the pass.
+
+
[email protected]_pass(opt_level=0, name="FuseConv2dReshapeAddRelu")
+class FuseConv2dReshapeAddRelu:

Review Comment:
   @yongwww 
   Excellent point! However, after checking the actual implementation, I've 
confirmed that the generic FuseOps cannot handle this specific pattern.
   
   # Summary
   The generic relax.transform.FuseOps pass is currently unable to fuse the 
common conv2d + bias + activation pattern when imported from PyTorch. The root 
cause is that the PyTorch frontend generates a conv2d -> reshape -> add 
sequence for the bias term, which the existing pattern matcher in FuseOps does 
not recognize. This leaves a critical, common pattern unoptimized.
   
   ## The Pattern Generated by the PyTorch Frontend
   When handling a torch.nn.Conv2d layer with bias=True, the PyTorch frontend 
consistently generates a reshape + add pattern for the bias. This is not 
specific to Conv2d and is standard behavior for other convolution types as well:
   
   Conv1d: See test_frontend_from_exported_program.py:1752-1753
   
   Conv2d: See test_frontend_from_fx.py:269-270
   
   Conv3d: See test_frontend_from_exported_program.py:3822-3823
   
   ## Limitation of TVM's Current Pattern Matching
   The pattern designed to fuse bias and activation, 
make_fused_bias_activation_pattern, is defined in pattern.py:1179-1181. This 
function is currently implemented to match only a simple relax.add operation 
following the convolution. It cannot see past the reshape operation inserted by 
the frontend, thus failing to match the sequence.
   
   Proof by Code: A Reproducible Example
   The following test case demonstrates that FuseOps fails to fuse this pattern.
   
   ```python
   import torch
   import tvm
   from tvm import relax
   from tvm.relax.frontend.torch import from_fx
   
   # 1. PyTorch Conv2d model with bias and ReLU
   class Conv2dWithBias(torch.nn.Module):
       def __init__(self):
           super().__init__()
           self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
           self.relu = torch.nn.ReLU()
   
       def forward(self, x):
           return self.relu(self.conv(x))
   
   # 2. Trace and convert the model to TVM Relax IR
   model = Conv2dWithBias()
   graph_model = torch.fx.symbolic_trace(model)
   input_info = [([1, 3, 10, 10], "float32")]
   mod = from_fx(graph_model, input_info)
   
   print("### Original Relax IR (Before FuseOps):")
   print(mod)
   
   # 3. Apply the generic FuseOps pass
   fused_mod = relax.transform.FuseOps()(mod)
   
   print("\n### Relax IR After Applying FuseOps:")
   print(fused_mod)
   ```
   
   ## Execution Results
   Converted IR (Before FuseOps): A sequence of four separate operations is 
generated: conv2d → reshape → add → relu.
   
   IR After FuseOps: The IR remains completely unchanged, confirming that the 
fusion failed.
   
   This failure is a direct result of the pattern in pattern.py:1179-1181 
matching only relax.add and not the reshape + add sequence.
   
   ## Conclusion and Proposal
   The generic FuseOps pass cannot handle this frontend-specific pattern, 
leaving a common PyTorch model structure (conv2d + bias + relu) unoptimized.
   
   Therefore, a specialized pass like FuseConv2dReshapeAddRelu is essential to 
correctly identify and fuse this pattern. This targeted pass is necessary to 
bridge the gap between the PyTorch frontend's IR generation and TVM's 
optimization capabilities, unlocking performance for a wide range of models.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to