sandeep-krishnamurthy commented on a change in pull request #12376: [MXNET-854] SVRG Optimization in Python Module API URL: https://github.com/apache/incubator-mxnet/pull/12376#discussion_r213406018
########## File path: contrib/svrg_optimization_python/src/svrg_module.py ########## @@ -0,0 +1,581 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A `SVRGModule` implements the `Module` API by wrapping an auxiliary module to perform +SVRG optimization logic. +""" + +import mxnet as mx +import time +import logging +from svrg_optimizer import SVRGOptimizer +from mxnet.module import Module + + +class SVRGModule(Module): + """SVRGModule is a module that encapsulates two Modules to accommodate the SVRG optimization technique. + It is functionally the same as Module API, except it is implemented using SVRG optimization logic. + + Parameters + ---------- + symbol : Symbol + data_names : list of str + Defaults to `('data')` for a typical model used in image classification. + label_names : list of str + Defaults to `('softmax_label')` for a typical model used in image + classification. + logger : Logger + Defaults to `logging`. + context : Context or list of Context + Defaults to ``mx.cpu()``. + work_load_list : list of number + Default ``None``, indicating uniform workload. + fixed_param_names: list of str + Default ``None``, indicating no network parameters are fixed. + state_names : list of str + states are similar to data and label, but not provided by data iterator. + Instead they are initialized to 0 and can be set by `set_states()`. + group2ctxs : dict of str to context or list of context, + or list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. + update_freq: int + Specifies the number of times to update the full gradients to be used in the SVRG optimization. For instance, + update_freq = 2 will calculates the gradients over all data every two epochs + Examples + -------- + >>> # An example of declaring and using SVRGModule. + >>> mod = mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2) Review comment: Please correct this example. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
