Mousius commented on code in PR #13655:
URL: https://github.com/apache/tvm/pull/13655#discussion_r1060610992
##########
tests/python/relay/aot/test_crt_aot.py:
##########
@@ -225,6 +225,64 @@ def test_packed_global_variables():
assert f"{func}_packed" not in tvmgen_names
+def test_io_size_definition():
+ """Check network IO size definitions in the codegen output."""
+ dtype = "float32"
+ ishape = (1, 32, 14, 14)
+ wshape = (32, 32, 3, 3)
+ interface_api = "c"
+ use_unpacked_api = True
+
+ data0 = relay.var("data", shape=ishape, dtype=dtype)
+ weight0 = relay.var("weight", shape=wshape, dtype=dtype)
+ out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1),
groups=1)
+ main_f = relay.Function([data0, weight0], out)
+ mod = tvm.IRModule()
+ mod["main"] = main_f
+ mod = transform.InferType()(mod)
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w1_data = np.random.uniform(0, 1, wshape).astype(dtype)
+
+ inputs = OrderedDict([("data", i_data), ("weight", w1_data)])
+
+ output_list = generate_ref_data(mod, inputs)
+ compiled_models_list = compile_models(
+ models=AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+ interface_api=interface_api,
+ use_unpacked_api=use_unpacked_api,
+ workspace_byte_alignment=8,
+ enable_op_fusion=True,
+ pass_config=AOT_DEFAULT_RUNNER.pass_config,
+ use_runtime_executor=True,
+ target=tvm.target.Target("c"),
+ )
+ ref_output_size = output_list["output"].size * np.dtype(dtype).itemsize
+ compiled_model = compiled_models_list[0]
+
+ tmp_path = utils.tempdir()
+ base_path = tmp_path.temp_dir
+
+ model = compiled_model.model
+ tar_file = os.path.join(base_path, f"{model.name}.tar")
+ export_model_library_format(compiled_model.executor_factory, tar_file)
+ t = tarfile.open(tar_file)
+ t.extractall(base_path)
+
+ file_list = []
+ for path in (pathlib.Path(base_path) / "codegen" / "host" /
"include").iterdir():
+ if path.is_file():
+ file_list.append(path)
+ assert len(file_list) > 0
+
+ for path in file_list:
+ with open(path, "r") as header:
+ contents = header.readlines()
+ contents = "".join(map(str, contents))
+ assert contents.count("_SIZE") == 4
+ assert str(ref_output_size) in contents
Review Comment:
Something like:
```
assert contents.count("_SIZE") == 4
assert f"INPUT_1_SIZE {ref_output_size}" in contents
```
?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]