abergeron opened a new issue #4758: [relay] ADT match becomes not well formed 
during VM optimization
URL: https://github.com/apache/incubator-tvm/issues/4758
 
 
   I have this sample program:
   
   ```python
   import tvm
   from tvm import relay
   from tvm.relay import adt
   
   ctx = tvm.ndarray.context('cpu', 0)
   
   mod = relay.Module({})
   
   union_type = relay.GlobalTypeVar("u")
   c0_type = relay.ty.TupleType([relay.ty.scalar_type('int32'), union_type()])
   c0 = adt.Constructor("c0", [c0_type], union_type)
   c1 = adt.Constructor("c1", [relay.ty.TupleType([])], union_type)
   
   mod[union_type] = adt.TypeData(union_type, [], [c0, c1])
   
   gv = relay.GlobalVar("fn")
   p = relay.var('p', union_type())
   v = relay.Var('v')
   cond = adt.Match(p, [adt.Clause(adt.PatternConstructor(c0, 
[adt.PatternWildcard\
   ()]), relay.const(True)),
                        adt.Clause(adt.PatternWildcard(), relay.const(False))])
   
   mm = adt.Match(p, [adt.Clause(adt.PatternConstructor(c0, 
[adt.PatternVar(v)]), \
   v)], complete=False)
   
   fn = relay.Function(
       [p],
       relay.If(
           cond,
           relay.const(0),
           relay.add(relay.TupleGetItem(mm, 0),
                     relay.Call(gv, [relay.TupleGetItem(mm, 1)]))
       ),
       ret_type=relay.ty.scalar_type('int32')
   )
   
   mod[gv] = fn
   
   q = relay.var("q", union_type())
   mod["main"] = relay.Function([q], relay.Call(gv, [q]))
   
   print(str(mod))
   
   vm = relay.create_executor(mod=mod, ctx=ctx, target='llvm', kind="vm")
   ```
   
   The module looks like this before compiling:
   
   ```
   v0.0.4
   type u {
     c0((int32, u[])),
     c1(()),
   }
   
   def @main(%q: u[]) -> int32 {
     @fn(%q) /* ty=int32 */
   }
   
   def @fn(%p: u[]) -> int32 {
     %0 = match (%p) {
       c0(_) => True /* ty=bool */,
       _ => False /* ty=bool */,
     };
     if (%0) {
       0 /* ty=int32 */
     } else {
       %1 = match? (%p) {
         c0(%v: (int32, u[])) => %v,
       };
       %2 = %1.0;
       %3 = %1.1;
       %4 = @fn(%3) /* ty=int32 */;
       add(%2, %4) /* ty=int32 */
     }
   }
   ```
   
   And I get this error during compilation with kind="vm", but it works 
correctly for kind="debug":
   
   ```
   Traceback (most recent call last):
   
     File "tst.py", line 42, in <module>
       vm = relay.create_executor(mod=mod, ctx=ctx, target='llvm', kind="vm")
   
     File "/home/anakha/ext/tvm/python/tvm/relay/build_module.py", line 411, in 
create_executor
       return VMExecutor(mod, ctx, target)
   
     File "/home/anakha/ext/tvm/python/tvm/relay/backend/vm.py", line 540, in 
__init__
       self.executable = compile(mod, target)
   
     File "/home/anakha/ext/tvm/python/tvm/relay/backend/vm.py", line 399, in 
compile
       compiler.lower(mod, target, target_host)
   
     File "/home/anakha/ext/tvm/python/tvm/relay/backend/vm.py", line 455, in 
lower
       self._lower(mod, target, target_host)
   
     File "tvm/_ffi/_cython/./function.pxi", line 304, in 
tvm._ffi._cy3.core.FunctionBase.__call__
   
     File "tvm/_ffi/_cython/./function.pxi", line 239, in 
tvm._ffi._cy3.core.FuncCall
   
     File "tvm/_ffi/_cython/./function.pxi", line 228, in 
tvm._ffi._cy3.core.FuncCall3
   
     File "tvm/_ffi/_cython/./base.pxi", line 157, in tvm._ffi._cy3.core.CALL
   
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     [bt] (8) /home/anakha/ext/tvm/build/libtvm.so(+0x2097b7d) [0x7f0bebd40b7d]
     [bt] (7) /home/anakha/ext/tvm/build/libtvm.so(+0x2097d9c) [0x7f0bebd40d9c]
     [bt] (6) /home/anakha/ext/tvm/build/libtvm.so(+0x2097f08) [0x7f0bebd40f08]
     [bt] (5) /home/anakha/ext/tvm/build/libtvm.so(+0x2097579) [0x7f0bebd40579]
     [bt] (4) 
/home/anakha/ext/tvm/build/libtvm.so(tvm::relay::vm::PrimitiveInliner::Inline()+0x276)
 [0x7f0bebd41a1c]
     [bt] (3) 
/home/anakha/ext/tvm/build/libtvm.so(tvm::IRModuleNode::Add(tvm::GlobalVar 
const&, tvm::BaseFunc const&, bool)+0xcb) [0x7f0beb3daca3]
     [bt] (2) 
/home/anakha/ext/tvm/build/libtvm.so(tvm::RunTypeCheck(tvm::IRModule const&, 
tvm::GlobalVar const&, tvm::relay::Function)+0x5b) [0x7f0beb3da662]
     [bt] (1) 
/home/anakha/ext/tvm/build/libtvm.so(tvm::relay::DeDup(tvm::RelayExpr 
const&)+0x106) [0x7f0bebb460cd]
     [bt] (0) 
/home/anakha/ext/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4a)
 [0x7f0beb36f884]
     File "/home/anakha/ext/tvm/src/relay/pass/de_duplicate.cc", line 109
   TVMError: Check failed: WellFormed(e): v0.0.4
   fn (%p: u[]) -> int32 {
     %0 = match (%p) {
       c0(_) => True /* ty=bool */,
       _ => False /* ty=bool */,
     };
     if (%0) {
       0 /* ty=int32 */
     } else {
       %1 = match? (%p) {
         c0(%v: (int32, u[])) => %v,
       };
       %2 = %1.0;
       %3 = match? (%p) {
         c0(%v: (int32, u[])) => %v,
       };
       %4 = %3.1;
       %5 = @fn(%4);
       add(%2, %5)
     }
   }
   ```
   
   The only thing I can notice is that something duplicated the match in the 
else branch and this probably makes the code not well formed.  I traced the 
WellFormed Error and it comes from %v.  I may be wrong about this.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to