samskalicky commented on a change in pull request #17623: Dynamic subgraph 
compile support
URL: https://github.com/apache/incubator-mxnet/pull/17623#discussion_r388636626
 
 

 ##########
 File path: python/mxnet/gluon/block.py
 ##########
 @@ -1026,6 +1030,69 @@ def _call_cached_op(self, *args):
             out = [out]
         return _regroup(out, self._out_format)
 
+    def optimize_for(self, x, *args, backend=None, backend_opts=None, 
**kwargs):
+        """Partitions the current HybridBlock and optimizes it for a given 
backend
+        without executing a forward pass. Modifies the HybridBlock in-place.
+
+        Immediately partitions a HybridBlock using the specified backend. 
Combines
+        the work done in the hybridize API with part of the work done in the 
forward
+        pass without calling the CachedOp. Can be used in place of hybridize,
+        afterwards `export` can be called or inference can be run. See 
README.md in
+        example/extensions/lib_subgraph/README.md for more details.
+
+        Examples
+        --------
+        # partition and then export to file
+        block.optimize_for(x, backend='myPart')
+        block.export('partitioned')
+
+        # partition and then run inference
+        block.optimize_for(x, backend='myPart')
+        block(x)
+
+        Parameters
+        ----------
+        x : NDArray
+            first input to model
+        *args : NDArray
+            other inputs to model
+        backend : str
+            The name of backend, as registered in `SubgraphBackendRegistry`, 
default None
+        backend_opts : dict of user-specified options to pass to the backend 
for partitioning, optional
+            Passed on to `PrePartition` and `PostPartition` functions of 
`SubgraphProperty`
+        static_alloc : bool, default False
+            Statically allocate memory to improve speed. Memory usage may 
increase.
+        static_shape : bool, default False
+            Optimize for invariant input shapes between iterations. Must also
+            set static_alloc to True. Change of input shapes is still allowed
+            but slower.
+        """
+
+        # do hybrize API call
+        self.hybridize(True, backend, backend_opts, **kwargs)
+
+        # do part of forward API call
+        has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + 
list(args))
+        if has_symbol:
+            raise ValueError('Inputs must be NDArrays for the optimize_for API'
+                             ' Please check the type of the args.\n')
+        if not has_symbol and not has_ndarray:
+            raise ValueError('In HybridBlock, there must be one NDArray as 
input.'
+                             ' Please check the type of the args.\n')
+        if len(ctx_set) > 1:
+            raise ValueError('Find multiple contexts in the input, '
+                             'After hybridized, the HybridBlock only supports 
one input '
+                             'context. You can print the ele.ctx in the '
+                             'input arguments to inspect their contexts. '
+                             'Find all contexts = {}'.format(ctx_set))
+
+        self._build_cache(x, *args)
 
 Review comment:
   `_build_cache` combines (x, args) back into a single tuple:
   
https://github.com/apache/incubator-mxnet/blob/ade7d48fc7b977e28213ec24b628c61dc1c0c6f0/python/mxnet/gluon/block.py#L923
   
   `_call_cached_op` does the same:
   
https://github.com/apache/incubator-mxnet/blob/ade7d48fc7b977e28213ec24b628c61dc1c0c6f0/python/mxnet/gluon/block.py#L993
   
   So, passing (x, *args) to `_build_cache` results in just a single tuple args 
that contains everything inside the function.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to