gigasquid closed pull request #13499: [MXNET-13453] [Clojure] - Add Spec 
Validations to the Optimizer namespace
URL: https://github.com/apache/incubator-mxnet/pull/13499
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
index f18ff40f569..f77f5532bfb 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
@@ -17,7 +17,19 @@
 
 (ns org.apache.clojure-mxnet.optimizer
   (:refer-clojure :exclude [update])
-  (:import (org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad 
Adam SGLD)))
+  (:require  
+   [clojure.spec.alpha :as s]
+   [org.apache.clojure-mxnet.util :as util])
+  (:import 
+   (org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam 
SGLD)
+   (org.apache.mxnet FactorScheduler)))
+
+(s/def ::learning-rate float?)
+(s/def ::momentum float?)
+(s/def ::wd float?)
+(s/def ::clip-gradient float?)
+(s/def ::lr-scheduler #(instance? FactorScheduler))
+(s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd 
::clip-gradient ::lr-scheduler]))
 
 (defn sgd
   "A very simple SGD optimizer with momentum and weight regularization."
@@ -26,10 +38,14 @@
           momentum 0.0
           wd 0.0001
           clip-gradient 0}}]
+   (util/validate! ::sgd-opts opts "Incorrect sgd optimizer options")
    (new SGD (float learning-rate) (float momentum) (float wd) (float 
clip-gradient) lr-scheduler))
   ([]
    (sgd {})))
 
+(s/def ::lambda float?)
+(s/def ::dcasgd-opts (s/keys :opt-un [::learning-rate ::momentum ::lambda ::wd 
::clip-gradient ::lr-scheduler]))
+
 (defn dcasgd
   "DCASGD optimizer with momentum and weight regularization.
   Implementation of paper 'Asynchronous Stochastic Gradient Descent with
@@ -40,10 +56,13 @@
           lambda 0.04
           wd 0.0
           clip-gradient 0}}]
+   (util/validate! ::sgd-opts opts "Incorrect dcasgd optimizer options")
    (new DCASGD (float learning-rate) (float lambda) (float momentum) (float 
wd) (float clip-gradient) lr-scheduler))
   ([]
    (dcasgd {})))
 
+(s/def ::nag-opts (s/keys :opt-un [::learning-rate ::momentum ::wd 
::clip-gradient ::lr-scheduler]))
+
 (defn nag
   "SGD with nesterov.
    It is implemented according to
@@ -53,10 +72,16 @@
           momentum 0.0
           wd 0.0001
           clip-gradient 0}}]
+   (util/validate! ::nag-opts opts "Incorrect nag optimizer options")
    (new NAG (float learning-rate) (float momentum) (float wd) (float 
clip-gradient) lr-scheduler))
   ([]
    (nag {})))
 
+(s/def ::rho float?)
+(s/def ::rescale-gradient float?)
+(s/def ::epsilon float?)
+(s/def ::ada-delta-opts (s/keys :opt-un [::rho ::rescale-gradient ::epsilon 
::wd ::clip-gradient]))
+
 (defn ada-delta
   "AdaDelta optimizer as described in Matthew D. Zeiler, 2012.
    http://arxiv.org/abs/1212.5701";
@@ -66,10 +91,15 @@
           epsilon 1e-8
           wd 0.0
           clip-gradient 0}}]
+   (util/validate! ::ada-delta-opts opts "Incorrect ada-delta optimizer 
options")
    (new AdaDelta (float rho) (float rescale-gradient) (float epsilon) (float 
wd) (float clip-gradient)))
   ([]
    (ada-delta {})))
 
