Anemone220 opened a new issue, #18703: URL: https://github.com/apache/tvm/issues/18703
# [Bug][Relax][ONNX] BatchNormalization ignores training_mode attribute, always uses training=True ## Expected behavior When importing an ONNX model with `BatchNormalization` operator that has `training_mode=0` (inference mode), TVM Relax should generate `R.nn.batch_norm(..., training=False)` and use the provided `running_mean` and `running_var` parameters for normalization. According to the [ONNX BatchNormalization specification](https://onnx.ai/onnx/operators/onnx__BatchNormalization.html): - `training_mode=0` (default): Use running statistics (inference mode) - `training_mode=1`: Compute batch statistics from input (training mode) ## Actual behavior TVM Relax ONNX frontend ignores the `training_mode` attribute and always generates `R.nn.batch_norm(..., training=True)`, causing TVM to compute batch statistics from the input tensor instead of using the provided `running_mean` and `running_var`. **Generated IR (incorrect):** ``` lv: R.Tuple(...) = R.nn.batch_norm(X, scale, bias, mean, var, axis=1, epsilon=1e-05, training=True) ^^^^^^^^^^^^ Should be training=False! ``` ## Environment - **TVM version**: 0.23.dev0 (commit hash if available) - **OS**: Ubuntu Linux - **Target**: llvm (CPU) - **Python**: 3.11 - **ONNX opset**: 15 ## Steps to reproduce ### Minimal reproduction script Save as `reproduce_bn_bug.py` and run with `python reproduce_bn_bug.py`: ```python #!/usr/bin/env python3 """ TVM BatchNormalization training_mode Bug - Minimal Reproduction """ import numpy as np import onnx from onnx import helper, TensorProto, numpy_helper import onnxruntime as ort import tvm from tvm.relax.frontend.onnx import from_onnx from tvm import relax def create_minimal_bn_model(): """Create a minimal ONNX model with only BatchNormalization (training_mode=0).""" batch, channels, height, width = 2, 3, 4, 4 epsilon = 1e-5 X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [batch, channels, height, width]) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [batch, channels, height, width]) scale = numpy_helper.from_array(np.array([1.0, 2.0, 0.5], dtype=np.float32), name='scale') bias = numpy_helper.from_array(np.array([0.0, 1.0, -1.0], dtype=np.float32), name='bias') mean = numpy_helper.from_array(np.array([0.5, 1.0, 2.0], dtype=np.float32), name='mean') var = numpy_helper.from_array(np.array([0.25, 1.0, 4.0], dtype=np.float32), name='var') bn_node = helper.make_node( 'BatchNormalization', inputs=['X', 'scale', 'bias', 'mean', 'var'], outputs=['Y'], epsilon=epsilon, momentum=0.9, training_mode=0 # KEY: inference mode! ) graph = helper.make_graph([bn_node], 'bn_test', [X], [Y], [scale, bias, mean, var]) model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 15)]) model.ir_version = 8 onnx.checker.check_model(model) return model def main(): print("Creating ONNX model with BatchNormalization (training_mode=0)...") model = create_minimal_bn_model() # Verify ONNX attribute for node in model.graph.node: if node.op_type == 'BatchNormalization': training_mode = next((a.i for a in node.attribute if a.name == 'training_mode'), 0) print(f"ONNX training_mode = {training_mode}") # Test input np.random.seed(42) input_data = np.random.randn(2, 3, 4, 4).astype(np.float32) # ONNX Runtime (reference) model_bytes = model.SerializeToString() sess = ort.InferenceSession(model_bytes, providers=['CPUExecutionProvider']) ort_output = sess.run(None, {'X': input_data})[0] print(f"ORT output sample: {ort_output[0, 0, 0, :3]}") # TVM Relax shape_dict = {'X': list(input_data.shape)} mod = from_onnx(model, shape_dict=shape_dict) # Check IR for the bug ir_text = mod.script() if 'training=True' in ir_text: print("\n[BUG] TVM IR contains training=True (should be False)!") for line in ir_text.split('\n'): if 'batch_norm' in line: print(f" {line.strip()}") # Compile and run target = tvm.target.Target("llvm") ex = tvm.compile(mod, target) device = tvm.cpu() vm = relax.VirtualMachine(ex, device) tvm_input = tvm.runtime.tensor(input_data, device=device) tvm_output = vm['main'](tvm_input).numpy() print(f"TVM output sample: {tvm_output[0, 0, 0, :3]}") # Compare max_diff = np.max(np.abs(ort_output - tvm_output)) print(f"\nMax difference (ORT vs TVM): {max_diff:.6f}") if max_diff > 0.001: print("\n[BUG CONFIRMED] TVM produces incorrect results!") return 1 return 0 if __name__ == '__main__': exit(main()) ``` ### Expected output ``` Creating ONNX model with BatchNormalization (training_mode=0)... ONNX training_mode = 0 ORT output sample: [-0.00657159 -1.2765031 0.29537117] [BUG] TVM IR contains training=True (should be False)! lv: R.Tuple(...) = R.nn.batch_norm(X, ..., training=True) TVM output sample: [ 0.66183543 -0.05753053 0.8328741 ] Max difference (ORT vs TVM): 2.758021 [BUG CONFIRMED] TVM produces incorrect results! ``` ## Triage - needs-triage -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
