This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 51d4b6bcde [Relax] Batch norm correctness on eval mode (#17752)
51d4b6bcde is described below

commit 51d4b6bcdeef9094e721e9154ae096370a05a465
Author: Hugo Latendresse <[email protected]>
AuthorDate: Tue Mar 25 20:57:11 2025 -0400

    [Relax] Batch norm correctness on eval mode (#17752)
    
    Batch_norm is a different operator in training and eval.
    The previous interface defaulted to the training mode and required
    changing an ingested pytorch program itself to use the eval mode.
    This is sub-ideal, especially since torch.export explicitely communicates
    whether batch_norm should be in training or eval in a given torch program.
    
    This PR automates the selection of training/eval mode in the exported
    program translator, and achieves correctness for eval mode.
    
    Future TODO: there is something wrong with batch_norm on training
    mode. It does not pass a correctness test when taken straight from
    the main branch (there's an issue with tensor dimensions).
    I added a note to address later as training mode is probably not high 
priority.
---
 include/tvm/relax/attrs/nn.h                       |   2 +
 .../frontend/torch/base_fx_graph_translator.py     |   2 +-
 .../frontend/torch/exported_program_translator.py  |  43 +-
 python/tvm/relax/op/nn/nn.py                       |   8 +-
 python/tvm/relax/transform/legalize_ops/nn.py      |   4 +-
 python/tvm/topi/nn/batch_norm.py                   |  26 +-
 src/relax/op/nn/nn.cc                              |   8 +-
 src/relax/op/nn/nn.h                               |   2 +-
 tests/python/relax/test_from_exported_to_cuda.py   |  59 +-
 .../python/relax/test_transform_legalize_ops_nn.py | 980 ++++++++++++---------
 10 files changed, 683 insertions(+), 451 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 8329344174..8f63012e09 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -462,6 +462,7 @@ struct BatchNormAttrs : public 
tvm::AttrsNode<BatchNormAttrs> {
   bool center;
   bool scale;
   double momentum;
+  bool training;
 
   TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") {
     TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is 
applied.");
@@ -470,6 +471,7 @@ struct BatchNormAttrs : public 
tvm::AttrsNode<BatchNormAttrs> {
         "Indicating if the beta offset will be added to the normalized 
tensor.");
     TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be 
multiplied.");
     TVM_ATTR_FIELD(momentum).describe("The value used for the moving_mean and 
moving_var update.");
+    TVM_ATTR_FIELD(training).describe("Whether we are training (i.e., not in 
eval mode).");
   }
 };  // struct BatchNormAttrs
 
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 4bfdb8c1bc..ecd8665b43 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1106,7 +1106,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         return self.env[node.args[0]]
 
     def _copy_(self, node: fx.Node) -> relax.Var:
-        # Copies the source tensor's to the destination tensor
+        # Copies the source tensor's into the destination tensor
         # In TVM, that means simply returning the source tensor
         return self.env[node.args[1]]
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index bf01bd6531..84821d27b5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -45,7 +45,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     ########## Neural Network ##########
 
-    def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
+    def _batch_norm(self, node: fx.Node, training) -> relax.Var:
         import numpy as np
 
         x = self.env[node.args[0]]
@@ -55,22 +55,43 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         bias = self.env.get(node.args[2], relax.const(np.zeros(channel), 
dtype=dtype))
         running_mean = self.env.get(node.args[3], 
relax.const(np.zeros(channel), dtype=dtype))
         running_var = self.env.get(node.args[4], relax.const(np.ones(channel), 
dtype=dtype))
-        momentum = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("momentum", 0.1)
-        eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 
1e-05)
+        ignore_running_stats = (
+            node.args[5] if len(node.args) > 5 else 
node.kwargs.get("track_running_stats", True)
+        )
+        track_running_stats = not ignore_running_stats
+        momentum = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("momentum", 0.1)
+        eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 
1e-05)
+
+        if track_running_stats:
+            training = True
 
         return self.block_builder.emit(
             relax.op.nn.batch_norm(
-                x,
-                weight,
-                bias,
-                running_mean,
-                running_var,
-                axis=1,
+                data=x,
+                gamma=weight,
+                beta=bias,
+                moving_mean=running_mean,
+                moving_var=running_var,
+                axis=1,  # Always over channel
                 epsilon=eps,
                 momentum=momentum,
-            )
+                training=training,
+            )[0]
         )
 
+    def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var:
+        # This method is called for batch_norm in training mode
+        # TODO does not have correctness!
+        # TODO we need to store the running mean and variance returned by the
+        # previous call to batch_norm and pass it again
+        training = True
+        return self._batch_norm(node, training)
+
+    def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
+        # This method is called for batch_norm in eval mode
+        training = False
+        return self._batch_norm(node, training)
+
     def _group_norm(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         num_groups = node.args[1]
@@ -283,7 +304,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             # linear algebra
             "linalg_vector_norm.default": self._linalg_vector_norm,
             # neural network
+            "_native_batch_norm_legit_functional.default": 
self._batch_norm_legit_functional,
             "_native_batch_norm_legit_no_training.default": 
self._batch_norm_legit_no_training,
+            "batch_norm.default": self._batch_norm_legit_no_training,
             "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
             "addmm.default": self._addmm,
             "avg_pool2d.default": self._avg_pool2d,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5a1895cbc1..09a7df5149 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1393,6 +1393,7 @@ def batch_norm(
     center: bool = True,
     scale: bool = True,
     momentum: float = 0.1,
+    training: bool = True,
 ) -> Expr:
     r"""
     Batch normalization layer (Ioffe and Szegedy, 2014).
@@ -1481,13 +1482,18 @@ def batch_norm(
     momentum : float
         The value used for the moving_mean and moving_var update.
 
+    training : bool
+        A boolean value to indicate whether training or in eval mode. By 
default.
+          relax batch_norm is training mode. To transform it to inference mode,
+          can use DecomposeOpsForInference.
+
     Returns
     -------
     result : relax.Expr
         The computed result.
     """
     return _ffi_api.batch_norm(  # type: ignore
-        data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, 
scale, momentum
+        data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, 
scale, momentum, training
     )
 
 
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index d9fb4701f7..4c8bdbc661 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -551,9 +551,7 @@ def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr:
         epsilon=call.attrs.epsilon,
         center=call.attrs.center,
         scale=call.attrs.scale,
-        # By default relax batch_norm is training mode.
-        # To transform it to inference mode, use DecomposeOpsForInference.
-        training=True,
+        training=call.attrs.training,
         momentum=call.attrs.momentum,
     )
 
diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py
index 3181efd7da..8308c93eae 100644
--- a/python/tvm/topi/nn/batch_norm.py
+++ b/python/tvm/topi/nn/batch_norm.py
@@ -111,22 +111,26 @@ def batch_norm(
     shape = [1] * len(data.shape)
     shape[axis] = data.shape[axis]
 
+    reduce_axes = list(range(len(data.shape)))
+    reduce_axes.remove(axis)
+    shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in 
reduce_axes], 1)
+
+    data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
+    data_mean_rs = topi.reshape(data_mean, shape)
+    data_var = (
+        topi.sum((data - data_mean_rs) * (data - data_mean_rs), 
axis=reduce_axes) / shape_prod
+    )
+    data_var_rs = topi.reshape(data_var, shape)
+
     if training:
