This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5dc4e785d9 [Relax] Fix batch normalization computation logic (#18609)
5dc4e785d9 is described below

commit 5dc4e785d98892d4061b5c95d741f44b81261bf4
Author: Asuka <[email protected]>
AuthorDate: Fri Dec 26 20:57:48 2025 +0800

    [Relax] Fix batch normalization computation logic (#18609)
    
    Dear reviewers,
    
    **Why**
    The previous implementation of batch_norm incorrectly conflated the
    computation of mean and variance between training and evaluation modes.
    Additionally, for '_native_batch_norm_legit.no_statsā€˜, using
    instance_norm to handle normalization ignored the batch dimension,
    leading to incorrect behavior.
    
    **How**
    This PR includes the following fixes:
    1. Corrects the computation logic to properly distinguish between
    training and evaluation modes.
    2. Ensures the batch dimension is properly accounted for in
    `_batch_norm_legit_no_stats`.
    
    **Environment**
    GPU: NVIDIA A100-SXM4-80GB
---
 .../frontend/torch/exported_program_translator.py  |  18 +-
 python/tvm/topi/nn/batch_norm.py                   |  31 +-
 .../relax/test_frontend_from_exported_program.py   |   6 +-
 .../python/relax/test_transform_legalize_ops_nn.py | 490 ++++++++++-----------
 4 files changed, 260 insertions(+), 285 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 94df0282c8..b6b9723c13 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -121,6 +121,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         x = self.env[node.args[0]]
         channel = int(self.shape_of(x)[1])
         dtype = x.struct_info.dtype
+        scale = node.args[1] is not None
+        center = node.args[2] is not None
         weight = self.env.get(node.args[1], relax.const(np.ones(channel), 
dtype=dtype))
         bias = self.env.get(node.args[2], relax.const(np.zeros(channel), 
dtype=dtype))
         running_mean = self.env.get(node.args[3], 
relax.const(np.zeros(channel), dtype=dtype))
@@ -134,10 +136,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             eps = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("eps", 1e-05)
             training = False
         elif target_name.startswith("_native_batch_norm_legit_functional"):
-            momentum = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("momentum", 0.1)
-            eps = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("eps", 1e-05)
-            training = True
-        else:
             ignore_running_stats = (
                 node.args[5] if len(node.args) > 5 else 
node.kwargs.get("track_running_stats", True)
             )
@@ -147,6 +145,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
             if track_running_stats:
                 training = True
+        else:
+            momentum = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("momentum", 0.1)
+            eps = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("eps", 1e-05)
+            training = True
 
         bn_result = self.block_builder.emit(
             relax.op.nn.batch_norm(
@@ -157,6 +159,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                 moving_var=running_var,
                 axis=1,  # Always over channel
                 epsilon=eps,
+                scale=scale,
+                center=center,
                 momentum=momentum,
                 training=training,
             )
@@ -197,9 +201,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         bias = self.env.get(node.args[2], relax.const(np.zeros(channel), 
dtype=dtype))
         eps = node.args[5] if len(node.args) > 5 else node.kwargs.get("eps", 
1e-05)
 
-        # Determine axes for instance norm (all spatial dimensions after 
channel)
+        # Shared by InstanceNorm (view as [1, N*C, H, W])
+        # and eval-mode BatchNorm without track_running_stats
+        # Determine axes for instance norm (all spatial dimensions after 
channel and batch dim)
         dim = len(self.shape_of(x))
-        axes = list(range(2, dim))
+        axes = [0] + list(range(2, dim))
 
         return self.block_builder.emit(
             relax.op.nn.instance_norm(
diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py
index 8308c93eae..6bd43d29cb 100644
--- a/python/tvm/topi/nn/batch_norm.py
+++ b/python/tvm/topi/nn/batch_norm.py
@@ -110,27 +110,25 @@ def batch_norm(
 
     shape = [1] * len(data.shape)
     shape[axis] = data.shape[axis]
-
-    reduce_axes = list(range(len(data.shape)))
-    reduce_axes.remove(axis)
-    shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in 
reduce_axes], 1)
-
-    data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
-    data_mean_rs = topi.reshape(data_mean, shape)
-    data_var = (
-        topi.sum((data - data_mean_rs) * (data - data_mean_rs), 
axis=reduce_axes) / shape_prod
-    )
-    data_var_rs = topi.reshape(data_var, shape)
+    data_mean = None
+    data_var = None
 
     if training:
+        reduce_axes = list(range(len(data.shape)))
+        reduce_axes.remove(axis)
+        shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in 
reduce_axes], 1)
+        data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
+        data_mean_rs = topi.reshape(data_mean, shape)
+        data_var = (
+            topi.sum((data - data_mean_rs) * (data - data_mean_rs), 
axis=reduce_axes) / shape_prod
+        )
+        data_var_rs = topi.reshape(data_var, shape)
+        out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
+    else:
         moving_mean_rs = topi.reshape(moving_mean, shape)
         moving_var_rs = topi.reshape(moving_var, shape)
