ptrendx opened a new pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269


   ## Description ##
   This PR greatly increases the speed of pointwise fusion graph pass. As a 
test case I used XLNet from Gluon-NLP which did expose slowness in this graph 
pass before: #17105 - the total time of fwd+bwd fusion after the fix from that 
issue was ~8s, after this PR the time is ~11 ms, with 3 ms of that taken by 
IndexedGraph construction from the original graph (so the actual fusion graph 
pass takes ~8ms for that case - about 1000x improvement).
   
   The motivation of this PR is getting the pointwise fusion graph pass to be 
lightweight enough to possibly be run every time the shapes change in the 
graph, enabling the fusion to be shape-aware. This is important as in 2.0 NumPy 
semantics make multiple operations (like add, sub etc.) broadcast by default, 
even if in the end they are simple elementwise operations. This PR does not 
fully get us there (I would much like it being <5ms for the network like XLNet 
to be sure it is lightweight enough to not be bottleneck for all usecases), 
although there are some parts that would not be needed anymore if we could make 
this pass after infershape is already done (e.g. we would not need to insert 
`FusedOpHelper`/`FusedOpOutHelper`, which currently takes a little over 1ms). 
That  said, it is a big step in that direction.
   
   The main problem that the fusion graph pass needs to solve is not allowing 
the cycles to be formed: nodes which both consume the output of and provide the 
input to a single subgraph. The original pointwise fusion graph pass' algorithm 
to avoid cycles was to first construct a mapping, which nodes are excluded to 
be in the same subgraph with a given node. The construction of the mapping was 
simple but inefficient (and then improved in #17114, but still pretty slow):
   ```
   for each node n that does not qualify to be in a subgraph:
       outputs = nodes reached by DFS from n in the direction of n's outputs
       inputs = nodes reached by DFS from n in the direction of n's inputs
       for each output in outputs:
           for each input in inputs:
               put (input, output) and (output, input) pair into the exclusion 
mapping
   ```
   This was `O(n^3)` algorithm (improved in #17114 to be `O(n^2)` on average) 
and required a separate `BidirectionalGraph` datastructure that enabled DFS in 
both directions.
   
   The new algorithm for pointwise fusion graph pass is designed based on 2 
observations:
    - most of the entries in the exclusion mappings are not useful, as those 
nodes are not considered to be part of the same subgraph
    - we can traverse the graph in topological order, which the original 
algorithm did not take advantage of
   
   In the new algorithm the graph is traversed in the topological order (and so 
we are sure that all the inputs of the current node were already processed) and 
each node has its own exclusion set of subgraphs that it can't be a part of. 
The exclusion set of node is constructed as the union of the exclusion sets of 
its inputs + all the subsets that its inputs are part of if the node itself is 
ineligible to be in a subgraph. Because subsets can merge (e.g. when you have 
operation `a + b` where `a` is part of subgraph s1 and `b` is part of subgraph 
s2, then s1 and s2 need to be merged into a single subgraph containing 
everything from s1, s2 and the `+` operator), the mapping is maintained to know 
which other subsets the current subset is merged with. 
   
   There are number of additional optimizations:
    - because of the topological ordering of the graph traversal, subset ids in 
the exclusion set are typically consecutive number (or a small group of 
consecutive numbers) -> therefore in the exclusion set we actually keep 
intervals of numbers instead of the numbers themselves
    - in most cases the union of exclusion sets is equal to one of the sets -> 
in order to avoid costly unnecessary memory allocations, we share the exclusion 
sets between nodes
    - the fwd and bwd fusions are done together in a single pass to avoid 
overheads
   
   The second part of the PR is the overhaul of the actual subgraph 
substitution - previously it was done 1 subgraph at a time, with multiple 
`DFSVisit` calls per subgraph. Unfortunately `DFSVisit` is costly and most of 
that work was wasted (as the DFS over the entire graph was needed for the 
substitution of a few nodes. In the new approach a new graph is created based 
on the subgraph assignment generated in the previous part, which requires only 
a single pass over the graph to apply all subgraphs.
   
   @samskalicky @Caenorst @mk-61
   
   
   ## Checklist ##
   ### Essentials ###
   - [ ] Changes are complete (i.e. I finished coding on this PR)
   - [x] All changes have test coverage
   - [x] Code is well-documented
   
   ## Comments ##
   - If this change is a backward incompatible change, why must this change be 
made.
   - Interesting edge cases to note here
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to