piiswrong commented on a change in pull request #7654: Making mixed precision 
work with all optimizers
URL: https://github.com/apache/incubator-mxnet/pull/7654#discussion_r140909368
 
 

 ##########
 File path: python/mxnet/optimizer.py
 ##########
 @@ -173,6 +182,36 @@ def create_state(self, index, weight):
             The state associated with the weight.
         """
 
+    def create_mp_state(self, index, weight):
+        """Creates auxiliary state for a given weight, including FP32 master
+        copy if necessary.
+
+        This method is provided to perform automatic mixed precision training
+        for optimizers that do not support it themselves.
+
+        Parameters
+        ----------
+        index : int
+            An unique index to identify the weight.
+        weight : NDArray
+            The weight.
+
+        Returns
+        -------
+        state : any obj
+            The state associated with the weight.
+        """
+        weight_master_copy = None
+        if self.multi_precision and weight.dtype == numpy.float16:
+            weight_master_copy = array(weight, ctx=weight.context, 
dtype=numpy.float32)
 
 Review comment:
   weight.astype(float32)
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to