-
         out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
 
-    else:
-        out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
-
     if scale:
         out = out * topi.reshape(gamma, shape)
     if center:
@@ -138,9 +136,6 @@ def batch_norm(
 
     if training:
         assert 0 <= momentum <= 1, "the valid momentum range is [0, 1]."
-        data_var = (
-            topi.sum((data - data_mean_rs) * (data - data_mean_rs), 
axis=reduce_axes) / shape_prod
-        )
         return [
             out,
             (1 - momentum) * moving_mean + momentum * data_mean,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 7894a9fb6d..9f8842ddcb 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1913,10 +1913,10 @@ def test_batchnorm2d():
                     w3,
                     w4,
                     axis=1,
-                    epsilon=0.1,
+                    epsilon=1e-5,
                     center=True,
                     scale=True,
-                    momentum=1.0,
+                    momentum=0.1,
                     training=True,
                 )
                 lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0]
@@ -3607,7 +3607,7 @@ def test_instancenorm2d():
                     w1,
                     w2,
                     channel_axis=1,
-                    axes=[2, 3],
+                    axes=[0, 2, 3],
                     epsilon=1e-05,
                     center=True,
                     scale=True,
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index de2f183a10..e81e1bab2a 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2046,27 +2046,46 @@ def test_batch_norm():
             with T.block("root"):
                 T.reads()
                 T.writes()
+                x_red = T.alloc_buffer((T.int64(3),))
+                T_divide = T.alloc_buffer((T.int64(3),))
                 T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
                 T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_multiply_red = T.alloc_buffer((T.int64(3),))
+                T_divide_1 = T.alloc_buffer((T.int64(3),))
                 T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
                 T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
                 compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), 
T.int64(1)))
-                T_divide = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
                 T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
-                T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
+                T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
                 T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
-                T_multiply_1 = T.alloc_buffer((T.int64(3),))
-                x_red = T.alloc_buffer((T.int64(3),))
-                T_divide_1 = T.alloc_buffer((T.int64(3),))
                 T_multiply_2 = T.alloc_buffer((T.int64(3),))
                 T_multiply_3 = T.alloc_buffer((T.int64(3),))
-                T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3), 
T.int64(1), T.int64(1)))
-                T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-                T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-                T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)))
-                T_multiply_red = T.alloc_buffer((T.int64(3),))
-                T_divide_2 = T.alloc_buffer((T.int64(3),))
+                T_multiply_4 = T.alloc_buffer((T.int64(3),))
                 T_multiply_5 = T.alloc_buffer((T.int64(3),))
+                for ax0 in range(T.int64(3)):
+                    for k0 in range(T.int64(2)):
+                        for k2 in range(T.int64(28)):
+                            for k3 in range(T.int64(28)):
+                                with T.block("x_red"):
+                                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                                    v_k0 = T.axis.reduce(T.int64(2), k0)
+                                    v_k2 = T.axis.reduce(T.int64(28), k2)
+                                    v_k3 = T.axis.reduce(T.int64(28), k3)
+                                    T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+                                    T.writes(x_red[v_ax0])
+                                    with T.init():
+                                        x_red[v_ax0] = T.float32(0.0)
+                                    x_red[v_ax0] = x_red[v_ax0] + x[v_k0, 
v_ax0, v_k2, v_k3]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_divide"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(x_red[v_ax0])
+                        T.writes(T_divide[v_ax0])
+                        T_divide[v_ax0] = x_red[v_ax0] / T.float32(1568.0)
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(1)):
@@ -2076,9 +2095,9 @@ def test_batch_norm():
                                     v_ax1 = T.axis.spatial(T.int64(3), ax1)
                                     v_ax2 = T.axis.spatial(T.int64(1), ax2)
                                     v_ax3 = T.axis.spatial(T.int64(1), ax3)
