qsqqsqqsq opened a new pull request #9784:
URL: https://github.com/apache/tvm/pull/9784


   This PR add get_packed_func interface to GenericFunc. This interface is used 
to get the packed function specified for the current target. As a third-party 
user, we can use this interface along with python context manager to temporary 
override op strategy with our own device strategy. 
   We can create temporary stratey registry as below:
   ```python
   with TempOpStrategy("nn.conv2d", "cuda", general_strategy):
       lib = relay.build(...)
   ```
   We save the current strategy function when TempOpStrategy init and register 
the strategy funcion back when TempOpStrtegy exit.
   ```python
   class TempOpStrategy(object):
       def __init__(self, op_name, target, fstrategy):
           generic_fstrategy = relay.op.get(op_name).get_attr("FTVMStrategy")
           self.op_name = op_name
           self.target = target
           with tvm.target.Target(target) as target_obj:
               self.origin_func = generic_fstrategy.get_packed_func()
               for tgt_key in target_obj.keys:
                   generic_fstrategy.register(fstrategy, tgt_key, 
allow_override=True)
   
       def __enter__(self):
           return self
   
       def __exit__(self, typ, value, traceback):
           generic_fstrategies = relay.op.get(name).get_attr("FTVMStrategy")
           with tvm.target.Target(self.target) as target_obj:
               for tgt_key in target_obj.keys:
                   generic_fstrategies.register(self.origin_func, tgt_key, 
allow_override=True)
   ```
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to