-        reduce_axes = list(range(len(data.shape)))
-        reduce_axes.remove(axis)
-        shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in 
reduce_axes], 1)
-        data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
-        data_mean_rs = topi.reshape(data_mean, shape)
-        data_var = (
-            topi.sum((data - data_mean_rs) * (data - data_mean_rs), 
axis=reduce_axes) / shape_prod
-        )
-        data_var_rs = topi.reshape(data_var, shape)
-        out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
-    else:
         moving_mean_rs = topi.reshape(moving_mean, shape)
         moving_var_rs = topi.reshape(moving_var, shape)
+
         out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
 
+    else:
+        out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
+
     if scale:
         out = out * topi.reshape(gamma, shape)
     if center:
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index b4668d65d3..826711538c 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -252,13 +252,14 @@ bool NormCheckDtypeAndShape(const Call& call, const 
BlockBuilder& ctx,
 TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
 
 Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr 
moving_var,  //
-                int axis, double epsilon, bool center, bool scale, double 
momentum) {
+                int axis, double epsilon, bool center, bool scale, double 
momentum, bool training) {
   ObjectPtr<BatchNormAttrs> attrs = make_object<BatchNormAttrs>();
   attrs->axis = axis;
   attrs->epsilon = epsilon;
   attrs->center = center;
   attrs->scale = scale;
   attrs->momentum = momentum;
+  attrs->training = training;
 
   static const Op& op = Op::Get("relax.nn.batch_norm");
   return Call(op,
@@ -266,7 +267,6 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr 
moving_mean, Expr moving_
                std::move(moving_var)},
               Attrs{attrs}, {});
 }
