ptrendx opened a new pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269
## Description ##
This PR greatly increases the speed of pointwise fusion graph pass. As a
test case I used XLNet from Gluon-NLP which did expose slowness in this graph
pass before: #17105 - the total time of fwd+bwd fusion after the fix from that
issue was ~8s, after this PR the time is ~11 ms, with 3 ms of that taken by
IndexedGraph construction from the original graph (so the actual fusion graph
pass takes ~8ms for that case - about 1000x improvement).
The motivation of this PR is getting the pointwise fusion graph pass to be
lightweight enough to possibly be run every time the shapes change in the
graph, enabling the fusion to be shape-aware. This is important as in 2.0 NumPy
semantics make multiple operations (like add, sub etc.) broadcast by default,
even if in the end they are simple elementwise operations. This PR does not
fully get us there (I would much like it being <5ms for the network like XLNet
to be sure it is lightweight enough to not be bottleneck for all usecases),
although there are some parts that would not be needed anymore if we could make
this pass after infershape is already done (e.g. we would not need to insert
`FusedOpHelper`/`FusedOpOutHelper`, which currently takes a little over 1ms).
That said, it is a big step in that direction.
The main problem that the fusion graph pass needs to solve is not allowing
the cycles to be formed: nodes which both consume the output of and provide the
input to a single subgraph. The original pointwise fusion graph pass' algorithm
to avoid cycles was to first construct a mapping, which nodes are excluded to
be in the same subgraph with a given node. The construction of the mapping was
simple but inefficient (and then improved in #17114, but still pretty slow):
```
for each node n that does not qualify to be in a subgraph:
outputs = nodes reached by DFS from n in the direction of n's outputs
inputs = nodes reached by DFS from n in the direction of n's inputs
for each output in outputs:
for each input in inputs:
put (input, output) and (output, input) pair into the exclusion
mapping
```
This was `O(n^3)` algorithm (improved in #17114 to be `O(n^2)` on average)
and required a separate `BidirectionalGraph` datastructure that enabled DFS in
both directions.
The new algorithm for pointwise fusion graph pass is designed based on 2
observations:
- most of the entries in the exclusion mappings are not useful, as those
nodes are not considered to be part of the same subgraph
- we can traverse the graph in topological order, which the original
algorithm did not take advantage of
In the new algorithm the graph is traversed in the topological order (and so
we are sure that all the inputs of the current node were already processed) and
each node has its own exclusion set of subgraphs that it can't be a part of.
The exclusion set of node is constructed as the union of the exclusion sets of
its inputs + all the subsets that its inputs are part of if the node itself is
ineligible to be in a subgraph. Because subsets can merge (e.g. when you have
operation `a + b` where `a` is part of subgraph s1 and `b` is part of subgraph
s2, then s1 and s2 need to be merged into a single subgraph containing
everything from s1, s2 and the `+` operator), the mapping is maintained to know
which other subsets the current subset is merged with.
There are number of additional optimizations:
- because of the topological ordering of the graph traversal, subset ids in
the exclusion set are typically consecutive number (or a small group of
consecutive numbers) -> therefore in the exclusion set we actually keep
intervals of numbers instead of the numbers themselves
- in most cases the union of exclusion sets is equal to one of the sets ->
in order to avoid costly unnecessary memory allocations, we share the exclusion
sets between nodes
- the fwd and bwd fusions are done together in a single pass to avoid
overheads
The second part of the PR is the overhaul of the actual subgraph
substitution - previously it was done 1 subgraph at a time, with multiple
`DFSVisit` calls per subgraph. Unfortunately `DFSVisit` is costly and most of
that work was wasted (as the DFS over the entire graph was needed for the
substitution of a few nodes. In the new approach a new graph is created based
on the subgraph assignment generated in the previous part, which requires only
a single pass over the graph to apply all subgraphs.
@samskalicky @Caenorst @mk-61
## Checklist ##
### Essentials ###
- [ ] Changes are complete (i.e. I finished coding on this PR)
- [x] All changes have test coverage
- [x] Code is well-documented
## Comments ##
- If this change is a backward incompatible change, why must this change be
made.
- Interesting edge cases to note here
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]