anirudh2290 commented on a change in pull request #15118: Conversion from FP32 
model to Mixed Precision model
URL: https://github.com/apache/incubator-mxnet/pull/15118#discussion_r293171602
 
 

 ##########
 File path: python/mxnet/contrib/amp/amp.py
 ##########
 @@ -342,3 +349,320 @@ def unscale(optimizer_or_trainer):
     else:
         raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
                         "an optimizer, instead is %s" % 
type(optimizer_or_trainer))
+
+def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
+                   fp32_ops=None, conditional_fp32_ops=None,
+                   excluded_sym_names=None, data_names=None):
+    """Given a symbol object representing a neural network of data type FP32 
and target_dtype,
+    add cast layers according to the op lists (target_dtype_ops, fp32_ops,
+    conditional_fp32_ops) if provided, otherwise use the default
+    lists provided by the framework.
+
+    Parameters
+    ----------
+    sym : Symbol
+        FP32 neural network symbol
+    target_dtype : str or numpy, optional defaults to float16
+        currently only supports float16. The target dtype indicates to add 
cast layers
+        when possible so that lower precision computation can be leveraged.
+    target_dtype_ops : list of strs, optional
+        Override the list of operator names casted to the target_dtype.
+        If None, uses the framework's default list to be casted to 
target_dtype.
+    fp32_ops : list of strs, optional
+        Override the list of operator names casted to FP32.
+        If None, uses the framework's default list to be casted to FP32.
+    conditional_fp32_ops : list of (string, string, list of string), optional
+        Override the list of functions to be casted to FP32.
+        The format of the list is
+        (name of the function, name of the parameter,
+         list of values of the parameter that make the operator to be casted 
to FP32)
+    excluded_sym_names : list of strs, optional
+        A list of strings that represent the names of symbols that users want 
to exclude
+        from being casted to FP16 or FP32.
+    data_names : list of strs, optional
+        A list of strings that represent input data tensor names to the model
+    """
+    if target_dtype != "float16":
+        raise ValueError("Only target_dtype float16 is supported currently")
+
+    if target_dtype_ops is not None:
+        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be 
a list of strs"
+    else:
+        target_dtype_ops = lists.symbol.FP16_FUNCS
+
+    if fp32_ops is not None:
+        assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
+    else:
+        fp32_ops = lists.symbol.FP32_FUNCS
+
+    if conditional_fp32_ops is not None:
+        assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops 
should be a list"
+    else:
+        conditional_fp32_ops = lists.symbol.CONDITIONAL_FP32_FUNCS
+
+    original_conditional_op_names = []
+    conditional_op_names = []
+    param_names = []
+    param_vals = []
+    indptr = [0]
+    for conditional_fp32_op in conditional_fp32_ops:
+        assert isinstance(conditional_fp32_op[0], str) and 
isinstance(conditional_fp32_op[1], str) \
+            and isinstance(conditional_fp32_op[2], list), 
"conditional_fp32_ops should be a list of " \
+                                                          "(str, str, list of 
str)"
+        param_vals += conditional_fp32_op[2]
+        indptr.append(len(param_vals))
+        param_names.append(conditional_fp32_op[1])
+        conditional_op_names.append(conditional_fp32_op[0])
+
+    if excluded_sym_names is not None:
+        assert isinstance(excluded_sym_names, list), "excluded_sym_names 
should be a list of strs"
+    else:
+        excluded_sym_names = []
+
+    for original_conditional_fp32_op in lists.symbol.CONDITIONAL_FP32_FUNCS:
+        original_conditional_op_names.append(original_conditional_fp32_op[0])
+
+    # Op lists should not have intersection
+    common_ops = set(target_dtype_ops) & set(fp32_ops)
+    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
+                                 "Common ops in target_dtype_ops and fp32_ops 
{}".format(common_ops)
+    common_ops = set(target_dtype_ops) & set(conditional_op_names)
+    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
+                                 "Common ops in target_dtype_ops and 
conditional_fp32_ops {}".format(common_ops)
+    common_ops = set(conditional_op_names) & set(fp32_ops)
+    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
+                                 "Common ops in fp32_ops and 
conditional_fp32_ops {}".format(common_ops)
+
+    combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
+    all_fp16_fp32_ops = set(lists.symbol.FP16_FUNCS + lists.symbol.FP32_FUNCS
+                            + lists.symbol.FP16_FP32_FUNCS + 
original_conditional_op_names)
+
+    illegal_ops = combined_ops - all_fp16_fp32_ops
+    assert not illegal_ops, '''Can only choose ops from one of the three lists
+                            for fp16_ops and fp32_ops
+                            1. amp.list_fp16_ops()
+                            2. amp.list_fp32_ops()
+                            3. amp.list_fp16_fp32_ops()
+                            4. amp.list_conditional_fp32_ops()
+                            Op %s not in any of them''' % (illegal_ops)
+
+    widest_dtype_ops = lists.symbol.WIDEST_TYPE_CASTS
+    target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type]
+
+    # Prepare a data_names list based on list_inputs if its not provided
+    # Add all names in list for the nodes in the symbol which don't have
+    # __dtype__ set
+    attr_dict = sym.attr_dict()
+    if not data_names:
+        data_names = []
+        for sym_name in sym.list_inputs():
+            if not sym_name in attr_dict:
+                data_names.append(sym_name)
+                continue
+            if not "__dtype__" in attr_dict[sym_name]:
+                data_names.append(sym_name)
+    model_param_names = list(set(sym.list_inputs()) - set(data_names))
+
+    # Since assumption is that it is a FP32 model, set dtypes for all
+    # data_names to float32
+    str_keys = []
+    sdata = []
+    for k in data_names:
+        str_keys.append(k)
+        sdata.append(0)
+    keys = c_str_array(str_keys)
+
+    out = SymbolHandle()
+    check_call(_LIB.MXReducePrecisionSymbol(sym.handle,
+                                            ctypes.byref(out),
+                                            mx_uint(len(sdata)),
+                                            c_array_buf(ctypes.c_int, 
array('i', sdata)),
+                                            mx_uint(len(indptr)),
+                                            c_array_buf(ctypes.c_int, 
array('i', indptr)),
+                                            
ctypes.byref(ctypes.c_int(target_dtype)),
+                                            mx_uint(len(target_dtype_ops)),
+                                            mx_uint(len(fp32_ops)),
+                                            mx_uint(len(widest_dtype_ops)),
+                                            mx_uint(len(conditional_op_names)),
+                                            mx_uint(len(excluded_sym_names)),
+                                            mx_uint(len(model_param_names)),
+                                            c_str_array(target_dtype_ops),
+                                            c_str_array(fp32_ops),
+                                            c_str_array(widest_dtype_ops),
+                                            c_str_array(conditional_op_names),
+                                            c_str_array(excluded_sym_names),
+                                            c_str_array(param_names),
+                                            c_str_array(param_vals),
+                                            c_str_array(model_param_names),
+                                            keys))
+    return Symbol(out)
+
+def convert_model(sym, arg_params, aux_params, target_dtype="float16", 
target_dtype_ops=None,
+                  fp32_ops=None, conditional_fp32_ops=None, 
excluded_sym_names=None):
+    """API for converting a model from FP32 model to a mixed precision model.
+    MXNet tries to convert the FP32 model to mixed precision model by adding
+    cast layers using amp_cast and amp_multicast operators which can be used 
for inference use cases.
+    The decision on which cast layer to add is based on hardcoded lists for 
Automatic Mixed Precision
+    in MXNet. These lists can be overridden by the user by providing their own 
lists
+    using : targe_precision_ops, fp32_ops, widest_precision_ops, 
conditional_fp32_ops
+
+    arg_params : dict
+        Dictionary of name to `NDArray`.
+    aux_params : dict
+        Dictionary of name to `NDArray`.
+    target_dtype : str
+        Currently only supports float16. The target dtype indicates to add 
cast layers
+        when possible so that lower precision computation can be leveraged.
+    target_dtype_ops : list of strs
+        Override the list of operator names casted to target_dtype.
+        If None, uses the framework's default list to be casted to target 
dtype.
+    fp32_ops : list of strs
+        Override the lists of operator names casted to FP32.
+        If None, uses the framework's default list to be casted to FP32.
+    widest_dtype_ops : list of strs
+        A list of op names provided by user which should run in widest 
precision among its inputs.
+        If None, uses the framework's default list of widest_precision_ops.
+    conditional_fp32_ops : list of (string, string, list of string)
+        Override the list of operators to be casted to FP32.
+        The format of the list is
+        (name of the function, name of the parameter,
+         list of values of the parameter that make the operator to be casted to
+        fp32)
+    excluded_sym_names : list of strs
+        A list of strings that represent the names of symbols that users want 
to exclude
+        from being quantized.
+    """
+    if excluded_sym_names is None:
+        excluded_sym_names = []
+        if not isinstance(excluded_sym_names, list):
+            raise ValueError('excluded_sym_names must be a list of strings 
representing'
+                             ' the names of the symbols that should not be 
casted,'
+                             ' while received type %s' % 
str(type(excluded_sym_names)))
+
+    if target_dtype != "float16":
+        raise ValueError("Only target_dtype float16 is supported currently")
+    param_names = list(arg_params.keys()) + list(aux_params.keys())
+
+    # Only pass non params as data_names, param types can be inferred
+    data_names = list(set(sym.list_inputs()) - set(param_names))
+
+    sym = convert_symbol(sym, target_dtype, target_dtype_ops,
+                         fp32_ops, conditional_fp32_ops,
+                         excluded_sym_names, data_names)
+
+    # If dtype is set for params, cast the param to that dtype
+    attr_dict = sym.attr_dict()
+    for sym_name in sym.list_arguments():
+        if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
+            if attr_dict[sym_name]["__dtype__"] != "-1":
+                typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
+                arg_params[sym_name] = arg_params[sym_name].astype(typ)
+
+    for sym_name in sym.list_auxiliary_states():
+        if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
+            if attr_dict[sym_name]["__dtype__"] != "-1":
+                typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
+                aux_params[sym_name] = aux_params[sym_name].astype(typ)
+
+    # Return the converted symbol and casted params
+    return sym, arg_params, aux_params
+
+def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
+                         fp32_ops=None, conditional_fp32_ops=None,
+                         excluded_sym_names=None, ctx=gpu(0)):
+    """Given a hybrid block/symbol block representing a FP32 model and a 
target_dtype,
+    return a block with mixed precision support which can be used for 
inference use cases.
+
+    Parameters
+    ----------
+    block : HybridBlock or SymbolBlock object
+        FP32 HybridBlock or SymbolBlock object
+    target_dtype : str or numpy
+        currently only supports fp16. The target dtype indicates to add cast 
layers
+        when possible so that lower precision computation can be leveraged.
+    target_precision_ops : list of strs
+        Override the list of operator names casted to target_dtype.
+        If None, uses the framework's default list to be casted to FP32.
+    conditional_fp32_ops : list of (str, str, list of str)
+        Override the list of functions to be casted to FP32.
+        The format of the list is
+        (name of the function, name of the parameter,
+         list of values of the parameter that make the operator to be casted 
to FP32
+    excluded_sym_names : list of strs
+        A list of strings that represent the names of symbols that users want 
to exclude
+        from being quantized
+    ctx : Context
+        Context on which model parameters should live
+    """
+    from ...gluon import HybridBlock, SymbolBlock
+    assert isinstance(block, HybridBlock), "block input should be a 
HybridBlock"
+    if not block._cached_graph:
+        raise RuntimeError(
+            "Please first call block.hybridize() and then run forward with "
+            "this block at least once before calling export.")
+
+    # Prepare inputs to pass to the convert_symbol API
+    inputs, sym = block._cached_graph
+    input_names = []
+    for inp in inputs:
+        input_names.append(inp.name)
+    converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops,
+                                   fp32_ops, conditional_fp32_ops,
+                                   excluded_sym_names, data_names=input_names)
+
+    arg_names = set(converted_sym.list_arguments())
+    aux_names = set(converted_sym.list_auxiliary_states())
+    arg_dict = {}
+
+    # If dtype for the param was set in the json, cast the
+    # param to this dtype
+    attr_dict = converted_sym.attr_dict()
+    for name, param in block.collect_params().items():
+        if name in arg_names:
+            arg_dict['arg:%s'%name] = param._reduce()
+            if name in attr_dict and "__dtype__" in attr_dict[name]:
+                if attr_dict[name]["__dtype__"] != "-1":
+                    typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
+                    arg_dict['arg:%s'%name] = 
arg_dict['arg:%s'%name].astype(typ)
+        else:
+            assert name in aux_names
+            arg_dict['aux:%s'%name] = param._reduce()
+            if name in attr_dict and "__dtype__" in attr_dict[name]:
+                if attr_dict[name]["__dtype__"] != "-1":
+                    typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
+                    arg_dict['aux:%s'%name] = 
arg_dict['aux:%s'%name].astype(typ)
+
+    # Create a symbolblock and cast the params to the dtypes based
+    # on the dtype information from the converted_symbol
+    ret = SymbolBlock(converted_sym, inputs)
 
 Review comment:
   Given a symbol and params, SymbolBlock is supposed to be used to load the 
model and run inference on it. Other Hybrid blocks are more for the hybrid 
mode, where you run imperative mode on the first run and cache ops and run 
forward on cached symbol from the second pass.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to