-
 TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm);
 
 StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) 
{
@@ -388,7 +388,7 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call,
 TVM_REGISTER_OP("relax.nn.layer_norm")
     .set_attrs_type<LayerNormAttrs>()
     .set_num_inputs(3)
-    .add_argument("data", "Tensor", "Input to which batch_norm will be 
applied.")
+    .add_argument("data", "Tensor", "Input to which layer_norm will be 
applied.")
     .add_argument("gamma", "Tensor", "The gamma scale factor.")
     .add_argument("beta", "Tensor", "The beta offset factor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm)
@@ -500,7 +500,7 @@ InferLayoutOutput InferLayoutGroupNorm(const Call& call,
 TVM_REGISTER_OP("relax.nn.group_norm")
     .set_attrs_type<GroupNormAttrs>()
     .set_num_inputs(3)
-    .add_argument("data", "Tensor", "Input to which batch_norm will be 
applied.")
+    .add_argument("data", "Tensor", "Input to which group_norm will be 
applied.")
     .add_argument("gamma", "Tensor", "The gamma scale factor.")
     .add_argument("beta", "Tensor", "The beta offset factor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm)
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index a3658fed54..28c14139b9 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -68,7 +68,7 @@ Expr log_softmax(Expr data, int axis);
 
 /*! \brief Compute batch normalization. */
 Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr 
moving_var,  //
-                int axis, double epsilon, bool center, bool scale, double 
momentum);
+                int axis, double epsilon, bool center, bool scale, double 
momentum, bool training);
 
 /*! \brief Compute layer normalization. */
 Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double 
epsilon, bool center,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index c120eb8981..f7501dd3b5 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -289,6 +289,25 @@ def test_linalg_vector_norm(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_batch_norm_prog(target, dev):
+    # Default args, in a pytorch program (to ensure output is in proper type 
and format)
+    raw_data = np.random.randn(2, 3, 2, 2).astype(np.float32)
+
+    class BatchNormWrapper(nn.Module):
+        def __init__(self):
+            super(BatchNormWrapper, self).__init__()
+            self.bn = nn.BatchNorm2d(3)
+
+        def forward(self, x):
+            x = self.bn(x)
+            x = x + 1
+            return x
+
+    torch_module = BatchNormWrapper().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_split_size(target, dev):
     # Test split using the split_size argument such that it is not a divisor
@@ -310,7 +329,46 @@ def test_split_size(target, dev):
             return torch.split(x, split_size_or_sections=self.split_size, 
dim=self.dim)
 
     torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
+
[email protected]_targets("cuda")
+def test_batch_norm0(target, dev):
+    # Eval, no momentum, no affine, no running stats
+    raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32)
+    torch_module = nn.BatchNorm2d(
+        3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, 
device=None, dtype=None
+    ).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_batch_norm1(target, dev):
+    # Eval, with momentum, no affine, with running stats
+    raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32)
+    torch_module = nn.BatchNorm2d(
+        4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True, 
device=None, dtype=None
+    ).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_batch_norm2(target, dev):
+    # Eval, with momentum, affine, no running stats
+    raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32)
+    torch_module = nn.BatchNorm2d(
+        4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False
+    ).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_batch_norm3(target, dev):
+    # Eval, no momentum, affine, with running stats
+    raw_data = np.random.randn(1, 2, 2, 2).astype(np.float32)
+    torch_module = nn.BatchNorm2d(
+        2, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True
+    ).eval()
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
@@ -335,7 +393,6 @@ def test_split_sections_list(target, dev):
             return torch.split(x, split_size_or_sections=self.split_size, 
dim=self.dim)
 
     torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
-
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index d83d0567e4..4ac4b57b91 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1955,212 +1955,289 @@ def test_batch_norm():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3),), 
"float32"), rxplaceholder_2: T.Buffer((T.int64(3),), "float32"), 
rxplaceholder_3: T.Buffer((T.int64(3),), "float32"), rxplaceholder_4: 
T.Buffer((T.int64(3),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)), "float32"), T_add_1: T.Buffer((T.int64(3),), 
"float32"), T_add_2: T.Buffer((T.int64(3),), "float32")):
-            T.func_attr({"tir.noalias": True})
-            # with T.block("root"):
-            rxplaceholder_red = T.alloc_buffer((T.int64(3),))
-            T_divide = T.alloc_buffer((T.int64(3),))
-            T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
-            T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), 
T.int64(28)))
-            T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-            T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-            T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), 
T.int64(28)))
-            T_multiply_red = T.alloc_buffer((T.int64(3),))
-            T_divide_1 = T.alloc_buffer((T.int64(3),))
-            T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
-            T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
-            compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
-            T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), 
T.int64(28)))
-            T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
-            T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-            T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
-            T_multiply_2 = T.alloc_buffer((T.int64(3),))
-            T_multiply_3 = T.alloc_buffer((T.int64(3),))
-            T_multiply_4 = T.alloc_buffer((T.int64(3),))
-            T_subtract_3 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-            T_subtract_4 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-            T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-            T_multiply_red_1 = T.alloc_buffer((T.int64(3),))
-            T_divide_3 = T.alloc_buffer((T.int64(3),))
-            T_multiply_6 = T.alloc_buffer((T.int64(3),))
-            for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), 
T.int64(28)):
-                with T.block("rxplaceholder_red"):
-                    v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, 
k2, k3])
-                    T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3])
-                    T.writes(rxplaceholder_red[v_ax0])
-                    with T.init():
-                        rxplaceholder_red[v_ax0] = T.float32(0)
-                    rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + 
rxplaceholder[v_k0, v_ax0, v_k2, v_k3]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_divide"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(rxplaceholder_red[v_ax0])
-                    T.writes(T_divide[v_ax0])
-                    T_divide[v_ax0] = rxplaceholder_red[v_ax0] * 
T.float32(0.00063775510204081628)
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)):
-                with T.block("T_reshape"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
-                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax1 + 
v_ax2 + v_ax3) % T.int64(3)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_subtract"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_subtract_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_subtract_2"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_multiply"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], 
T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
-            for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), 
T.int64(28)):
-                with T.block("T_multiply_red"):
-                    v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, 
k2, k3])
-                    T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3])
-                    T.writes(T_multiply_red[v_ax0])
-                    with T.init():
-                        T_multiply_red[v_ax0] = T.float32(0)
-                    T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + 
T_multiply[v_k0, v_ax0, v_k2, v_k3]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_divide_1"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(T_multiply_red[v_ax0])
-                    T.writes(T_divide_1[v_ax0])
-                    T_divide_1[v_ax0] = T_multiply_red[v_ax0] * 
T.float32(0.00063775510204081628)
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)):
-                with T.block("T_reshape_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
-                    T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)):
-                with T.block("T_add"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, 
v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
-            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)):
-                with T.block("compute"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
-                    T.writes(compute[v_i0, v_i1, v_i2, v_i3])
-                    compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, 
v_i1, v_i2, v_i3])
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_divide_2"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], 
compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, 
v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)):
-                with T.block("T_reshape_2"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % 
T.int64(3)])
-                    T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_multiply_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)):
-                with T.block("T_reshape_3"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % 
T.int64(3)])
-                    T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_add_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, 
v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_multiply_2"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(rxplaceholder_3[v_ax0])
-                    T.writes(T_multiply_2[v_ax0])
-                    T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * 
rxplaceholder_3[v_ax0]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_multiply_3"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(T_divide[v_ax0])
-                    T.writes(T_multiply_3[v_ax0])
-                    T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * 
T_divide[v_ax0]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_add_2"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
-                    T.writes(T_add_1[v_ax0])
-                    T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_multiply_4"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(rxplaceholder_4[v_ax0])
-                    T.writes(T_multiply_4[v_ax0])
-                    T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * 
rxplaceholder_4[v_ax0]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_subtract_3"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_subtract_4"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)):
-                with T.block("T_multiply_5"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], 
T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, 
v_ax3]
-            for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), 
T.int64(28)):
-                with T.block("T_multiply_red_1"):
-                    v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, 
k2, k3])
-                    T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3])
-                    T.writes(T_multiply_red_1[v_ax0])
-                    with T.init():
-                        T_multiply_red_1[v_ax0] = T.float32(0)
-                    T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + 
T_multiply_5[v_k0, v_ax0, v_k2, v_k3]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_divide_3"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(T_multiply_red_1[v_ax0])
-                    T.writes(T_divide_3[v_ax0])
-                    T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] * 
T.float32(0.00063775510204081628)
-            for ax0 in range(T.int64(3)):
-                with T.block("T_multiply_6"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(T_divide_3[v_ax0])
-                    T.writes(T_multiply_6[v_ax0])
-                    T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_3[v_ax0]
-            for ax0 in range(T.int64(3)):
-                with T.block("T_add_3"):
-                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                    T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0])
-                    T.writes(T_add_2[v_ax0])
-                    T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0]
+        def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: 
T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: 
T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28), 
T.int64(28)))
+            gamma = T.match_buffer(var_gamma, (T.int64(3),))
+            beta = T.match_buffer(var_beta, (T.int64(3),))
+            moving_mean = T.match_buffer(var_moving_mean, (T.int64(3),))
+            moving_var = T.match_buffer(var_moving_var, (T.int64(3),))
+            T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+            T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),))
+            T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),))
+            with T.block("root"):
+                T.reads()
+                T.writes()
+                T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
+                T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
+                T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
+                compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
+                T_divide = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
+                T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
+                T_multiply_1 = T.alloc_buffer((T.int64(3),))
+                x_red = T.alloc_buffer((T.int64(3),))
+                T_divide_1 = T.alloc_buffer((T.int64(3),))
+                T_multiply_2 = T.alloc_buffer((T.int64(3),))
+                T_multiply_3 = T.alloc_buffer((T.int64(3),))
+                T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
+                T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_multiply_red = T.alloc_buffer((T.int64(3),))
+                T_divide_2 = T.alloc_buffer((T.int64(3),))
+                T_multiply_5 = T.alloc_buffer((T.int64(3),))
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(moving_mean[(v_ax1 + v_ax2 + 
v_ax3) % T.int64(3)])
+                                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_subtract"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_1"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3) 
% T.int64(3)])
+                                    T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_add"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T.writes(T_add_3[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
+                for i0 in range(T.int64(1)):
+                    for i1 in range(T.int64(3)):
+                        for i2 in range(T.int64(1)):
+                            for i3 in range(T.int64(1)):
+                                with T.block("compute"):
+                                    v_i0 = T.axis.spatial(T.int64(1), i0)
+                                    v_i1 = T.axis.spatial(T.int64(3), i1)
+                                    v_i2 = T.axis.spatial(T.int64(1), i2)
+                                    v_i3 = T.axis.spatial(T.int64(1), i3)
+                                    T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
+                                    T.writes(compute[v_i0, v_i1, v_i2, v_i3])
+                                    compute[v_i0, v_i1, v_i2, v_i3] = 
T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3])
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_divide"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_2"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(gamma[(v_ax1 + v_ax2 + v_ax3) % 
T.int64(3)])
+                                    T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_multiply"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_3"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(beta[(v_ax1 + v_ax2 + v_ax3) % 
T.int64(3)])
+                                    T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_add_1"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+                                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_multiply_1"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(moving_mean[v_ax0])
+                        T.writes(T_multiply_1[v_ax0])
+                        T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * 
moving_mean[v_ax0]
+                for ax0 in range(T.int64(3)):
+                    for k0 in range(T.int64(2)):
+                        for k2 in range(T.int64(28)):
+                            for k3 in range(T.int64(28)):
+                                with T.block("x_red"):
+                                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                                    v_k0 = T.axis.reduce(T.int64(2), k0)
+                                    v_k2 = T.axis.reduce(T.int64(28), k2)
+                                    v_k3 = T.axis.reduce(T.int64(28), k3)
+                                    T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+                                    T.writes(x_red[v_ax0])
+                                    with T.init():
+                                        x_red[v_ax0] = T.float32(0.0)
+                                    x_red[v_ax0] = x_red[v_ax0] + x[v_k0, 
v_ax0, v_k2, v_k3]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_divide_1"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(x_red[v_ax0])
+                        T.writes(T_divide_1[v_ax0])
+                        T_divide_1[v_ax0] = x_red[v_ax0] * 
T.float32(0.00063775510204081628)
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_multiply_2"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(T_divide_1[v_ax0])
+                        T.writes(T_multiply_2[v_ax0])
+                        T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_1[v_ax0]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_add_2"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+                        T.writes(T_add_1[v_ax0])
+                        T_add_1[v_ax0] = T_multiply_1[v_ax0] + 
T_multiply_2[v_ax0]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_multiply_3"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(moving_var[v_ax0])
+                        T.writes(T_multiply_3[v_ax0])
+                        T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * 
moving_var[v_ax0]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_4"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) 
% T.int64(3)])
+                                    T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_subtract_1"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_subtract_2"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_multiply_4"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+                                    T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                for ax0 in range(T.int64(3)):
+                    for k0 in range(T.int64(2)):
+                        for k2 in range(T.int64(28)):
+                            for k3 in range(T.int64(28)):
+                                with T.block("T_multiply_red"):
+                                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                                    v_k0 = T.axis.reduce(T.int64(2), k0)
+                                    v_k2 = T.axis.reduce(T.int64(28), k2)
+                                    v_k3 = T.axis.reduce(T.int64(28), k3)
+                                    T.reads(T_multiply_4[v_k0, v_ax0, v_k2, 
v_k3])
+                                    T.writes(T_multiply_red[v_ax0])
+                                    with T.init():
+                                        T_multiply_red[v_ax0] = T.float32(0.0)
+                                    T_multiply_red[v_ax0] = 
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_divide_2"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(T_multiply_red[v_ax0])
+                        T.writes(T_divide_2[v_ax0])
+                        T_divide_2[v_ax0] = T_multiply_red[v_ax0] * 
T.float32(0.00063775510204081628)
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_multiply_5"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(T_divide_2[v_ax0])
+                        T.writes(T_multiply_5[v_ax0])
+                        T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_2[v_ax0]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_add_3"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+                        T.writes(T_add_2[v_ax0])
+                        T_add_2[v_ax0] = T_multiply_3[v_ax0] + 
T_multiply_5[v_ax0]
 
         @R.function
         def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: 