-                                    T.reads(moving_mean[(v_ax1 + v_ax2 + 
v_ax3) % T.int64(3)])
+                                    T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) % 
T.int64(3)])
                                     T.writes(T_reshape[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+                                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
                 for ax0 in range(T.int64(2)):
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(28)):
@@ -2091,6 +2110,62 @@ def test_batch_norm():
                                     T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
                                     T.writes(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3])
                                     T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_subtract_1"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_subtract_2"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(T.int64(2)):
+                    for ax1 in range(T.int64(3)):
+                        for ax2 in range(T.int64(28)):
+                            for ax3 in range(T.int64(28)):
+                                with T.block("T_multiply"):
+                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
+                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
+                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
+                                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+                                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                for ax0 in range(T.int64(3)):
+                    for k0 in range(T.int64(2)):
+                        for k2 in range(T.int64(28)):
+                            for k3 in range(T.int64(28)):
+                                with T.block("T_multiply_red"):
+                                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                                    v_k0 = T.axis.reduce(T.int64(2), k0)
+                                    v_k2 = T.axis.reduce(T.int64(28), k2)
+                                    v_k3 = T.axis.reduce(T.int64(28), k3)
+                                    T.reads(T_multiply[v_k0, v_ax0, v_k2, 
v_k3])
+                                    T.writes(T_multiply_red[v_ax0])
+                                    with T.init():
+                                        T_multiply_red[v_ax0] = T.float32(0.0)
+                                    T_multiply_red[v_ax0] = 
T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3]
+                for ax0 in range(T.int64(3)):
+                    with T.block("T_divide_1"):
+                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                        T.reads(T_multiply_red[v_ax0])
+                        T.writes(T_divide_1[v_ax0])
+                        T_divide_1[v_ax0] = T_multiply_red[v_ax0] / 
T.float32(1568.0)
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(1)):
@@ -2100,9 +2175,9 @@ def test_batch_norm():
                                     v_ax1 = T.axis.spatial(T.int64(3), ax1)
                                     v_ax2 = T.axis.spatial(T.int64(1), ax2)
                                     v_ax3 = T.axis.spatial(T.int64(1), ax3)
-                                    T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3) 
% T.int64(3)])
+                                    T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) 
% T.int64(3)])
                                     T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+                                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(1)):
@@ -2131,14 +2206,14 @@ def test_batch_norm():
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(28)):
                             for ax3 in range(T.int64(28)):
-                                with T.block("T_divide"):
+                                with T.block("T_divide_2"):
                                     v_ax0 = T.axis.spatial(T.int64(2), ax0)
                                     v_ax1 = T.axis.spatial(T.int64(3), ax1)
                                     v_ax2 = T.axis.spatial(T.int64(28), ax2)
                                     v_ax3 = T.axis.spatial(T.int64(28), ax3)
                                     T.reads(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                                    T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(1)):
@@ -2155,14 +2230,14 @@ def test_batch_norm():
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(28)):
                             for ax3 in range(T.int64(28)):
-                                with T.block("T_multiply"):
+                                with T.block("T_multiply_1"):
                                     v_ax0 = T.axis.spatial(T.int64(2), ax0)
                                     v_ax1 = T.axis.spatial(T.int64(3), ax1)
                                     v_ax2 = T.axis.spatial(T.int64(28), ax2)
                                     v_ax3 = T.axis.spatial(T.int64(28), ax3)
