yangulei opened a new pull request, #13213: URL: https://github.com/apache/tvm/pull/13213
This PR enables simplification and folding of a sub graph containing adjacent `mul`s and `add`s with constant inputs. ### Motivation Workloads like [densenet-121](https://github.com/onnx/models/blob/main/vision/classification/densenet-121/model/densenet-7.onnx) has several partitions with `conv-bn-mul-add-relu` pattern, for example: ``` go def @main(%data_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */) -> Tensor[(1, 1000, 1, 1), float32] { %0 = nn.conv2d(%data_0, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */; %1 = nn.batch_norm(%0, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */, meta[relay.Constant][4] /* ty=Tensor[(64), float32] */) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) */; %2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] */; %3 = multiply(%2, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */; %4 = add(%3, meta[relay.Constant][6] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */; %5 = nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] */; // ... } ``` Current transforms on this pattern are: 1. `conv-bn-mul-add-relu` as the original pattern. 2. to `conv-mul-add-mul-add-relu` as `bn` is expended to `mul-add`. 3. to `conv-add-mul-add-relu` as the first `mul` is folded into `conv`. As all the `mul`s and `add`s have constant second inputs, they should be folded to a single `mul-add` and a preferred transform sequence should be: 1. `conv-bn-mul-add-relu` as the original pattern. 2. to `conv-mul-add-mul-add-relu` as `bn` is expended to `mul-add`. 3. to `conv-mul-add-relu` as `mul`s and `add`s are folded to one single `mul-add`. 4. to `conv-add-relu` as `mul` is folded into `conv`. ### Solution Actually, any series contain `mul`s and `add`s with constant inputs could be folded to one particular `mul-add`. Three rewrite rules are added to make this happen: 1. `mul-mul` -> `mul` 2. `add-add` -> `add` 3. `add-mul` -> `mul-add` As `SimplifyExpr` apply simplifications iteratively until no changes to the graph, any `mul` and `add` series could be rewritten to one single `mul`, `add` or `mul-add` with one of the binary inputs could evaluates to a constant in the following `FoldConstant` pass. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
