mbs-octoml opened a new issue, #11661:
URL: https://github.com/apache/tvm/issues/11661
I wanted to unit test TEPass with something like:
```
def test_lower_primitive():
input_mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
%0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7),
float32], Primitive=1) -> Tensor[(5, 7), float32] {
add(%x, %y)
};
%0(%a, %a)
}
""",
"from_string", None, None,
)
actual_mod = transform(input_mod)
expected_mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
%0 = (%a, %a);
call_lowered(@test_fused_add, %0,
metadata={relay_attrs={Primitive=1},all_prim_fn_vars=[@test_fused_add]})
}
def @test_fused_add = <lowered PrimFunc>
""", "from_string", None, None)
tvm.ir.assert_structural_equal(actual_mod, expected_mod, True)
```
Firstly, it's not possible to express the call_lowered metadata attributes
in the form written, so it needs to be bound to a meta table entry.
```
test_fused_add = actual_mod.get_global_var('test_fused_add')
call_lowered_attrs = {
"relay_attrs": tvm.ir.make_node("DictAttrs",
Primitive=tvm.tir.IntImm("int32", 1)),
"all_prim_fn_vars": [test_fused_add]
}
metadata = {
"attrs": [call_lowered_attrs]
}
```
That's ok, but the global var baund to 'test_fused_add' is not right, it
needs to be the same object as created to represent the definition.
I think we should have a structural_equal mode that compares on name_hint
alone. We almost have that, but somehow in all the logic the 'map_free_vars'
options got reset to False and triggered a failure.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]