zhuwenxi commented on pull request #7619:
URL: https://github.com/apache/tvm/pull/7619#issuecomment-814055145


   @tqchen , I met some problems when I was trying to reconstruct the UT with 
ir_builder and assert_structural_equal(). This is my code:
   <pre>
   def assert_packed_func(target="llvm", parallel=True):
     ib = tvm.tir.ir_builder.create()
   
     m = n = k = 16
   
     # 
     # Prepare buffer for a, b and c:
     #
     a = te.placeholder((m, k), name="a", dtype="float64")
     b = te.placeholder((k, n), name="b", dtype="float64")
     k = te.reduce_axis((0, k), name="k")
     c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k), 
name="c")
   
     a_buffer = tvm.tir.decl_buffer(
         a.shape, a.dtype, name="a_buffer", offset_factor=1, 
strides=[te.var("s1"), 1]
     )
     b_buffer = tvm.tir.decl_buffer(
         b.shape, b.dtype, name="b_buffer", offset_factor=1, 
strides=[te.var("s2"), 1]
     )
     c_buffer = tvm.tir.decl_buffer(
         c.shape, c.dtype, name="c_buffer", offset_factor=1, 
strides=[te.var("s3"), 1]
     )
   
   
     # Use ir_buider to create a packed call in the parallel loop:
     with ib.for_range(0, 10, "i", kind="parallel"):
       ib.emit(tvm.tir.call_packed("tvm.test_matmul", a_buffer, b_buffer, 
c_buffer))
   
     stmt = ib.get()
   
     # Construct a valid IRModule to be lowered:
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer, 
c_buffer], stmt))
     target = tvm.target.Target(target)
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", 
"main"))(mod)
     mod = tvm.tir.transform.MakePackedAPI()(mod)
     
     # Do the lowering:
     mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
   
     # Get the PrimFunc from module:
     prim_func = mod.functions.items()[0][1]
     
     # Recursively visit PrimFunc until we meet the for-loop
     node = prim_func.body
     while isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, 
tvm.tir.AttrStmt)):
       node = node.body
     
     # For-loop met
     assert isinstance(node, tvm.tir.stmt.For)
   
     alloca_tcode = node.body
     assert isinstance(alloca_tcode, tvm.tir.LetStmt)
   
     return alloca_tcode
   </pre>
   
   I suppose I should use assert_structural_equal() to assert the 
"alloca_tcode" here, but I don't know how to construct the "expected" stmt. The 
expected stmt here should be `let stack_tcode = 
tir.tvm_stack_alloca("arg_tcode", 4)`, but seems TVM doesn't have python APIs 
to create a `tir.tvm_stack_alloca` stmt. (This intrinsic can only be generated 
by C++ API?)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to