psrivas2 commented on code in PR #14257: URL: https://github.com/apache/tvm/pull/14257#discussion_r1137041891
########## tests/python/relax/test_transform_convert_layout.py: ########## @@ -0,0 +1,1211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import ConvertLayout, Normalize +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input) + mod = Normalize()(mod) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv) + return gv + + verify(Input, Expected) + + +def test_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_relu_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims( + x0, axes=[0, 2, 3, 1] + ) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_relu_tanh(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.tanh(gv2) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_add(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + bias, axes=[0, 2, 3, 1] + ) + lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv3, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_add_relu_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + bias, axes=[0, 2, 3, 1] + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.permute_dims( + lv4, axes=[0, 3, 1, 2] + ) + R.output(gv4) + return gv4 + + verify(Input, Expected) + + +def test_conv2d_fma_relu_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv, axes=[0, 3, 1, 2] + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.ewise_fma(lv2, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + gv3, axes=[0, 2, 3, 1] + ) + lv4: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv5: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d( + lv3, + lv4, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.permute_dims( + lv5, axes=[0, 3, 1, 2] + ) + R.output(gv4) + return gv4 + + verify(Input, Expected) + + +def test_conv2d_sum(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=2): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=False) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_sum_keepdim(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 1, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=True) + gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_transpose(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((26, 26, 4, 2), dtype="float32") = R.permute_dims( + gv, axes=[2, 1, 3, 0] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_expand_dims(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=6): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") = R.expand_dims( + gv, axis=[-3, 1] + ) + gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 1, 5, 3, 2, 4] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_expand_dims_squeeze(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.squeeze(gv2, axis=[1, 3]) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") = R.expand_dims( + gv, axis=[-3, 1] + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.squeeze(gv2, axis=[1, 3]) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_strided_slice(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 9, 7, 2), dtype="float32") = R.strided_slice( + gv, axes=[3, 1, 2], begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4] + ) + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_relu_concat(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv, gv2), axis=3) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_relu_concat_split(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv, gv2), axis=3) + gv4: R.Tuple( + R.Tensor((2, 26, 26, 4), dtype="float32"), + R.Tensor((2, 26, 26, 4), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=3) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[0] + lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + lv4: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[1] + lv5: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv4, axes=[0, 3, 1, 2] + ) + gv5 = (lv3, lv5) + R.output(gv5) + return gv5 + + verify(Input, Expected) + + +def test_conv2d_maxpool2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 4), dtype="float32") = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_avgpool2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 4), dtype="float32") = R.nn.adaptive_avg_pool2d( + gv, output_size=[13, 13], layout="NHWC", out_layout="NHWC" + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_softmax(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.softmax(gv, axis=3) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_batchnorm(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm(gv, gamma, beta, moving_mean, moving_var, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tuple( + R.Tensor((2, 26, 26, 4), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm( + gv, + gamma, + beta, + moving_mean, + moving_var, + axis=3, + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv2[0] + lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + lv4: R.Tensor((4,), dtype="float32") = gv2[1] + lv5: R.Tensor((4,), dtype="float32") = gv2[2] + gv3 = (lv3, lv4, lv5) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_layernorm(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.layer_norm( + gv, + gamma, + beta, + axes=[1, 2], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_resize2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.image.resize2d(gv, (52, 52), layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 52, 52, 4), dtype="float32") = R.image.resize2d( + gv, + (52, 52), + roi=[T.float32(0), T.float32(0), T.float32(0), T.float32(0)], + layout="NHWC", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="void", + ) + gv2: R.Tensor((2, 4, 52, 52), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_unknown_dim(): Review Comment: should we rename this test to `test_conv2d_unknown_bias_dim`? ########## tests/python/relax/test_transform_convert_layout.py: ########## @@ -0,0 +1,1211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import ConvertLayout, Normalize +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input) + mod = Normalize()(mod) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv) + return gv + + verify(Input, Expected) + + +def test_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_relu_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims( Review Comment: Does this pass only attempts to flow the new layouts downstream from producers to consumers? It would be nice to document it somewhere. ########## python/tvm/script/parser/core/evaluator.py: ########## @@ -325,7 +325,6 @@ def _eval_slice(self, fields: Dict[str, Any]) -> slice: lower = self._eval_expr(lower) if lower is not None else None upper = self._eval_expr(upper) if upper is not None else None step = self._eval_expr(step) if step is not None else None Review Comment: left over change? ########## tests/python/relax/test_transform_convert_layout.py: ########## @@ -0,0 +1,1211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import ConvertLayout, Normalize +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input) + mod = Normalize()(mod) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected) + + Review Comment: Can you add a test with R.match_cast between let's say conv2d and bias_add? ########## tests/python/relax/test_transform_convert_layout.py: ########## @@ -0,0 +1,1211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import ConvertLayout, Normalize +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input) + mod = Normalize()(mod) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected) + Review Comment: Can you add a few tests with symbolic shapes? ########## tests/python/relax/test_transform_convert_layout.py: ########## @@ -0,0 +1,1211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import ConvertLayout, Normalize +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input) + mod = Normalize()(mod) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv) + return gv + + verify(Input, Expected) + + +def test_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_relu_conv2d_relu(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims( + x0, axes=[0, 2, 3, 1] + ) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_relu_tanh(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.tanh(gv2) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv3) + return gv3 + + verify(Input, Expected) + + +def test_conv2d_add(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + bias, axes=[0, 2, 3, 1] + ) + lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv3, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_add_relu_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.permute_dims( + bias, axes=[0, 2, 3, 1] + ) + gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.permute_dims( + lv4, axes=[0, 3, 1, 2] + ) + R.output(gv4) + return gv4 + + verify(Input, Expected) + + +def test_conv2d_fma_relu_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + gv, axes=[0, 3, 1, 2] + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.ewise_fma(lv2, scale, bias) Review Comment: can you add a comment here or in InferLayoutEwiseFMA why modifying `R.ewise_fma` and its operands to be compliant with new layout is not feasible? ########## include/tvm/relax/transform.h: ########## @@ -279,6 +279,14 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_opt * \return The Pass. */ TVM_DLL Pass SimplifyNormInference(); + +/*! + * \brief Automatic layout conversion pass. Review Comment: Should we remove the word automatic from description and call it simply Covert Layout pass or Layout conversion pass? Automatic does not convey much here IMO. ########## src/relax/transform/convert_layout.cc: ########## @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/convert_layout.cc + * \brief Automatic layout conversion pass, especially for axis swapping. + */ + +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/nested_msg.h> +#include <tvm/relax/op_attr_types.h> +#include <tvm/relax/transform.h> + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tir::Layout; + +/*! + * \brief Main logic to automatically convert the layout of conv2d. Other ops + * can adapt to such layout conversion following conv2d accordingly. + * + * Structurally speaking, a Relax function is composed of a series of VarBinding and + * MatchCast. And a specific class of VarBindings is the basic unit we want to rewrite. + * Formally, they are of the form: + * + * var = Call(Op, [args], attrs) + * + * where Op is a specific op we want to rewrite, and attrs is the attributes of the op. + * var and args are all exprs with type Tensor or Tuple of Tensors. They might + * be vars, constants, or Tuple of vars and constants. We register the layout inference function for + * each op (FRelaxInferLayout), which accepts the current call, the desired layout of conv2d ops, + * and the layout map of previous vars. The result of the layout inference function is contained in + * an InferLayoutOutput object, which contains 3 fields: input_layouts, output_layouts, and attr, + * which represents the expected input layout, output_layout and converted attrs of the new op call. + * The rewriter will use these info to convert the layout of inputs and attrs of the op call, and + * note down the new layout of the output. + * + * The desired layout of conv2d ops is a map from the name of the op to the desired layout of the + * desired feature map, weight and output. For example, if we want to convert the layout of conv2d + * from NCHW to NHWC, we can set the desired layout of conv2d to be {"conv2d": ["NHWC", "OHWI"]}. + * + * The way we represent the layout of a var is a NLayout object, which is a nested tuple of Layout. + * The incoming layout of the module will be set as the default layout (We use ABCD... as the + * default) Note that for operators like conv, pool, people typically use NHWC to refer to the axes. + * But to be generic and support more operators, we use ABCD... to refer to the axes. + * + * Note that currently the layout conversion of conv2d only support axis swapping, such as NCHW to + * NWHC. Packed layout such as NCHW to NCHW4c is not supported now. + */ +class LayoutConvertMutator : public ExprMutator { + public: + explicit LayoutConvertMutator(const Map<String, Array<String>>& desired_layouts) + : desired_layouts_(desired_layouts) {} + + void InitVarMap(const Function& func) { + for (const auto& param : func->params) { + if (IsNestedTensor(param)) { + var_layout_map_[param] = InitialNLayout(param); + } + } + } + + private: + Array<Integer> LayoutToIntegers(const Layout& layout) { + Array<Integer> ret; + LayoutDecision src = InitialLayoutDecision(layout.ndim()); + for (size_t i = 0; i < layout.ndim(); ++i) { + ret.push_back(Integer(src->layout.IndexOf(layout[i]))); + } + return ret; + } + + Expr RewriteExpr(const Expr& expr, const NLayout& to) { + auto fvisitleaf = [&](const Expr& expr, std::array<NLayout, 2> layouts) -> Expr { + NLayout from = layouts[0], to = layouts[1]; + if (NLayoutEqual()(from, to)) return expr; + // If not both from and to are dynamic, then none of them can be dynamic. + ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && + !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) + << "Cannot convert when exactly one of the layouts is dynamic"; + const auto* tensor = GetStructInfoAs<TensorStructInfoNode>(expr); + ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr; + Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout, + from.LeafValue()->layout, to.LeafValue()->layout); + return permute_dims(expr, LayoutToIntegers(axes)); + }; + return TransformTupleLeaf<LayoutDecision>( + VarReplacer::Replace(expr, var_remap_), + std::array<NLayout, 2>({GetNLayout(var_layout_map_, expr), to}), fvisitleaf); + } + + Array<Expr> RewriteArgs(const Array<Expr>& args, const Array<NLayout>& to) { + ICHECK(args.size() == to.size()); + std::vector<Expr> new_args; + for (size_t i = 0; i < args.size(); ++i) { + if (IsNestedTensor(args[i])) { + new_args.push_back(RewriteExpr(args[i], to[i])); + } else { + new_args.push_back(args[i]); + } + } + return std::move(new_args); + } + + void VisitBinding(const Binding& binding) final { + // Emit the binding + ExprMutator::VisitBinding(binding); + if (!builder_->CurrentBlockIsDataFlow()) { + return; + } + // The layout is default to be initial if not rewritten. + if (IsNestedTensor(binding->var)) { + if (var_layout_map_.find(binding->var) == var_layout_map_.end()) { + var_layout_map_[binding->var] = InitialNLayout(binding->var); + } + } + } + + Expr VisitVars_(const Var& var) { + if (IsNestedTensor(var)) { + // We encounter a var use outside of Call, we rewrite it to initial layout. + return RewriteExpr(var, InitialNLayout(var)); + } + return ExprMutator::VisitExpr_(var.get()); + } + + Expr VisitExpr_(const VarNode* op) final { + if (!builder_->CurrentBlockIsDataFlow()) { + return ExprMutator::VisitExpr_(op); + } + return VisitVars_(GetRef<Var>(op)); + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + if (!builder_->CurrentBlockIsDataFlow()) { + return ExprMutator::VisitExpr_(op); + } + return VisitVars_(GetRef<Var>(op)); + } + + bool HasUnknownDimTensor(const NLayout& nlayout) { + bool find = false; + auto fvisit = [&](const LayoutDecision& layout) { + find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim())); + }; + ForEachLeaf<LayoutDecision>(nlayout, fvisit); + return find; + } + + bool HasUnknownDimTensor(const Array<Expr>& args) { + for (const auto& arg : args) { + if (IsNestedTensor(arg)) { + if (HasUnknownDimTensor(GetNLayout(var_layout_map_, arg))) { + return true; + } + } + } + return false; + } + + Optional<InferLayoutOutput> GetInferLayoutInfo(const CallNode* call_node, + const Map<String, Array<String>>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const OpNode* op_node = call_node->op.as<OpNode>(); + if (op_node == nullptr) return NullOpt; + Op op = Downcast<Op>(GetRef<Op>(op_node)); + const auto attr_map = Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout"); + if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { + // If the op has FRelaxInferLayout, and all the input tensors have known ndim + FRelaxInferLayout f = attr_map[op]; + return f(GetRef<Call>(call_node), desired_layouts, var_layout_map); + } else { + // Otherwise, we use the default policy. + return NullOpt; + } + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + if (!builder_->CurrentBlockIsDataFlow()) { + ExprMutator::VisitBinding_(binding, call_node); + return; + } + Optional<InferLayoutOutput> res = + GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_); + ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node); + new_call->struct_info_ = NullOpt; + if (!res.defined() || + (!IsNestedTensor(binding->var) && !binding->var->IsInstance<DataflowVarNode>())) { + // Default policy: use the initial layout. + // When we don't have the infer layout info, or it's a non-tensor global var binding. + std::vector<NLayout> input_layout; + for (const auto& arg : call_node->args) { + if (IsNestedTensor(arg)) { + input_layout.push_back(InitialNLayout(arg)); + } else { + input_layout.push_back(LayoutDecision::InitUnknownDim()); + } + } + Array<Expr> new_args = RewriteArgs(call_node->args, std::move(input_layout)); + new_call->args = std::move(new_args); + ReEmitBinding(binding, builder_->Normalize(Call(new_call))); + // update the layout map + if (IsNestedTensor(binding->var)) { + var_layout_map_[binding->var] = InitialNLayout(binding->var); + } + } else { + // Convert the layout according to the inferred layout output. + Array<Expr> new_args = RewriteArgs(call_node->args, res.value()->input_layouts); + new_call->args = std::move(new_args); + new_call->attrs = std::move(res.value()->new_attrs); + Expr cur_call = builder_->Normalize(Call(new_call)); + if (binding->var->IsInstance<DataflowVarNode>()) { + // Dataflow var, we emit the rewritten call. + ReEmitBinding(binding, cur_call); + // update the layout map + if (IsNestedTensor(binding->var)) { + var_layout_map_[binding->var] = res.value()->output_layouts[0]; + } + } else { + // Global var (tensor), we rewrite it to initial layout + ICHECK(IsNestedTensor(binding->var)); + if (!NLayoutEqual()(res.value()->output_layouts[0], InitialNLayout(binding->var))) { + Var new_var = builder_->Emit(cur_call); + var_layout_map_[new_var] = res.value()->output_layouts[0]; + cur_call = builder_->Normalize(RewriteExpr(new_var, InitialNLayout(binding->var))); + } + ReEmitBinding(binding, cur_call); + // update the layout map + var_layout_map_[binding->var] = InitialNLayout(binding->var); + } + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { + if (!builder_->CurrentBlockIsDataFlow()) { + ExprMutator::VisitBinding_(binding, val); + return; + } + std::vector<NLayout> input_layout; + for (const auto& field : val->fields) { + if (IsNestedTensor(field)) { + if (binding->var->IsInstance<DataflowVarNode>()) { + // Df var: Use the current realized layout to group the tuple; + input_layout.push_back(GetNLayout(var_layout_map_, field)); + } else { + // Global var: Use the initial layout to group the tuple; + input_layout.push_back(InitialNLayout(field)); + } + } else { + input_layout.push_back(LayoutDecision::InitUnknownDim()); + } + } + Array<Expr> new_fields = RewriteArgs(val->fields, std::move(input_layout)); + if (IsNestedTensor(binding->var)) { + ReEmitBinding(binding, builder_->Normalize(Tuple(new_fields))); + var_layout_map_[binding->var] = input_layout; + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { + if (!builder_->CurrentBlockIsDataFlow()) { + ExprMutator::VisitBinding_(binding, val); + return; + } + if (IsNestedTensor(val->tuple)) { + // Use the current realized layout to retrieve the field; + NLayout input_layout = binding->var->IsInstance<DataflowVarNode>() + ? GetNLayout(var_layout_map_, val->tuple) + : InitialNLayout(val->tuple); + ReEmitBinding(binding, builder_->Normalize( + TupleGetItem(RewriteExpr(val->tuple, input_layout), val->index))); + // update the layout map + var_layout_map_[binding->var] = input_layout.NestedArray()[val->index]; + } else { + ExprMutator::VisitBinding_(binding, val); + } + } + + std::unordered_map<Var, NLayout, ObjectPtrHash, ObjectPtrEqual> var_layout_map_; + Map<String, Array<String>> desired_layouts_; +}; + +Expr ConvertLayoutPass(const Function& f, Map<String, Array<String>> desired_layouts) { + LayoutConvertMutator mutator(desired_layouts); + mutator.InitVarMap(f); + return mutator.VisitExpr(f); +} + +namespace transform { + +Pass ConvertLayout(Map<String, Array<String>> desired_layouts) { + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast<Function>(ConvertLayoutPass(f, desired_layouts)); + }; + return CreateFunctionPass(pass_func, 0, "ConvertLayout", {}); Review Comment: Should this be a DataflowBlock Pass instead of a Function pass? Then we can avoid the dataflow block guard everywhere ``` if (!builder_->CurrentBlockIsDataFlow()) { ExprMutator::Visit...; return; } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
