saurabh-shandilya commented on issue #13543:
URL: https://github.com/apache/tvm/issues/13543#issuecomment-1347088613

   @masahi Sure I will submit a PR. By the way my question about partition code 
that I used in my example above. 
   ` mod_p = test_pattern().partition(finalout,{'Composite': 'test_pattern'})`
   
   I wrote following test for this but it fails. 
   
   ```
   def test_partition_merge_composite():
       score = relay.var("score", shape=[7], dtype="float32")
       box = relay.var("box", shape=[7, 4], dtype="float32")
       g = relay.greater(score, relay.const(0.01))
       w = relay.argwhere(g)
       s = relay.squeeze(w, axis=[1])
       tb = (box, s)
       boxout = relay.adv_index(tb)
       scoreout = relay.adv_index((score, s))
       sourceexpand = relay.expand_dims(scoreout, axis=-1)
       finalout = relay.concatenate([boxout, sourceexpand], axis=-1)
   
       # Use partition
       partitioned_expr = test_pattern().partition(finalout,{'Composite': 
'test_pattern'})
       print(partitioned_expr)
   
       # Use merge composite
       mod = tvm.IRModule.from_expr(finalout)
       mod_m = run_mergecomposite_pass(mod)
       merge_composite_expr=mod_m["main"].body
       print(merge_composite_expr)    
       
       assert tvm.ir.structural_equal(merge_composite_expr,partitioned_expr)
   ```
   
   I printed the 2 expressions and they look almost same except that the first 
one has shape information but the check failed. 
   
   partitioned_expr is as below
   > free_var %score: Tensor[(7), float32];
   > %2 = fn (%FunctionVar_0_0, 
PartitionedFromPattern="greater_argwhere_squeeze_", Composite="test_pattern") {
   >   %0 = greater(%FunctionVar_0_0, 0.01f);
   >   %1 = argwhere(%0);
   >   squeeze(%1, axis=[1])
   > };
   > free_var %box: Tensor[(7, 4), float32];
   > %3 = %2(%score);
   > %4 = (%box, %3);
   > %5 = (%score, %3);
   > %6 = adv_index(%5);
   > %7 = adv_index(%4);
   > %8 = expand_dims(%6, axis=-1);
   > %9 = (%7, %8);
   > concatenate(%9, axis=-1)
   
   merge_composite_expr is as below
   > free_var %score: Tensor[(7), float32] /* ty=Tensor[(7), float32] */;
   > %2 = fn (%FunctionVar_0_0: Tensor[(7), float32] /* ty=Tensor[(7), float32] 
*/, PartitionedFromPattern="greater_argwhere_squeeze_", 
Composite="test_pattern") -> Tensor[(?), int32] {
   >   %0 = greater(%FunctionVar_0_0, 0.01f /* ty=float32 */) /* ty=Tensor[(7), 
bool] */;
   >   %1 = argwhere(%0) /* ty=Tensor[(?, 1), int32] */;
   >   squeeze(%1, axis=[1]) /* ty=Tensor[(?), int32] */
   > } /* ty=fn (Tensor[(7), float32]) -> Tensor[(?), int32] */;
   > free_var %box: Tensor[(7, 4), float32] /* ty=Tensor[(7, 4), float32] */;
   > %3 = %2(%score) /* ty=Tensor[(?), int32] */;
   > %4 = (%box, %3) /* ty=(Tensor[(7, 4), float32], Tensor[(?), int32]) */;
   > %5 = (%score, %3) /* ty=(Tensor[(7), float32], Tensor[(?), int32]) */;
   > %6 = adv_index(%5) /* ty=Tensor[(?), float32] */;
   > %7 = adv_index(%4) /* ty=Tensor[(?, 4), float32] */;
   > %8 = expand_dims(%6, axis=-1) /* ty=Tensor[(?, 1), float32] */;
   > %9 = (%7, %8) /* ty=(Tensor[(?, 4), float32], Tensor[(?, 1), float32]) */;
   > concatenate(%9, axis=-1) /* ty=Tensor[(?, 5), float32] */


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to