qsqqsqqsq opened a new pull request #9784:
URL: https://github.com/apache/tvm/pull/9784
This PR add get_packed_func interface to GenericFunc. This interface is used
to get the packed function specified for the current target. As a third-party
user, we can use this interface along with python context manager to temporary
override op strategy with our own device strategy.
We can create temporary stratey registry as below:
```python
with TempOpStrategy("nn.conv2d", "cuda", general_strategy):
lib = relay.build(...)
```
We save the current strategy function when TempOpStrategy init and register
the strategy funcion back when TempOpStrtegy exit.
```python
class TempOpStrategy(object):
def __init__(self, op_name, target, fstrategy):
generic_fstrategy = relay.op.get(op_name).get_attr("FTVMStrategy")
self.op_name = op_name
self.target = target
with tvm.target.Target(target) as target_obj:
self.origin_func = generic_fstrategy.get_packed_func()
for tgt_key in target_obj.keys:
generic_fstrategy.register(fstrategy, tgt_key,
allow_override=True)
def __enter__(self):
return self
def __exit__(self, typ, value, traceback):
generic_fstrategies = relay.op.get(name).get_attr("FTVMStrategy")
with tvm.target.Target(self.target) as target_obj:
for tgt_key in target_obj.keys:
generic_fstrategies.register(self.origin_func, tgt_key,
allow_override=True)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]