sandeep-krishnamurthy commented on a change in pull request #12376: [MXNET-854] 
SVRG Optimization in Python Module API
URL: https://github.com/apache/incubator-mxnet/pull/12376#discussion_r213406018
 
 

 ##########
 File path: contrib/svrg_optimization_python/src/svrg_module.py
 ##########
 @@ -0,0 +1,581 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A `SVRGModule` implements the `Module` API by wrapping an auxiliary module 
to perform
+SVRG optimization logic.
+"""
+
+import mxnet as mx
+import time
+import logging
+from svrg_optimizer import SVRGOptimizer
+from mxnet.module import Module
+
+
+class SVRGModule(Module):
+    """SVRGModule is a module that encapsulates two Modules to accommodate the 
SVRG optimization technique.
+    It is functionally the same as Module API, except it is implemented using 
SVRG optimization logic.
+
+    Parameters
+    ----------
+    symbol : Symbol
+    data_names : list of str
+        Defaults to `('data')` for a typical model used in image 
classification.
+    label_names : list of str
+        Defaults to `('softmax_label')` for a typical model used in image
+        classification.
+    logger : Logger
+        Defaults to `logging`.
+    context : Context or list of Context
+        Defaults to ``mx.cpu()``.
+    work_load_list : list of number
+        Default ``None``, indicating uniform workload.
+    fixed_param_names: list of str
+        Default ``None``, indicating no network parameters are fixed.
+    state_names : list of str
+        states are similar to data and label, but not provided by data 
iterator.
+        Instead they are initialized to 0 and can be set by `set_states()`.
+    group2ctxs : dict of str to context or list of context,
+                 or list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context 
assignment.
+    compression_params : dict
+        Specifies type of gradient compression and additional arguments 
depending
+        on the type of compression being used. For example, 2bit compression 
requires a threshold.
+        Arguments would then be {'type':'2bit', 'threshold':0.5}
+        See mxnet.KVStore.set_gradient_compression method for more details on 
gradient compression.
+    update_freq: int
+        Specifies the number of times to update the full gradients to be used 
in the SVRG optimization. For instance,
+        update_freq = 2 will calculates the gradients over all data every two 
epochs
+    Examples
+    --------
+    >>> # An example of declaring and using SVRGModule.
+    >>> mod = mod = SVRGModule(symbol=lro, data_names=['data'], 
label_names=['lin_reg_label'], update_freq=2)
 
 Review comment:
   Please correct this example.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to