rafzi commented on PR #11880:
URL: https://github.com/apache/tvm/pull/11880#issuecomment-1167499812
How about the following in the file
"tests/python/unittest/test_tir_usmp_algo.py"?
``` python
def test_custom_algo():
target = Target("c")
global_workspace_pool = usmp_utils.PoolInfo(
pool_name="global_workspace",
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
)
tir_mod = ResnetStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod,
[global_workspace_pool])
tir_mod = tir_mod.with_attr("executor",
tvm.relay.backend.Executor("aot"))
tir_mod = tir_mod.with_attr("runtime", tvm.relay.backend.Runtime("crt"))
tir_mod["__tvm_main__"] =
tir_mod["tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast"]
algo_called = False
@tvm.register_func("tir.usmp.algo.trivial")
def _trivial_algo(buf_infos, mem_pressure):
nonlocal algo_called
algo_called = True
out_layout = {}
offset = 0
for buf_info in buf_infos:
pool_info = buf_info.pool_candidates[0]
out_layout[buf_info] = usmp_utils.PoolAllocation(pool_info,
offset)
offset += buf_info.size_bytes
return out_layout
usmp_pass =
tvm.get_global_func("tir.transform.UnifiedStaticMemoryPlanner")
usmp_pass()(tir_mod)
assert not algo_called
with tvm.transform.PassContext(config={"tir.usmp.custom_algorithm":
"trivial"}):
usmp_pass()(tir_mod)
assert algo_called
with pytest.raises(tvm.TVMError, match="The selected custom USMP
algorithm : invalid is not defined"):
with tvm.transform.PassContext(config={"tir.usmp.custom_algorithm":
"invalid"}):
usmp_pass()(tir_mod)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]