-                                    T.reads(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                                    T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(T.int64(3)):
                         for ax2 in range(T.int64(1)):
@@ -2184,133 +2259,45 @@ def test_batch_norm():
                                     v_ax1 = T.axis.spatial(T.int64(3), ax1)
                                     v_ax2 = T.axis.spatial(T.int64(28), ax2)
                                     v_ax3 = T.axis.spatial(T.int64(28), ax3)
-                                    T.reads(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
                                     T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
-                                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
                 for ax0 in range(T.int64(3)):
-                    with T.block("T_multiply_1"):
+                    with T.block("T_multiply_2"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
                         T.reads(moving_mean[v_ax0])
-                        T.writes(T_multiply_1[v_ax0])
-                        T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * 
moving_mean[v_ax0]
-                for ax0 in range(T.int64(3)):
-                    for k0 in range(T.int64(2)):
-                        for k2 in range(T.int64(28)):
-                            for k3 in range(T.int64(28)):
-                                with T.block("x_red"):
-                                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                                    v_k0 = T.axis.reduce(T.int64(2), k0)
-                                    v_k2 = T.axis.reduce(T.int64(28), k2)
-                                    v_k3 = T.axis.reduce(T.int64(28), k3)
-                                    T.reads(x[v_k0, v_ax0, v_k2, v_k3])
-                                    T.writes(x_red[v_ax0])
-                                    with T.init():
-                                        x_red[v_ax0] = T.float32(0.0)
-                                    x_red[v_ax0] = x_red[v_ax0] + x[v_k0, 
v_ax0, v_k2, v_k3]
-                for ax0 in range(T.int64(3)):
-                    with T.block("T_divide_1"):
-                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                        T.reads(x_red[v_ax0])
-                        T.writes(T_divide_1[v_ax0])
-                        T_divide_1[v_ax0] = x_red[v_ax0] / T.float32(1568)
+                        T.writes(T_multiply_2[v_ax0])
+                        T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * 
moving_mean[v_ax0]
                 for ax0 in range(T.int64(3)):
-                    with T.block("T_multiply_2"):
+                    with T.block("T_multiply_3"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                        T.reads(T_divide_1[v_ax0])
-                        T.writes(T_multiply_2[v_ax0])
-                        T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_1[v_ax0]
+                        T.reads(T_divide[v_ax0])
+                        T.writes(T_multiply_3[v_ax0])
+                        T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * 
T_divide[v_ax0]
                 for ax0 in range(T.int64(3)):
                     with T.block("T_add_2"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                        T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+                        T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
                         T.writes(T_add_1[v_ax0])
-                        T_add_1[v_ax0] = T_multiply_1[v_ax0] + 
T_multiply_2[v_ax0]
+                        T_add_1[v_ax0] = T_multiply_2[v_ax0] + 
T_multiply_3[v_ax0]
                 for ax0 in range(T.int64(3)):
-                    with T.block("T_multiply_3"):
+                    with T.block("T_multiply_4"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
                         T.reads(moving_var[v_ax0])
-                        T.writes(T_multiply_3[v_ax0])
-                        T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * 
moving_var[v_ax0]
-                for ax0 in range(T.int64(1)):
-                    for ax1 in range(T.int64(3)):
-                        for ax2 in range(T.int64(1)):
-                            for ax3 in range(T.int64(1)):
-                                with T.block("T_reshape_4"):
-                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
-                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
-                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
-                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
-                                    T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) 
% T.int64(3)])
-                                    T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
-                for ax0 in range(T.int64(2)):
-                    for ax1 in range(T.int64(3)):
-                        for ax2 in range(T.int64(28)):
-                            for ax3 in range(T.int64(28)):
-                                with T.block("T_subtract_1"):
-                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
-                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
-                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
-                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
-                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
-                for ax0 in range(T.int64(2)):
-                    for ax1 in range(T.int64(3)):
-                        for ax2 in range(T.int64(28)):
-                            for ax3 in range(T.int64(28)):
-                                with T.block("T_subtract_2"):
-                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
-                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
-                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
-                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
-                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
-                for ax0 in range(T.int64(2)):
-                    for ax1 in range(T.int64(3)):
-                        for ax2 in range(T.int64(28)):
-                            for ax3 in range(T.int64(28)):
-                                with T.block("T_multiply_4"):
-                                    v_ax0 = T.axis.spatial(T.int64(2), ax0)
-                                    v_ax1 = T.axis.spatial(T.int64(3), ax1)
-                                    v_ax2 = T.axis.spatial(T.int64(28), ax2)
-                                    v_ax3 = T.axis.spatial(T.int64(28), ax3)
-                                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                                    T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
-                for ax0 in range(T.int64(3)):
-                    for k0 in range(T.int64(2)):
-                        for k2 in range(T.int64(28)):
-                            for k3 in range(T.int64(28)):
-                                with T.block("T_multiply_red"):
-                                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                                    v_k0 = T.axis.reduce(T.int64(2), k0)
-                                    v_k2 = T.axis.reduce(T.int64(28), k2)
-                                    v_k3 = T.axis.reduce(T.int64(28), k3)
-                                    T.reads(T_multiply_4[v_k0, v_ax0, v_k2, 
v_k3])
-                                    T.writes(T_multiply_red[v_ax0])
-                                    with T.init():
-                                        T_multiply_red[v_ax0] = T.float32(0.0)
-                                    T_multiply_red[v_ax0] = 
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
-                for ax0 in range(T.int64(3)):
-                    with T.block("T_divide_2"):
-                        v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                        T.reads(T_multiply_red[v_ax0])
-                        T.writes(T_divide_2[v_ax0])
-                        T_divide_2[v_ax0] = T_multiply_red[v_ax0] / 
T.float32(1568)
+                        T.writes(T_multiply_4[v_ax0])
+                        T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * 
moving_var[v_ax0]
                 for ax0 in range(T.int64(3)):
                     with T.block("T_multiply_5"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                        T.reads(T_divide_2[v_ax0])
+                        T.reads(T_divide_1[v_ax0])
                         T.writes(T_multiply_5[v_ax0])
-                        T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_2[v_ax0]
+                        T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_1[v_ax0]
                 for ax0 in range(T.int64(3)):
                     with T.block("T_add_3"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
-                        T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+                        T.reads(T_multiply_4[v_ax0], T_multiply_5[v_ax0])
                         T.writes(T_add_2[v_ax0])
-                        T_add_2[v_ax0] = T_multiply_3[v_ax0] + 
T_multiply_5[v_ax0]
+                        T_add_2[v_ax0] = T_multiply_4[v_ax0] + 
T_multiply_5[v_ax0]
 
         @R.function
         def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: 
R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), 
moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), 
dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), 
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")):
@@ -2353,27 +2340,46 @@ def test_batch_norm_symbolic():
             with T.block("root"):
                 T.reads()
                 T.writes()
+                x_red = T.alloc_buffer((h,))
+                T_divide = T.alloc_buffer((h,))
                 T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
                 T_subtract = T.alloc_buffer((n, h, w, c))
+                T_subtract_1 = T.alloc_buffer((n, h, w, c))
+                T_subtract_2 = T.alloc_buffer((n, h, w, c))
+                T_multiply = T.alloc_buffer((n, h, w, c))
+                T_multiply_red = T.alloc_buffer((h,))
+                T_divide_1 = T.alloc_buffer((h,))
                 T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
                 T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
                 compute = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
-                T_divide = T.alloc_buffer((n, h, w, c))
+                T_divide_2 = T.alloc_buffer((n, h, w, c))
                 T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
-                T_multiply = T.alloc_buffer((n, h, w, c))
+                T_multiply_1 = T.alloc_buffer((n, h, w, c))
                 T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
-                T_multiply_1 = T.alloc_buffer((c,))
-                x_red = T.alloc_buffer((h,))
-                T_divide_1 = T.alloc_buffer((h,))
-                T_multiply_2 = T.alloc_buffer((h,))
-                T_multiply_3 = T.alloc_buffer((c,))
-                T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1), 
T.int64(1)))
-                T_subtract_1 = T.alloc_buffer((n, h, w, c))
-                T_subtract_2 = T.alloc_buffer((n, h, w, c))
-                T_multiply_4 = T.alloc_buffer((n, h, w, c))
-                T_multiply_red = T.alloc_buffer((h,))
-                T_divide_2 = T.alloc_buffer((h,))
+                T_multiply_2 = T.alloc_buffer((c,))
+                T_multiply_3 = T.alloc_buffer((h,))
+                T_multiply_4 = T.alloc_buffer((c,))
                 T_multiply_5 = T.alloc_buffer((h,))
