So with the following rewrites and passes

```python
class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.zeros = 
tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard())
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.zeros + self.other_tensor) | (self.other_tensor + 
self.zeros)

    def callback(self, pre, post, node_map):
        rt = node_map[self.pattern][0]
        ot = node_map[self.other_tensor][0]
        if (ot._checked_type_ == rt._checked_type_):
            return ot
        else:
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))

class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.ones = 
tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard())
 | tvm.relay.dataflow_pattern.is_constant()
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.ones + self.other_tensor) | (self.other_tensor + 
self.ones)

    def callback(self, pre, post, node_map):
        rt = node_map[self.pattern][0]
        ones = node_map[self.ones][0]
        ot = node_map[self.other_tensor][0]
        if isinstance(ot, tvm.relay.Constant):
            if not all(ones.data.asnumpy() == 0):
                return rt
        # I don't know why I don't reliably get checked types here...
        if (((rt._checked_type_ is not None) and (ot._checked_type_ == 
rt._checked_type_))
            or (rt.type_args[0] == rt.type_args[1])):
            return ot
        elif (rt._checked_type_ is not None):
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
        return rt

class OneZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.ones = 
tvm.relay.dataflow_pattern.is_op("ones")(tvm.relay.dataflow_pattern.wildcard()) 
| tvm.relay.dataflow_pattern.is_constant()
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.ones * self.other_tensor) | (self.other_tensor * 
self.ones)

    def callback(self, pre, post, node_map):
        rt = node_map[self.pattern][0]
        ones = node_map[self.ones][0]
        ot = node_map[self.other_tensor][0]
        if isinstance(ot, tvm.relay.Constant):
            if not all(ones.data.asnumpy() == 1):
                return rt
        if (ot._checked_type_ == rt._checked_type_):
            return ot
        else:
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))


class LikeZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.translations_with_dt = {'zeros_like': tvm.relay.zeros,
                                     'ones_like': tvm.relay.ones}
        self.data_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = ((tvm.relay.dataflow_pattern.is_op("zeros_like")
                        | tvm.relay.dataflow_pattern.is_op("ones_like")
                        )(self.data_tensor)
                        ) | ((
                        tvm.relay.dataflow_pattern.is_op("collapse_sum_like")
                        | tvm.relay.dataflow_pattern.is_op("broadcast_to_like")
                       )(self.data_tensor, self.pattern_tensor))

    def callback(self, pre, post, node_map):
        data = node_map[self.data_tensor][0]
        res = node_map[self.pattern][0]
        if res.op.name in self.translations_with_dt:
            return 
self.translations_with_dt[res.op.name](list(res._checked_type_.shape),
                                                          
res._checked_type_.dtype)
        if (data._checked_type_ == res._checked_type_):
            return data
        else:
            if res.op.name == 'broadcast_to_like':
                return tvm.relay.broadcast_to(data, 
list(res._checked_type_.shape))
            return res


    grmod["main"] = tvm.relay.dataflow_pattern.rewrite(LikeZapp(), 
grmod["main"])
    grmod = tvm.relay.transform.FoldConstant()(grmod)
    grmod = tvm.relay.transform.InferType()(grmod)
    grmod["main"] = tvm.relay.dataflow_pattern.rewrite(ZeroZapp(), 
grmod["main"])
    grmod["main"] = tvm.relay.dataflow_pattern.rewrite(OneZapp(), grmod["main"])
```

I get what looks realistic:

![image|690x184](upload://hdIh0laFwoahHVCGdh649tdCQv8.png)

But this is just a trivial case and if you had a hint whether some of these 
patterns are readily available, I would be most grateful.

Also I don't have an idea why I don't reliably get `_checked_shape_` attributes 
in the ZeroZapp... If you have an idea...

Best regards

Thomas





---
[Visit Topic](https://discuss.tvm.ai/t/same-shape-pattern/7012/4) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/edf1a3df04c4e401b4a9bae6d2453278776471a7e08ef58e6b9932a4d4cbb1f8).

Reply via email to