zhuwenxi edited a comment on pull request #7619:
URL: https://github.com/apache/tvm/pull/7619#issuecomment-814055145
@tqchen , I met some problems when I was trying to reconstruct the UT with
ir_builder and assert_structural_equal(). This is my code:
<pre>
def assert_packed_func(target="llvm", parallel=True):
ib = tvm.tir.ir_builder.create()
m = n = k = 16
#
# Prepare buffer for a, b and c:
#
a = te.placeholder((m, k), name="a", dtype="float64")
b = te.placeholder((k, n), name="b", dtype="float64")
k = te.reduce_axis((0, k), name="k")
c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k),
name="c")
a_buffer = tvm.tir.decl_buffer(
a.shape, a.dtype, name="a_buffer", offset_factor=1,
strides=[te.var("s1"), 1]
)
b_buffer = tvm.tir.decl_buffer(
b.shape, b.dtype, name="b_buffer", offset_factor=1,
strides=[te.var("s2"), 1]
)
c_buffer = tvm.tir.decl_buffer(
c.shape, c.dtype, name="c_buffer", offset_factor=1,
strides=[te.var("s3"), 1]
)
# Use ir_buider to create a packed call in the parallel loop:
with ib.for_range(0, 10, "i", kind="parallel"):
ib.emit(tvm.tir.call_packed("tvm.test_matmul", a_buffer, b_buffer,
c_buffer))
stmt = ib.get()
# Construct a valid IRModule to be lowered:
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer,
c_buffer], stmt))
target = tvm.target.Target(target)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol",
"main"))(mod)
mod = tvm.tir.transform.MakePackedAPI()(mod)
# Do the lowering:
mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
# Get the PrimFunc from module:
prim_func = mod.functions.items()[0][1]
# Recursively visit PrimFunc until we meet the for-loop
node = prim_func.body
while isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt,
tvm.tir.AttrStmt)):
node = node.body
# For-loop met
assert isinstance(node, tvm.tir.stmt.For)
alloca_tcode = node.body
assert isinstance(alloca_tcode, tvm.tir.LetStmt)
...
</pre>
I suppose I should use assert_structural_equal() to assert the
"alloca_tcode" here, but I don't know how to construct the "expected" stmt. The
expected stmt here should be `let stack_tcode =
tir.tvm_stack_alloca("arg_tcode", 4)`, but seems TVM doesn't have python APIs
to create a `tir.tvm_stack_alloca` stmt. (This intrinsic can only be generated
by C++ API?)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]