This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9899f9cd28 [AOT][Testing] Improve output mismatch information on test
failure (#16765)
9899f9cd28 is described below
commit 9899f9cd2801b3234437df5cd8ab10504b9608bc
Author: Andrei Hutu <[email protected]>
AuthorDate: Mon Mar 25 09:06:49 2024 +0000
[AOT][Testing] Improve output mismatch information on test failure (#16765)
Enhanced AOT test harness to include overall mismatch percentage and the
individual mismatch positions from the output tensor for debugging test
failures. Both of these are still gated behind `print_output_on_mismatch ==
True`.
I also added tests to check for the presence and correctness of this new
debug information.
Sample output:
```
Element [Position]: Actual, Reference
-------------------------------------
Element [0, 8, 8, 7]: 521.846313, 521.847412
Element [0, 9, 8, 51]: 478.874359, 478.875549
Element [0, 9, 9, 6]: 462.901001, 462.899658
Mismatched elements: 3 / 16384 (0.02%)
...
```
---
python/tvm/testing/aot.py | 48 ++++++++++++++++-------
tests/python/relay/aot/test_aot_test_harness.py | 52 ++++++++++++++++++++++++-
2 files changed, 85 insertions(+), 15 deletions(-)
diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py
index 3a117624df..959d1cf58e 100644
--- a/python/tvm/testing/aot.py
+++ b/python/tvm/testing/aot.py
@@ -476,20 +476,40 @@ def _emit_main_compare(
if print_output_on_mismatch:
main_file.write(
- f"int mismatch = 0;"
- f'printf("Actual, Reference\\n");\n'
- f"for (int i = 0; i<{data_length_var_name}; i++) {{\n"
- f"\tif ({comparison_function}({actual_data_name}[i]-"
- f"{expected_data_name}[i]) > {tolerance}) {{\n"
- f'\t\tprintf("{value_format_specifier},
{value_format_specifier}\\n"'
- f", {actual_data_name}[i], {expected_data_name}[i]);\n"
- f"\t\tmismatch = 1;\n"
- f"\t}}\n"
- f"}}"
- f"if (mismatch == 1) {{\n"
- f'\tprintf("{AOT_FAILURE_TOKEN}\\n");\n'
- f"\treturn -1;\n"
- f"}}"
+ f"""
+ {{
+ int mismatch = 0;
+ int out_ndim = {outputs[key].ndim};
+ int out_shape[] = {{{','.join(map(str, outputs[key].shape))}}};
+ int out_indices[out_ndim];
+ printf("Element [Position]: Actual, Reference\\n");
+ printf("-------------------------------------\\n");
+ for (int i = 0; i<{data_length_var_name}; i++) {{
+ if ({comparison_function}({actual_data_name}[i] -
+ {expected_data_name}[i]) > {tolerance}) {{
+ int flat_index = i;
+ for (int j = out_ndim - 1; j >= 0; j--){{
+ out_indices[j] = flat_index % out_shape[j];
+ flat_index /= out_shape[j];
+ }}
+ printf("Element [%d", out_indices[0]);
+ for (int j = 1; j < out_ndim; j++)
+ printf(", %d", out_indices[j]);
+ printf("]: {value_format_specifier},
{value_format_specifier}\\n",
+ {actual_data_name}[i], {expected_data_name}[i]);
+ mismatch += 1;
+ }}
+ }}
+ if (mismatch >= 1) {{
+ float percent_mismatched =
+ ((float) mismatch) / ((float) {data_length_var_name}) *
100;
+ printf("\\nMismatched elements: %d / %zu (%.2f%%)\\n",
+ mismatch, {data_length_var_name}, percent_mismatched);
+ printf("{AOT_FAILURE_TOKEN}\\n");
+ return -1;
+ }}
+ }}
+ """
)
else:
main_file.write(
diff --git a/tests/python/relay/aot/test_aot_test_harness.py
b/tests/python/relay/aot/test_aot_test_harness.py
index 8ec9506f9f..3d10f15d4a 100644
--- a/tests/python/relay/aot/test_aot_test_harness.py
+++ b/tests/python/relay/aot/test_aot_test_harness.py
@@ -46,7 +46,57 @@ def test_output_on_mismatch_option():
).astype(dtype)
}
- msg = ".*Actual, Reference\n2.000000, 0.000000\nAOT_TEST_FAILURE.*"
+ msg = ".*Actual, Reference(\n|.)*2.000000,
0.000000(\n|.)*AOT_TEST_FAILURE.*"
+ with pytest.raises(RuntimeError, match=msg):
+ compile_and_run(
+ AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={},
outputs=outputs),
+ test_runner,
+ interface_api,
+ use_unpacked_api,
+ print_output_on_mismatch=True,
+ )
+
+
+def test_output_position_on_mismatch():
+ """
+ Test the mismatch position output for the print_output_on_mismatch option.
+ """
+ interface_api = "packed"
+ use_unpacked_api = True
+ test_runner = AOTTestRunner()
+ dtype = "float32"
+
+ x = np.zeros(shape=(2, 2), dtype=dtype)
+ x[-1, -1] = 1
+ func = relay.Function([], relay.const(x, dtype=dtype))
+ outputs = {"output": np.zeros(shape=(2, 2), dtype=dtype)}
+
+ msg = ".*Element \\[1, 1\\]:.*"
+ with pytest.raises(RuntimeError, match=msg):
+ compile_and_run(
+ AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={},
outputs=outputs),
+ test_runner,
+ interface_api,
+ use_unpacked_api,
+ print_output_on_mismatch=True,
+ )
+
+
+def test_mismatch_percentage():
+ """
+ Test the mismatch percentage for the print_output_on_mismatch option.
+ """
+ interface_api = "packed"
+ use_unpacked_api = True
+ test_runner = AOTTestRunner()
+ dtype = "float32"
+
+ x = np.zeros(shape=(8,), dtype=dtype)
+ x[0] = 1
+ func = relay.Function([], relay.const(x, dtype=dtype))
+ outputs = {"output": np.zeros(shape=(8,), dtype=dtype)}
+
+ msg = ".*Mismatched elements: 1 / 8 \\(12.50%\\).*"
with pytest.raises(RuntimeError, match=msg):
compile_and_run(
AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={},
outputs=outputs),