saurabh-shandilya commented on issue #13543:
URL: https://github.com/apache/tvm/issues/13543#issuecomment-1347088613
@masahi Sure I will submit a PR. By the way my question about partition code
that I used in my example above.
` mod_p = test_pattern().partition(finalout,{'Composite': 'test_pattern'})`
I wrote following test for this but it fails.
```
def test_partition_merge_composite():
score = relay.var("score", shape=[7], dtype="float32")
box = relay.var("box", shape=[7, 4], dtype="float32")
g = relay.greater(score, relay.const(0.01))
w = relay.argwhere(g)
s = relay.squeeze(w, axis=[1])
tb = (box, s)
boxout = relay.adv_index(tb)
scoreout = relay.adv_index((score, s))
sourceexpand = relay.expand_dims(scoreout, axis=-1)
finalout = relay.concatenate([boxout, sourceexpand], axis=-1)
# Use partition
partitioned_expr = test_pattern().partition(finalout,{'Composite':
'test_pattern'})
print(partitioned_expr)
# Use merge composite
mod = tvm.IRModule.from_expr(finalout)
mod_m = run_mergecomposite_pass(mod)
merge_composite_expr=mod_m["main"].body
print(merge_composite_expr)
assert tvm.ir.structural_equal(merge_composite_expr,partitioned_expr)
```
I printed the 2 expressions and they look almost same except that the first
one has shape information but the check failed.
partitioned_expr is as below
> free_var %score: Tensor[(7), float32];
> %2 = fn (%FunctionVar_0_0,
PartitionedFromPattern="greater_argwhere_squeeze_", Composite="test_pattern") {
> %0 = greater(%FunctionVar_0_0, 0.01f);
> %1 = argwhere(%0);
> squeeze(%1, axis=[1])
> };
> free_var %box: Tensor[(7, 4), float32];
> %3 = %2(%score);
> %4 = (%box, %3);
> %5 = (%score, %3);
> %6 = adv_index(%5);
> %7 = adv_index(%4);
> %8 = expand_dims(%6, axis=-1);
> %9 = (%7, %8);
> concatenate(%9, axis=-1)
merge_composite_expr is as below
> free_var %score: Tensor[(7), float32] /* ty=Tensor[(7), float32] */;
> %2 = fn (%FunctionVar_0_0: Tensor[(7), float32] /* ty=Tensor[(7), float32]
*/, PartitionedFromPattern="greater_argwhere_squeeze_",
Composite="test_pattern") -> Tensor[(?), int32] {
> %0 = greater(%FunctionVar_0_0, 0.01f /* ty=float32 */) /* ty=Tensor[(7),
bool] */;
> %1 = argwhere(%0) /* ty=Tensor[(?, 1), int32] */;
> squeeze(%1, axis=[1]) /* ty=Tensor[(?), int32] */
> } /* ty=fn (Tensor[(7), float32]) -> Tensor[(?), int32] */;
> free_var %box: Tensor[(7, 4), float32] /* ty=Tensor[(7, 4), float32] */;
> %3 = %2(%score) /* ty=Tensor[(?), int32] */;
> %4 = (%box, %3) /* ty=(Tensor[(7, 4), float32], Tensor[(?), int32]) */;
> %5 = (%score, %3) /* ty=(Tensor[(7), float32], Tensor[(?), int32]) */;
> %6 = adv_index(%5) /* ty=Tensor[(?), float32] */;
> %7 = adv_index(%4) /* ty=Tensor[(?, 4), float32] */;
> %8 = expand_dims(%6, axis=-1) /* ty=Tensor[(?, 1), float32] */;
> %9 = (%7, %8) /* ty=(Tensor[(?, 4), float32], Tensor[(?, 1), float32]) */;
> concatenate(%9, axis=-1) /* ty=Tensor[(?, 5), float32] */
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]