+                for ax0 in range(h):
+                    for k0 in range(n):
+                        for k2 in range(w):
+                            for k3 in range(c):
+                                with T.block("x_red"):
+                                    v_ax0 = T.axis.spatial(h, ax0)
+                                    v_k0 = T.axis.reduce(n, k0)
+                                    v_k2 = T.axis.reduce(w, k2)
+                                    v_k3 = T.axis.reduce(c, k3)
+                                    T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+                                    T.writes(x_red[v_ax0])
+                                    with T.init():
+                                        x_red[v_ax0] = T.float32(0.0)
+                                    x_red[v_ax0] = x_red[v_ax0] + x[v_k0, 
v_ax0, v_k2, v_k3]
+                for ax0 in range(h):
+                    with T.block("T_divide"):
+                        v_ax0 = T.axis.spatial(h, ax0)
+                        T.reads(x_red[v_ax0])
+                        T.writes(T_divide[v_ax0])
+                        T_divide[v_ax0] = x_red[v_ax0] / T.Cast("float32", n * 
w * c)
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(h):
                         for ax2 in range(T.int64(1)):
@@ -2383,9 +2389,9 @@ def test_batch_norm_symbolic():
                                     v_ax1 = T.axis.spatial(h, ax1)
                                     v_ax2 = T.axis.spatial(T.int64(1), ax2)
                                     v_ax3 = T.axis.spatial(T.int64(1), ax3)
