Anemone220 opened a new issue, #18703:
URL: https://github.com/apache/tvm/issues/18703

   # [Bug][Relax][ONNX] BatchNormalization ignores training_mode attribute, 
always uses training=True
   
   ## Expected behavior
   
   When importing an ONNX model with `BatchNormalization` operator that has 
`training_mode=0` (inference mode), TVM Relax should generate 
`R.nn.batch_norm(..., training=False)` and use the provided `running_mean` and 
`running_var` parameters for normalization.
   
   According to the [ONNX BatchNormalization 
specification](https://onnx.ai/onnx/operators/onnx__BatchNormalization.html):
   - `training_mode=0` (default): Use running statistics (inference mode)
   - `training_mode=1`: Compute batch statistics from input (training mode)
   
   ## Actual behavior
   
   TVM Relax ONNX frontend ignores the `training_mode` attribute and always 
generates `R.nn.batch_norm(..., training=True)`, causing TVM to compute batch 
statistics from the input tensor instead of using the provided `running_mean` 
and `running_var`.
   
   **Generated IR (incorrect):**
   ```
   lv: R.Tuple(...) = R.nn.batch_norm(X, scale, bias, mean, var, axis=1, 
epsilon=1e-05, training=True)
                                                                                
         ^^^^^^^^^^^^
                                                                               
Should be training=False!
   ```
   
   ## Environment
   
   - **TVM version**: 0.23.dev0 (commit hash if available)
   - **OS**: Ubuntu Linux
   - **Target**: llvm (CPU)
   - **Python**: 3.11
   - **ONNX opset**: 15
   
   ## Steps to reproduce
   
   ### Minimal reproduction script
   
   Save as `reproduce_bn_bug.py` and run with `python reproduce_bn_bug.py`:
   
   ```python
   #!/usr/bin/env python3
   """
   TVM BatchNormalization training_mode Bug - Minimal Reproduction
   """
   
   import numpy as np
   import onnx
   from onnx import helper, TensorProto, numpy_helper
   import onnxruntime as ort
   import tvm
   from tvm.relax.frontend.onnx import from_onnx
   from tvm import relax
   
   
   def create_minimal_bn_model():
       """Create a minimal ONNX model with only BatchNormalization 
(training_mode=0)."""
       batch, channels, height, width = 2, 3, 4, 4
       epsilon = 1e-5
   
       X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [batch, 
channels, height, width])
       Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [batch, 
channels, height, width])
   
       scale = numpy_helper.from_array(np.array([1.0, 2.0, 0.5], 
dtype=np.float32), name='scale')
       bias = numpy_helper.from_array(np.array([0.0, 1.0, -1.0], 
dtype=np.float32), name='bias')
       mean = numpy_helper.from_array(np.array([0.5, 1.0, 2.0], 
dtype=np.float32), name='mean')
       var = numpy_helper.from_array(np.array([0.25, 1.0, 4.0], 
dtype=np.float32), name='var')
   
       bn_node = helper.make_node(
           'BatchNormalization',
           inputs=['X', 'scale', 'bias', 'mean', 'var'],
           outputs=['Y'],
           epsilon=epsilon,
           momentum=0.9,
           training_mode=0  # KEY: inference mode!
       )
   
       graph = helper.make_graph([bn_node], 'bn_test', [X], [Y], [scale, bias, 
mean, var])
       model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 
15)])
       model.ir_version = 8
       onnx.checker.check_model(model)
       return model
   
   
   def main():
       print("Creating ONNX model with BatchNormalization (training_mode=0)...")
       model = create_minimal_bn_model()
   
       # Verify ONNX attribute
       for node in model.graph.node:
           if node.op_type == 'BatchNormalization':
               training_mode = next((a.i for a in node.attribute if a.name == 
'training_mode'), 0)
               print(f"ONNX training_mode = {training_mode}")
   
       # Test input
       np.random.seed(42)
       input_data = np.random.randn(2, 3, 4, 4).astype(np.float32)
   
       # ONNX Runtime (reference)
       model_bytes = model.SerializeToString()
       sess = ort.InferenceSession(model_bytes, 
providers=['CPUExecutionProvider'])
       ort_output = sess.run(None, {'X': input_data})[0]
       print(f"ORT output sample: {ort_output[0, 0, 0, :3]}")
   
       # TVM Relax
       shape_dict = {'X': list(input_data.shape)}
       mod = from_onnx(model, shape_dict=shape_dict)
   
       # Check IR for the bug
       ir_text = mod.script()
       if 'training=True' in ir_text:
           print("\n[BUG] TVM IR contains training=True (should be False)!")
           for line in ir_text.split('\n'):
               if 'batch_norm' in line:
                   print(f"  {line.strip()}")
   
       # Compile and run
       target = tvm.target.Target("llvm")
       ex = tvm.compile(mod, target)
       device = tvm.cpu()
       vm = relax.VirtualMachine(ex, device)
       tvm_input = tvm.runtime.tensor(input_data, device=device)
       tvm_output = vm['main'](tvm_input).numpy()
       print(f"TVM output sample: {tvm_output[0, 0, 0, :3]}")
   
       # Compare
       max_diff = np.max(np.abs(ort_output - tvm_output))
       print(f"\nMax difference (ORT vs TVM): {max_diff:.6f}")
   
       if max_diff > 0.001:
           print("\n[BUG CONFIRMED] TVM produces incorrect results!")
           return 1
       return 0
   
   
   if __name__ == '__main__':
       exit(main())
   ```
   
   ### Expected output
   
   ```
   Creating ONNX model with BatchNormalization (training_mode=0)...
   ONNX training_mode = 0
   ORT output sample: [-0.00657159 -1.2765031   0.29537117]
   
   [BUG] TVM IR contains training=True (should be False)!
     lv: R.Tuple(...) = R.nn.batch_norm(X, ..., training=True)
   TVM output sample: [ 0.66183543 -0.05753053  0.8328741 ]
   
   Max difference (ORT vs TVM): 2.758021
   
   [BUG CONFIRMED] TVM produces incorrect results!
   ```
   
   ## Triage
   - needs-triage


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to