kimm240 commented on code in PR #18173: URL: https://github.com/apache/tvm/pull/18173#discussion_r2297111596
########## python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py: ########## @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module provides a TVM Relax pass for fusing Conv2d-Reshape-Add-ReLU pattern.""" + +import tvm +from tvm import IRModule, relax +from tvm.relax.dpl.pattern import is_op, wildcard + +# Define a TVM module pass for fusing specific operations. +# @tvm.transform.module_pass decorates a class to turn it into a TVM IRModule pass. +# opt_level=0 means this pass can be run at any optimization level. +# name="FuseConv2dReshapeAddRelu" gives a descriptive name to the pass. + + [email protected]_pass(opt_level=0, name="FuseConv2dReshapeAddRelu") +class FuseConv2dReshapeAddRelu: Review Comment: @yongwww Excellent point! However, after checking the actual implementation, I've confirmed that the generic FuseOps cannot handle this specific pattern. # Summary The generic relax.transform.FuseOps pass is currently unable to fuse the common conv2d + bias + activation pattern when imported from PyTorch. The root cause is that the PyTorch frontend generates a conv2d -> reshape -> add sequence for the bias term, which the existing pattern matcher in FuseOps does not recognize. This leaves a critical, common pattern unoptimized. ## The Pattern Generated by the PyTorch Frontend When handling a torch.nn.Conv2d layer with bias=True, the PyTorch frontend consistently generates a reshape + add pattern for the bias. This is not specific to Conv2d and is standard behavior for other convolution types as well: Conv1d: See test_frontend_from_exported_program.py:1752-1753 Conv2d: See test_frontend_from_fx.py:269-270 Conv3d: See test_frontend_from_exported_program.py:3822-3823 ## Limitation of TVM's Current Pattern Matching The pattern designed to fuse bias and activation, make_fused_bias_activation_pattern, is defined in pattern.py:1179-1181. This function is currently implemented to match only a simple relax.add operation following the convolution. It cannot see past the reshape operation inserted by the frontend, thus failing to match the sequence. Proof by Code: A Reproducible Example The following test case demonstrates that FuseOps fails to fuse this pattern. ```python import torch import tvm from tvm import relax from tvm.relax.frontend.torch import from_fx # 1. PyTorch Conv2d model with bias and ReLU class Conv2dWithBias(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 6, 3, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(self.conv(x)) # 2. Trace and convert the model to TVM Relax IR model = Conv2dWithBias() graph_model = torch.fx.symbolic_trace(model) input_info = [([1, 3, 10, 10], "float32")] mod = from_fx(graph_model, input_info) print("### Original Relax IR (Before FuseOps):") print(mod) # 3. Apply the generic FuseOps pass fused_mod = relax.transform.FuseOps()(mod) print("\n### Relax IR After Applying FuseOps:") print(fused_mod) ``` ## Execution Results Converted IR (Before FuseOps): A sequence of four separate operations is generated: conv2d → reshape → add → relu. IR After FuseOps: The IR remains completely unchanged, confirming that the fusion failed. This failure is a direct result of the pattern in pattern.py:1179-1181 matching only relax.add and not the reshape + add sequence. ## Conclusion and Proposal The generic FuseOps pass cannot handle this frontend-specific pattern, leaving a common PyTorch model structure (conv2d + bias + relu) unoptimized. Therefore, a specialized pass like FuseConv2dReshapeAddRelu is essential to correctly identify and fuse this pattern. This targeted pass is necessary to bridge the gap between the PyTorch frontend's IR generation and TVM's optimization capabilities, unlocking performance for a wide range of models. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