-                                    T.reads(moving_mean[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % c])
+                                    T.reads(T_divide[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % h])
                                     T.writes(T_reshape[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+                                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
                 for ax0 in range(n):
                     for ax1 in range(h):
                         for ax2 in range(w):
@@ -2398,6 +2404,62 @@ def test_batch_norm_symbolic():
                                     T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
                                     T.writes(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3])
                                     T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_subtract_1"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_subtract_2"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                for ax0 in range(n):
+                    for ax1 in range(h):
+                        for ax2 in range(w):
+                            for ax3 in range(c):
+                                with T.block("T_multiply"):
+                                    v_ax0 = T.axis.spatial(n, ax0)
+                                    v_ax1 = T.axis.spatial(h, ax1)
+                                    v_ax2 = T.axis.spatial(w, ax2)
+                                    v_ax3 = T.axis.spatial(c, ax3)
+                                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+                                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                for ax0 in range(h):
+                    for k0 in range(n):
+                        for k2 in range(w):
+                            for k3 in range(c):
+                                with T.block("T_multiply_red"):
+                                    v_ax0 = T.axis.spatial(h, ax0)
+                                    v_k0 = T.axis.reduce(n, k0)
+                                    v_k2 = T.axis.reduce(w, k2)
+                                    v_k3 = T.axis.reduce(c, k3)
+                                    T.reads(T_multiply[v_k0, v_ax0, v_k2, 
v_k3])
+                                    T.writes(T_multiply_red[v_ax0])
+                                    with T.init():
+                                        T_multiply_red[v_ax0] = T.float32(0.0)
+                                    T_multiply_red[v_ax0] = 
T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3]
+                for ax0 in range(h):
+                    with T.block("T_divide_1"):
+                        v_ax0 = T.axis.spatial(h, ax0)
+                        T.reads(T_multiply_red[v_ax0])
+                        T.writes(T_divide_1[v_ax0])
+                        T_divide_1[v_ax0] = T_multiply_red[v_ax0] / 
T.Cast("float32", n * w * c)
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(h):
                         for ax2 in range(T.int64(1)):
@@ -2407,9 +2469,9 @@ def test_batch_norm_symbolic():
                                     v_ax1 = T.axis.spatial(h, ax1)
                                     v_ax2 = T.axis.spatial(T.int64(1), ax2)
                                     v_ax3 = T.axis.spatial(T.int64(1), ax3)
-                                    T.reads(moving_var[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % c])
+                                    T.reads(T_divide_1[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % h])
                                     T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+                                    T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(h):
                         for ax2 in range(T.int64(1)):
@@ -2438,14 +2500,14 @@ def test_batch_norm_symbolic():
                     for ax1 in range(h):
                         for ax2 in range(w):
                             for ax3 in range(c):