R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), 
moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), 
dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), 
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")):
-            gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, 
moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), 
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")])
+            cls = Expected
+            gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, 
moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), 
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")])
             return gv
     # fmt: on
 
@@ -2184,230 +2261,295 @@ def test_batch_norm_symbolic():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, 
var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, 
var_T_add_2: T.handle):
-            T.func_attr({"tir.noalias": True})
-            n = T.int64()
-            h = T.int64()
-            w = T.int64()
-            c = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c))
-            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,))
-            rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,))
-            rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,))
-            rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,))
+        def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: 
T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: 
T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64()
+            x = T.match_buffer(var_x, (n, h, w, c))
+            gamma = T.match_buffer(var_gamma, (c,))
+            beta = T.match_buffer(var_beta, (c,))
+            moving_mean = T.match_buffer(var_moving_mean, (c,))
+            moving_var = T.match_buffer(var_moving_var, (c,))
             T_add = T.match_buffer(var_T_add, (n, h, w, c))
             T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),))
             T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),))
-            # with T.block("root"):
-            rxplaceholder_red = T.alloc_buffer((h,))
-            T_divide = T.alloc_buffer((h,))
-            T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
-            T_subtract = T.alloc_buffer((n, h, w, c))
-            T_subtract_1 = T.alloc_buffer((n, h, w, c))
-            T_subtract_2 = T.alloc_buffer((n, h, w, c))
-            T_multiply = T.alloc_buffer((n, h, w, c))
-            T_multiply_red = T.alloc_buffer((h,))
-            T_divide_1 = T.alloc_buffer((h,))
-            T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
-            T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
-            compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
-            T_divide_2 = T.alloc_buffer((n, h, w, c))
-            T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
-            T_multiply_1 = T.alloc_buffer((n, h, w, c))
-            T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
-            T_multiply_2 = T.alloc_buffer((c,))
-            T_multiply_3 = T.alloc_buffer((h,))
-            T_multiply_4 = T.alloc_buffer((c,))
-            T_subtract_3 = T.alloc_buffer((n, h, w, c))
-            T_subtract_4 = T.alloc_buffer((n, h, w, c))
-            T_multiply_5 = T.alloc_buffer((n, h, w, c))
-            T_multiply_red_1 = T.alloc_buffer((h,))
-            T_divide_3 = T.alloc_buffer((h,))
-            T_multiply_6 = T.alloc_buffer((h,))
-            for ax0, k0, k2, k3 in T.grid(h, n, w, c):
-                with T.block("rxplaceholder_red"):
-                    v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, 
k2, k3])
-                    T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3])
-                    T.writes(rxplaceholder_red[v_ax0])
-                    with T.init():
-                        rxplaceholder_red[v_ax0] = T.float32(0)
-                    rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + 
rxplaceholder[v_k0, v_ax0, v_k2, v_k3]
-            for ax0 in range(h):
-                with T.block("T_divide"):
-                    v_ax0 = T.axis.spatial(h, ax0)
-                    T.reads(rxplaceholder_red[v_ax0])
-                    T.writes(T_divide[v_ax0])
-                    T_divide[v_ax0] = rxplaceholder_red[v_ax0] / 
T.Cast("float32", n * w * c)
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), 
T.int64(1)):
-                with T.block("T_reshape"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h])
-                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 * 
h + v_ax1 + v_ax2 + v_ax3) % h]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_subtract"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_subtract_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_subtract_2"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_multiply"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], 
T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
-            for ax0, k0, k2, k3 in T.grid(h, n, w, c):
-                with T.block("T_multiply_red"):
-                    v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, 
k2, k3])
-                    T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3])
-                    T.writes(T_multiply_red[v_ax0])
-                    with T.init():
-                        T_multiply_red[v_ax0] = T.float32(0)
-                    T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + 
T_multiply[v_k0, v_ax0, v_k2, v_k3]
-            for ax0 in range(h):
-                with T.block("T_divide_1"):
-                    v_ax0 = T.axis.spatial(h, ax0)
-                    T.reads(T_multiply_red[v_ax0])
-                    T.writes(T_divide_1[v_ax0])
-                    T_divide_1[v_ax0] = T_multiply_red[v_ax0] / 
T.Cast("float32", n * w * c)
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), 
T.int64(1)):
-                with T.block("T_reshape_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % 
h])
-                    T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), 
T.int64(1)):
-                with T.block("T_add"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, 
v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
-            for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1), 
T.int64(1)):
-                with T.block("compute"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
-                    T.writes(compute[v_i0, v_i1, v_i2, v_i3])
-                    compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, 
v_i1, v_i2, v_i3])
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_divide_2"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], 
compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, 
v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), 
T.int64(1)):
-                with T.block("T_reshape_2"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + 
v_ax3) % c])
-                    T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_multiply_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), 
T.int64(1)):
-                with T.block("T_reshape_3"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + 
v_ax3) % c])
-                    T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_add_1"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, 
v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
-            for ax0 in range(c):
-                with T.block("T_multiply_2"):
-                    v_ax0 = T.axis.spatial(c, ax0)
-                    T.reads(rxplaceholder_3[v_ax0])
-                    T.writes(T_multiply_2[v_ax0])
-                    T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * 
rxplaceholder_3[v_ax0]
-            for ax0 in range(h):
-                with T.block("T_multiply_3"):
-                    v_ax0 = T.axis.spatial(h, ax0)
-                    T.reads(T_divide[v_ax0])
-                    T.writes(T_multiply_3[v_ax0])
-                    T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * 
T_divide[v_ax0]
-            for ax0 in range(T.max(c, h)):
-                with T.block("T_add_2"):
-                    v_ax0 = T.axis.spatial(T.max(c, h), ax0)
-                    T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
-                    T.writes(T_add_1[v_ax0])
-                    T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0]
-            for ax0 in range(c):
-                with T.block("T_multiply_4"):
-                    v_ax0 = T.axis.spatial(c, ax0)
-                    T.reads(rxplaceholder_4[v_ax0])
-                    T.writes(T_multiply_4[v_ax0])
-                    T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * 
rxplaceholder_4[v_ax0]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_subtract_3"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_subtract_4"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                    T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
-            for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
-                with T.block("T_multiply_5"):
-                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], 
T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, 
v_ax3]
-            for ax0, k0, k2, k3 in T.grid(h, n, w, c):
-                with T.block("T_multiply_red_1"):
-                    v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, 
k2, k3])
-                    T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3])
-                    T.writes(T_multiply_red_1[v_ax0])
-                    with T.init():
-                        T_multiply_red_1[v_ax0] = T.float32(0)
-                    T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + 
T_multiply_5[v_k0, v_ax0, v_k2, v_k3]
-            for ax0 in range(h):
-                with T.block("T_divide_3"):
-                    v_ax0 = T.axis.spatial(h, ax0)
-                    T.reads(T_multiply_red_1[v_ax0])
-                    T.writes(T_divide_3[v_ax0])
-                    T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] / 
T.Cast("float32", n * w * c)
-            for ax0 in range(h):
-                with T.block("T_multiply_6"):
-                    v_ax0 = T.axis.spatial(h, ax0)
-                    T.reads(T_divide_3[v_ax0])
-                    T.writes(T_multiply_6[v_ax0])
-                    T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_3[v_ax0]
-            for ax0 in range(T.max(c, h)):
-                with T.block("T_add_3"):
-                    v_ax0 = T.axis.spatial(T.max(c, h), ax0)
-                    T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0])
-                    T.writes(T_add_2[v_ax0])
-                    T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0]
-
-        @R.function
-        def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: 
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), 
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), 
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), 
R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",), 
dtype="float32")):
+            with T.block("root"):
+                T.reads()
+                T.writes()
+                T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
+                T_subtract = T.alloc_buffer((n, h, w, c))
+                T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
+                T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
+                compute = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
+                T_divide = T.alloc_buffer((n, h, w, c))
+                T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
+                T_multiply = T.alloc_buffer((n, h, w, c))
+                T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
+                T_multiply_1 = T.alloc_buffer((c,))
+                x_red = T.alloc_buffer((h,))
+                T_divide_1 = T.alloc_buffer((h,))
+                T_multiply_2 = T.alloc_buffer((h,))
+                T_multiply_3 = T.alloc_buffer((c,))
+                T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
+                T_subtract_1 = T.alloc_buffer((n, h, w, c))
+                T_subtract_2 = T.alloc_buffer((n, h, w, c))
+                T_multiply_4 = T.alloc_buffer((n, h, w, c))
+                T_multiply_red = T.alloc_buffer((h,))
+                T_divide_2 = T.alloc_buffer((h,))
+                T_multiply_5 = T.alloc_buffer((h,))
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(h):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(moving_mean[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % c])
+                                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_subtract"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(h):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_1"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(moving_var[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % c])
+                                    T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(h):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_add"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T.writes(T_add_3[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
+                for i0 in range(T.int64(1)):
+                    for i1 in range(h):
+                        for i2 in range(T.int64(1)):
+                            for i3 in range(T.int64(1)):
+                                with T.block("compute"):
+                                    v_i0 = T.axis.spatial(T.int64(1), i0)
+                                    v_i1 = T.axis.spatial(h, i1)
+                                    v_i2 = T.axis.spatial(T.int64(1), i2)
+                                    v_i3 = T.axis.spatial(T.int64(1), i3)
+                                    T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
+                                    T.writes(compute[v_i0, v_i1, v_i2, v_i3])
+                                    compute[v_i0, v_i1, v_i2, v_i3] = 
T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3])
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_divide"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(h):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_2"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(gamma[(v_ax0 * h + v_ax1 + v_ax2 + 
v_ax3) % c])
+                                    T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_multiply"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(h):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_3"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(beta[(v_ax0 * h + v_ax1 + v_ax2 + 
v_ax3) % c])
+                                    T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_add_1"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+                                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                for ax0 in range(c):
+                    with T.block("T_multiply_1"):
+                        v_ax0 = T.axis.spatial(c, ax0)
+                        T.reads(moving_mean[v_ax0])
+                        T.writes(T_multiply_1[v_ax0])
+                        T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * 
moving_mean[v_ax0]
+                for ax0 in range(h):
+                    for k0 in range(n):
+                        for k2 in range(w):
+                            for k3 in range(c):
+                                with T.block("x_red"):
+                                    v_ax0 = T.axis.spatial(h, ax0)
+                                    v_k0 = T.axis.reduce(n, k0)
+                                    v_k2 = T.axis.reduce(w, k2)
+                                    v_k3 = T.axis.reduce(c, k3)
+                                    T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+                                    T.writes(x_red[v_ax0])
+                                    with T.init():
+                                        x_red[v_ax0] = T.float32(0.0)
+                                    x_red[v_ax0] = x_red[v_ax0] + x[v_k0, 
v_ax0, v_k2, v_k3]
+                for ax0 in range(h):
+                    with T.block("T_divide_1"):
+                        v_ax0 = T.axis.spatial(h, ax0)
+                        T.reads(x_red[v_ax0])
+                        T.writes(T_divide_1[v_ax0])
+                        T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n 
* w * c)
+                for ax0 in range(h):
+                    with T.block("T_multiply_2"):
+                        v_ax0 = T.axis.spatial(h, ax0)
+                        T.reads(T_divide_1[v_ax0])
+                        T.writes(T_multiply_2[v_ax0])
+                        T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_1[v_ax0]
+                for ax0 in range(T.max(c, h)):
+                    with T.block("T_add_2"):
+                        v_ax0 = T.axis.spatial(T.max(c, h), ax0)
+                        T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+                        T.writes(T_add_1[v_ax0])
+                        T_add_1[v_ax0] = T_multiply_1[v_ax0] + 
T_multiply_2[v_ax0]
+                for ax0 in range(c):
+                    with T.block("T_multiply_3"):
+                        v_ax0 = T.axis.spatial(c, ax0)
+                        T.reads(moving_var[v_ax0])
+                        T.writes(T_multiply_3[v_ax0])
+                        T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * 
moving_var[v_ax0]
+                for ax0 in range(T.int64(1)):
+                    for ax1 in range(h):
+                        for ax2 in range(T.int64(1)):
+                            for ax3 in range(T.int64(1)):
+                                with T.block("T_reshape_4"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
+                                    T.reads(T_divide_1[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % h])
+                                    T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_subtract_1"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_subtract_2"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_multiply_4"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+                                    T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                for ax0 in range(h):
+                    for k0 in range(n):
+                        for k2 in range(w):
+                            for k3 in range(c):
+                                with T.block("T_multiply_red"):
+                                    v_ax0 = T.axis.spatial(h, ax0)
+                                    v_k0 = T.axis.reduce(n, k0)
+                                    v_k2 = T.axis.reduce(w, k2)
+                                    v_k3 = T.axis.reduce(c, k3)
+                                    T.reads(T_multiply_4[v_k0, v_ax0, v_k2, 
v_k3])
+                                    T.writes(T_multiply_red[v_ax0])
+                                    with T.init():
+                                        T_multiply_red[v_ax0] = T.float32(0.0)
+                                    T_multiply_red[v_ax0] = 
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
+                for ax0 in range(h):
+                    with T.block("T_divide_2"):
+                        v_ax0 = T.axis.spatial(h, ax0)
+                        T.reads(T_multiply_red[v_ax0])
+                        T.writes(T_divide_2[v_ax0])
+                        T_divide_2[v_ax0] = T_multiply_red[v_ax0] / 
T.Cast("float32", n * w * c)
+                for ax0 in range(h):
+                    with T.block("T_multiply_5"):
+                        v_ax0 = T.axis.spatial(h, ax0)
+                        T.reads(T_divide_2[v_ax0])
+                        T.writes(T_multiply_5[v_ax0])
+                        T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_2[v_ax0]
+                for ax0 in range(T.max(c, h)):
+                    with T.block("T_add_3"):
+                        v_ax0 = T.axis.spatial(T.max(c, h), ax0)
+                        T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+                        T.writes(T_add_2[v_ax0])
+                        T_add_2[v_ax0] = T_multiply_3[v_ax0] + 
T_multiply_5[v_ax0]
+
+        @R.function
+        def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: 
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), 
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), 
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), 
R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), 
dtype="float32")):
             n = T.int64()
             h = T.int64()
             w = T.int64()
             c = T.int64()
-            gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, 
moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), 
R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), 
dtype="float32")])
+            cls = Expected
+            gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, 
moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), 
R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), 
dtype="float32")])
             return gv
-    # fmt: on
 
     mod = LegalizeOps()(BatchNorm)
     tvm.ir.assert_structural_equal(mod, Expected)


Reply via email to