maartenvds opened a new issue, #14424:
URL: https://github.com/apache/tvm/issues/14424
### Expected behavior
Consider the following code snippet:
```
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import DFPatternCallback, FunctionPattern,
dominates, rewrite, wildcard, is_op
class Rewriter(DFPatternCallback):
def __init__(self, require_type=True, rewrite_once=True):
super().__init__(require_type, rewrite_once)
self.x = is_op("nn.leaky_relu")(wildcard())
reduction = FunctionPattern([wildcard(), wildcard()],
wildcard())(wildcard(), wildcard())
self.pattern = dominates(self.x, wildcard(), reduction)
def callback(self, pre, post, node_map):
#
# the rewrite goes here
#
print("reached callback")
return x
# example graph with diamond structure
x = relay.var('x', relay.TensorType([1, 1, 32, 32], 'float32'))
x = relay.nn.leaky_relu(x)
a = relay.nn.leaky_relu(x)
b = relay.nn.relu(x)
y = a + b
# partition add op (wrap it into a function)
y = is_op('add')(wildcard(), wildcard()).partition(y)
# rewrite diamond
r = Rewriter()
y = rewrite(r, y)
```
I expect that the `callback` method in the rewrite class gets called since
the described pattern is present in the graph.
### Actual behavior
It crashes with this error:
```
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/include/tvm/runtime/packed_func.h:1646
8:
tvm::relay::RewritePatterns(tvm::runtime::Array<tvm::relay::DFPatternCallback,
void>, tvm::RelayExpr, tvm::IRModule)
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/src/relay/ir/dataflow_matcher.cc:852
7:
tvm::relay::PatternRewriter::Rewrite(tvm::runtime::Array<tvm::relay::DFPatternCallback,
void> const&, tvm::RelayExpr const&)
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/src/relay/ir/dataflow_matcher.cc:816
6: tvm::relay::PatternGrouper::GroupMatches(tvm::relay::DFPattern const&,
tvm::RelayExpr const&)
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/src/relay/ir/dataflow_matcher.cc:587
5: tvm::relay::PatternGrouper::VisitExprs()
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/src/relay/ir/dataflow_matcher.cc:605
4: tvm::relay::PatternGrouper::CreateGroup(tvm::RelayExpr const&)
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/src/relay/ir/dataflow_matcher.cc:624
3: tvm::runtime::Map<tvm::relay::DFPattern,
tvm::runtime::Array<tvm::RelayExpr, void>, void,
void>::operator[](tvm::relay::DFPattern const&) const
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/include/tvm/runtime/container/map.h:1346
2: tvm::runtime::Map<tvm::relay::DFPattern,
tvm::runtime::Array<tvm::RelayExpr, void>, void,
void>::at(tvm::relay::DFPattern const&) const
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/include/tvm/runtime/container/map.h:1340
1: tvm::runtime::MapNode::at(tvm::runtime::ObjectRef const&)
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/include/tvm/runtime/container/map.h:1177
0: tvm::runtime::SmallMapNode::at(tvm::runtime::ObjectRef const&)
at
/home/maarten/code/nn_compiler/sirius/tvm-fork/include/tvm/runtime/container/map.h:376
File
"/home/maarten/code/nn_compiler/sirius/tvm-fork/include/tvm/runtime/container/map.h",
line 671
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (!iter.IsNone()) is false: IndexError: key is not in Map
```
Note that the pattern itself works. If I call:
```
r.pattern.match(y)
```
it returns 1 (matched!).
Also note that the code works when not wrapping the 'add' operator into a
function. However, I need it to support reduction ops in a function too.
### Environment
Environment: Ubuntu 22.04, python3.7, TVM 0.9.0
### Steps to reproduce
Run the script above
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]