mikepapadim opened a new pull request #10759:
URL: https://github.com/apache/tvm/pull/10759


   This PR adds `DFPattern` support for the TRT backend without removing the 
existing predicate registry. 
   
   Adds and extends the following:
   * In `tensorrt.py`: Add a `pattern_table` for all the supported ops and 
consumes the pre-existing op_registry checks
   * Adds an additional pass as `unmerge_composites.cc`. This is required for 
the TRT backend as it expects a single primitive function to work with, while 
the `MergeComposite` and `PartitionGraph` will produce a single function for 
each `Composite` pattern. 
   
   
   Both the pattern-based and predicate-based pass sequences produce 
syntactically equivalent `IRModules`.
   This is to ensure backwards compatibility."
   
   Original Pass orderding: 
   ```
     seq = tvm.transform.Sequential(
           [
               transform.InferType(),
               RemoveDropoutPass(),
               transform.RemoveUnusedFunctions(),
               transform.ConvertLayout(
                   {
                       "nn.conv1d": ["NCW", "default"],
                       "nn.conv2d": ["NCHW", "default"],
                       "nn.conv3d": ["NCDHW", "default"],
                       "nn.conv2d_transpose": ["NCHW", "default"],
                   }
               ),
               transform.FoldConstant(),
               transform.AnnotateTarget("tensorrt"),
               transform.MergeCompilerRegions(),
               transform.PartitionGraph(),
               transform.InferType(),
           ]
       )
   ```
   
   Pass ordering with MergeComposites and UnmergeComposites:
   
   ```
     seq = tvm.transform.Sequential(
           [
               transform.InferType(),
               RemoveDropoutPass(),
               transform.RemoveUnusedFunctions(),
               transform.ConvertLayout(
                   {
                       "nn.conv1d": ["NCW", "default"],
                       "nn.conv2d": ["NCHW", "default"],
                       "nn.conv3d": ["NCDHW", "default"],
                       "nn.conv2d_transpose": ["NCHW", "default"],
                   }
               ),
               transform.FoldConstant(),
               transform.MergeComposite(pattern_table()),                       
            <-------- Change #1
               transform.AnnotateTarget("tensorrt"),
               transform.MergeCompilerRegions(),
               transform.PartitionGraph(),
               transform.UnmergeComposites("tensorrt"),                         
            <-------- Change #2
               transform.InferType(),
           ]
       )
   ```
   
   @mbs-octoml @mbaret @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to