-                                with T.block("T_divide"):
+                                with T.block("T_divide_2"):
                                     v_ax0 = T.axis.spatial(n, ax0)
                                     v_ax1 = T.axis.spatial(h, ax1)
                                     v_ax2 = T.axis.spatial(w, ax2)
                                     v_ax3 = T.axis.spatial(c, ax3)
                                     T.reads(T_subtract[v_ax0, v_ax1, v_ax2, 
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+                                    T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(h):
                         for ax2 in range(T.int64(1)):
@@ -2462,14 +2524,14 @@ def test_batch_norm_symbolic():
                     for ax1 in range(h):
                         for ax2 in range(w):
                             for ax3 in range(c):
-                                with T.block("T_multiply"):
+                                with T.block("T_multiply_1"):
                                     v_ax0 = T.axis.spatial(n, ax0)
                                     v_ax1 = T.axis.spatial(h, ax1)
                                     v_ax2 = T.axis.spatial(w, ax2)
                                     v_ax3 = T.axis.spatial(c, ax3)
-                                    T.reads(T_divide[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                                    T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                    T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
                 for ax0 in range(T.int64(1)):
                     for ax1 in range(h):
                         for ax2 in range(T.int64(1)):
@@ -2491,133 +2553,45 @@ def test_batch_norm_symbolic():
                                     v_ax1 = T.axis.spatial(h, ax1)
                                     v_ax2 = T.axis.spatial(w, ax2)
                                     v_ax3 = T.axis.spatial(c, ax3)
-                                    T.reads(T_multiply[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                                    T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
                                     T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
-                                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
+                                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, 
T.int64(0), T.int64(0)]
                 for ax0 in range(c):
-                    with T.block("T_multiply_1"):
+                    with T.block("T_multiply_2"):
                         v_ax0 = T.axis.spatial(c, ax0)
                         T.reads(moving_mean[v_ax0])
-                        T.writes(T_multiply_1[v_ax0])
-                        T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * 
moving_mean[v_ax0]
-                for ax0 in range(h):
-                    for k0 in range(n):
-                        for k2 in range(w):
-                            for k3 in range(c):
-                                with T.block("x_red"):
-                                    v_ax0 = T.axis.spatial(h, ax0)
-                                    v_k0 = T.axis.reduce(n, k0)
-                                    v_k2 = T.axis.reduce(w, k2)
-                                    v_k3 = T.axis.reduce(c, k3)
-                                    T.reads(x[v_k0, v_ax0, v_k2, v_k3])
-                                    T.writes(x_red[v_ax0])
-                                    with T.init():
-                                        x_red[v_ax0] = T.float32(0.0)
-                                    x_red[v_ax0] = x_red[v_ax0] + x[v_k0, 
v_ax0, v_k2, v_k3]
-                for ax0 in range(h):
-                    with T.block("T_divide_1"):
-                        v_ax0 = T.axis.spatial(h, ax0)
-                        T.reads(x_red[v_ax0])
-                        T.writes(T_divide_1[v_ax0])
-                        T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n 
* w * c)
+                        T.writes(T_multiply_2[v_ax0])
+                        T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * 
moving_mean[v_ax0]
                 for ax0 in range(h):
-                    with T.block("T_multiply_2"):
+                    with T.block("T_multiply_3"):
                         v_ax0 = T.axis.spatial(h, ax0)
-                        T.reads(T_divide_1[v_ax0])
-                        T.writes(T_multiply_2[v_ax0])
-                        T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_1[v_ax0]
+                        T.reads(T_divide[v_ax0])
+                        T.writes(T_multiply_3[v_ax0])
+                        T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * 
T_divide[v_ax0]
                 for ax0 in range(T.max(c, h)):
                     with T.block("T_add_2"):
                         v_ax0 = T.axis.spatial(T.max(c, h), ax0)
-                        T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+                        T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
                         T.writes(T_add_1[v_ax0])
-                        T_add_1[v_ax0] = T_multiply_1[v_ax0] + 
T_multiply_2[v_ax0]
+                        T_add_1[v_ax0] = T_multiply_2[v_ax0] + 
T_multiply_3[v_ax0]
                 for ax0 in range(c):
-                    with T.block("T_multiply_3"):
+                    with T.block("T_multiply_4"):
                         v_ax0 = T.axis.spatial(c, ax0)
                         T.reads(moving_var[v_ax0])
-                        T.writes(T_multiply_3[v_ax0])
-                        T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * 
moving_var[v_ax0]
-                for ax0 in range(T.int64(1)):
-                    for ax1 in range(h):
-                        for ax2 in range(T.int64(1)):
-                            for ax3 in range(T.int64(1)):
-                                with T.block("T_reshape_4"):
-                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
-                                    v_ax1 = T.axis.spatial(h, ax1)
-                                    v_ax2 = T.axis.spatial(T.int64(1), ax2)
-                                    v_ax3 = T.axis.spatial(T.int64(1), ax3)
-                                    T.reads(T_divide_1[(v_ax0 * h + v_ax1 + 
v_ax2 + v_ax3) % h])
-                                    T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
-                for ax0 in range(n):
-                    for ax1 in range(h):
-                        for ax2 in range(w):
-                            for ax3 in range(c):
-                                with T.block("T_subtract_1"):
-                                    v_ax0 = T.axis.spatial(n, ax0)
-                                    v_ax1 = T.axis.spatial(h, ax1)
-                                    v_ax2 = T.axis.spatial(w, ax2)
-                                    v_ax3 = T.axis.spatial(c, ax3)
-                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
-                for ax0 in range(n):
-                    for ax1 in range(h):
-                        for ax2 in range(w):
-                            for ax3 in range(c):
-                                with T.block("T_subtract_2"):
-                                    v_ax0 = T.axis.spatial(n, ax0)
-                                    v_ax1 = T.axis.spatial(h, ax1)
-                                    v_ax2 = T.axis.spatial(w, ax2)
-                                    v_ax3 = T.axis.spatial(c, ax3)
-                                    T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], 
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
-                                    T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
-                for ax0 in range(n):
-                    for ax1 in range(h):
-                        for ax2 in range(w):
-                            for ax3 in range(c):
-                                with T.block("T_multiply_4"):
-                                    v_ax0 = T.axis.spatial(n, ax0)
-                                    v_ax1 = T.axis.spatial(h, ax1)
-                                    v_ax2 = T.axis.spatial(w, ax2)
-                                    v_ax3 = T.axis.spatial(c, ax3)
-                                    T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, 
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                                    T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, 
v_ax3])
-                                    T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, 
v_ax3]
-                for ax0 in range(h):
-                    for k0 in range(n):
-                        for k2 in range(w):
-                            for k3 in range(c):
-                                with T.block("T_multiply_red"):
-                                    v_ax0 = T.axis.spatial(h, ax0)
-                                    v_k0 = T.axis.reduce(n, k0)
-                                    v_k2 = T.axis.reduce(w, k2)
-                                    v_k3 = T.axis.reduce(c, k3)
-                                    T.reads(T_multiply_4[v_k0, v_ax0, v_k2, 
v_k3])
-                                    T.writes(T_multiply_red[v_ax0])
-                                    with T.init():
-                                        T_multiply_red[v_ax0] = T.float32(0.0)
-                                    T_multiply_red[v_ax0] = 
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
-                for ax0 in range(h):
-                    with T.block("T_divide_2"):
-                        v_ax0 = T.axis.spatial(h, ax0)
-                        T.reads(T_multiply_red[v_ax0])
-                        T.writes(T_divide_2[v_ax0])
-                        T_divide_2[v_ax0] = T_multiply_red[v_ax0] / 
T.Cast("float32", n * w * c)
+                        T.writes(T_multiply_4[v_ax0])
+                        T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * 
moving_var[v_ax0]
                 for ax0 in range(h):
                     with T.block("T_multiply_5"):
                         v_ax0 = T.axis.spatial(h, ax0)
-                        T.reads(T_divide_2[v_ax0])
+                        T.reads(T_divide_1[v_ax0])
                         T.writes(T_multiply_5[v_ax0])
-                        T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_2[v_ax0]
+                        T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * 
T_divide_1[v_ax0]
                 for ax0 in range(T.max(c, h)):
                     with T.block("T_add_3"):
                         v_ax0 = T.axis.spatial(T.max(c, h), ax0)
-                        T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+                        T.reads(T_multiply_4[v_ax0], T_multiply_5[v_ax0])
                         T.writes(T_add_2[v_ax0])
-                        T_add_2[v_ax0] = T_multiply_3[v_ax0] + 
T_multiply_5[v_ax0]
+                        T_add_2[v_ax0] = T_multiply_4[v_ax0] + 
T_multiply_5[v_ax0]
 
         @R.function
         def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: 
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), 
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), 
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), 
R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), 
dtype="float32")):

Reply via email to