sandeep-krishnamurthy commented on a change in pull request #12412: Infer dtype
in SymbolBlock import from input symbol
URL: https://github.com/apache/incubator-mxnet/pull/12412#discussion_r218573269
##########
File path: python/mxnet/gluon/block.py
##########
@@ -1084,5 +1092,71 @@ def _clear_cached_op(self):
super(SymbolBlock, self)._clear_cached_op()
self._cached_graph = tmp
+ def cast(self, dtype):
+ self._clear_cached_op()
+ super(SymbolBlock, self).cast(dtype)
+
def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
+
+def _infer_param_types(in_params, out_params, arg_params, aux_params,
default_dtype=mx_real_t):
+ """Utility function that helps in inferring DType of args and auxs params
+ from given input param.
+
+ Parameters
+ ----------
+ in_params: List of Symbol
+ List of input symbol variables.
+ out_params: Symbol
+ Output symbol variable.
+ arg_params: List of Str
+ List of names of argument parametrs.
+ aux_params: List of Str
+ List of names of auxiliary parameters.
+ default_dtype: numpy.dtype or str, default 'float32'
+ Default data type for arg_params and aux_params, if unable to infer
the type.
+
+ Returns
+ -------
+ arg_types: List of numpy.dtype
+ List of arg_params type. Order is same as arg_params.
+ Defaults to 'float32', if unable to infer type.
+ aux_types: List of numpy.dtype
+ List of aux_params type. Order is same as aux_params.
+ Defaults to 'float32', if unable to infer type.
+ """
+ arg_types = None
+ aux_types = None
+
+ # Get Input symbol details. This will be used to infer types of
+ # other parameters.
+ input_sym_names = [in_param.name for in_param in in_params]
+
+ # Try to infer input types. If not successful, we will set default dtype.
Review comment:
In one of the unit test(pasting below) infer shape will be false and it need
not be an issue, users are just assuming default type.
```python
model = nn.HybridSequential()
model.add(nn.Dense(128, activation='tanh'))
model.add(nn.Dropout(0.5))
model.add(nn.Dense(64, activation='tanh'),
nn.Dense(32, in_units=64))
model.add(nn.Activation('relu'))
model.initialize()
inputs = mx.sym.var('data')
outputs = model(inputs).get_internals()
smodel = gluon.SymbolBlock(outputs, inputs,
params=model.collect_params())
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services