+(s/def gamma1 float?)
+(s/def gamma2 float?)
+(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient 
::gamma1 ::gamma2 ::wd ::clip-gradient]))
+
 (defn rms-prop
   "RMSProp optimizer as described in Tieleman & Hinton, 2012.
    http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.
@@ -80,18 +110,21 @@
    -  wd L2 regularization coefficient add to all the weights
    -  clip-gradient clip gradient in range [-clip_gradient, clip_gradient]
    -  lr-scheduler The learning rate scheduler"
-  ([{:keys [learning-rate rescale-gradient gamma1 gamma2 wd lr-scheduler 
clip-gradient]
+  ([{:keys [learning-rate rescale-gradient gamma1 gamma2 wd lr-scheduler 
clip-gradient] :as opts
      :or {learning-rate 0.002
           rescale-gradient 1.0
           gamma1 0.95
           gamma2 0.9
           wd 0.0
           clip-gradient 0}}]
+   (util/validate! ::rms-prop-opts opts "Incorrect rms-prop optimizer options")
    (new RMSProp (float learning-rate) (float rescale-gradient) (float gamma1)
         (float gamma2) (float wd) lr-scheduler (float clip-gradient)))
   ([]
    (rms-prop {})))
 
+(s/def ::ada-grad-opts (s/keys :opt-un [::learning-rate ::rescale-gradient 
::epsilon ::wd]))
+
 (defn ada-grad
   " AdaGrad optimizer as described in Duchi, Hazan and Singer, 2011.
    http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
@@ -101,15 +134,20 @@
                 Default value is set to 1e-7.
    - rescale-gradient rescaling factor of gradient.
    - wd L2 regularization coefficient add to all the weights"
-  ([{:keys [learning-rate rescale-gradient epsilon wd]
+  ([{:keys [learning-rate rescale-gradient epsilon wd] :as opts
      :or {learning-rate 0.05
           rescale-gradient 1.0
           epsilon 1e-7
           wd 0.0}}]
+   (util/validate! ::ada-grad-opts opts "Incorrect ada-grad optimizer options")
    (new AdaGrad (float learning-rate) (float rescale-gradient) (float epsilon) 
(float wd)))
   ([]
    (ada-grad {})))
 
+(s/def ::beta1 float?)
+(s/def ::beta2 float?)
+(s/def ::adam-opts (s/keys :opt-un [::learning-rate ::beta1 ::beta2 ::epsilon 
::decay-factor ::wd ::clip-gradient ::lr-scheduler]))
+
 (defn adam
   "Adam optimizer as described in [King2014]
 
@@ -125,7 +163,7 @@
    - wd L2 regularization coefficient add to all the weights
    - clip-gradient  clip gradient in range [-clip_gradient, clip_gradient]
    - lr-scheduler The learning rate scheduler"
-  ([{:keys [learning-rate beta1 beta2 epsilon decay-factor wd clip-gradient 
lr-scheduler]
+  ([{:keys [learning-rate beta1 beta2 epsilon decay-factor wd clip-gradient 
lr-scheduler] :as opts
      :or {learning-rate 0.002
           beta1 0.9
           beta2 0.999
@@ -133,11 +171,14 @@
           decay-factor (- 1 1e-8)
           wd 0
           clip-gradient 0}}]
+   (util/validate! ::adam-opts opts "Incorrect adam optimizer options")
    (new Adam (float learning-rate) (float beta1) (float beta2) (float epsilon)
         (float decay-factor) (float wd) (float clip-gradient) lr-scheduler))
   ([]
    (adam {})))
 
+(s/def ::sgld-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::wd 
::clip-gradient ::lr-scheduler]))
+
 (defn sgld
   "Stochastic Langevin Dynamics Updater to sample from a distribution.
 
@@ -146,11 +187,12 @@
   - wd L2 regularization coefficient add to all the weights
   - clip-gradient Float, clip gradient in range [-clip_gradient, clip_gradient]
   - lr-scheduler The learning rate scheduler"
-  ([{:keys [learning-rate rescale-gradient wd clip-gradient lr-scheduler]
+  ([{:keys [learning-rate rescale-gradient wd clip-gradient lr-scheduler] :as 
opts
      :or {learning-rate 0.01
           rescale-gradient 1
           wd 0.0001
           clip-gradient 0}}]
+   (util/validate! ::sgld-opts opts "Incorrect sgld optimizer options")
    (new SGLD (float learning-rate) (float rescale-gradient) (float wd)
         (float clip-gradient) lr-scheduler))
   ([]
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/optimizer_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/optimizer_test.clj
index f6461b10f02..599a0672bea 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/optimizer_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/optimizer_test.clj
@@ -44,3 +44,13 @@
               ["sgld" optimizer/sgld]]]
     (doseq [opt opts]
       (test-optimizer opt))))
+
+(deftest test-optimizers-parameters-specs
+  (is (thrown? Exception (optimizer/sgd {:wd 'a})))
+  (is (thrown? Exception (optimizer/dcasgd {:lambda 'a})))
+  (is (thrown? Exception (optimizer/nag {:momentum 'a})))
+  (is (thrown? Exception (optimizer/ada-delta {:epsilon 'a})))
+  (is (thrown? Exception (optimizer/rms-prop {:gamma1 'a})))
+  (is (thrown? Exception (optimizer/ada-grad {:rescale-gradient 'a})))
+  (is (thrown? Exception (optimizer/adam {:beta1 'a})))
+  (is (thrown? Exception (optimizer/sgld {:lr-scheduler 0.1}))))
\ No newline at end of file


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to