This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8feb826   #13441 [Clojure] Add Spec Validations for the Random 
namespace (#13523)
8feb826 is described below

commit 8feb826e0b707531f8596e209e2d37598fd7a4d7
Author: Nicolas Modrzyk <[email protected]>
AuthorDate: Fri Dec 7 02:06:16 2018 +0900

     #13441 [Clojure] Add Spec Validations for the Random namespace (#13523)
---
 .../neural-style/src/neural_style/core.clj         |  2 +-
 .../src/org/apache/clojure_mxnet/optimizer.clj     | 26 +++++++++----------
 .../src/org/apache/clojure_mxnet/random.clj        | 30 +++++++++++++++++++---
 .../org/apache/clojure_mxnet/operator_test.clj     |  2 +-
 .../test/org/apache/clojure_mxnet/random_test.clj  | 17 ++++++++++--
 5 files changed, 57 insertions(+), 20 deletions(-)

diff --git 
a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj 
b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
index 50f95c9..fcf402f 100644
--- a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
+++ b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
@@ -193,7 +193,7 @@
         ;;;train
 
         ;;initialize with random noise
-        img (ndarray/- (random/uniform 0 255 content-np-shape dev) 128)
+        img (ndarray/- (random/uniform 0 255 content-np-shape {:ctx dev}) 128)
         ;;; img (random/uniform -0.1 0.1 content-np-shape dev)
         ;; img content-np
         lr-sched (lr-scheduler/factor-scheduler 10 0.9)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
index f77f553..672090a 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
@@ -24,11 +24,11 @@
    (org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam 
SGLD)
    (org.apache.mxnet FactorScheduler)))
 
-(s/def ::learning-rate float?)
-(s/def ::momentum float?)
-(s/def ::wd float?)
-(s/def ::clip-gradient float?)
-(s/def ::lr-scheduler #(instance? FactorScheduler))
+(s/def ::learning-rate number?)
+(s/def ::momentum number?)
+(s/def ::wd number?)
+(s/def ::clip-gradient number?)
+(s/def ::lr-scheduler #(instance? FactorScheduler %))
 (s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd 
::clip-gradient ::lr-scheduler]))
 
 (defn sgd
@@ -43,7 +43,7 @@
   ([]
    (sgd {})))
 
-(s/def ::lambda float?)
+(s/def ::lambda number?)
 (s/def ::dcasgd-opts (s/keys :opt-un [::learning-rate ::momentum ::lambda ::wd 
::clip-gradient ::lr-scheduler]))
 
 (defn dcasgd
@@ -77,9 +77,9 @@
   ([]
    (nag {})))
 
-(s/def ::rho float?)
-(s/def ::rescale-gradient float?)
-(s/def ::epsilon float?)
+(s/def ::rho number?)
+(s/def ::rescale-gradient number?)
+(s/def ::epsilon number?)
 (s/def ::ada-delta-opts (s/keys :opt-un [::rho ::rescale-gradient ::epsilon 
::wd ::clip-gradient]))
 
 (defn ada-delta
@@ -96,8 +96,8 @@
   ([]
    (ada-delta {})))
 
-(s/def gamma1 float?)
-(s/def gamma2 float?)
+(s/def gamma1 number?)
+(s/def gamma2 number?)
 (s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient 
::gamma1 ::gamma2 ::wd ::clip-gradient]))
 
 (defn rms-prop
@@ -144,8 +144,8 @@
   ([]
    (ada-grad {})))
 
-(s/def ::beta1 float?)
-(s/def ::beta2 float?)
+(s/def ::beta1 number?)
+(s/def ::beta2 number?)
 (s/def ::adam-opts (s/keys :opt-un [::learning-rate ::beta1 ::beta2 ::epsilon 
::decay-factor ::wd ::clip-gradient ::lr-scheduler]))
 
 (defn adam
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
index d6e3378..0ec2039 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
@@ -16,8 +16,18 @@
 ;;
 
 (ns org.apache.clojure-mxnet.random
-  (:require [org.apache.clojure-mxnet.shape :as mx-shape])
-  (:import (org.apache.mxnet Random)))
+  (:require
+   [org.apache.clojure-mxnet.shape :as mx-shape]
+   [org.apache.clojure-mxnet.context :as context]
+   [clojure.spec.alpha :as s]
+   [org.apache.clojure-mxnet.util :as util])
+  (:import (org.apache.mxnet Context Random)))
+
+(s/def ::low number?)
+(s/def ::high number?)
+(s/def ::shape-vec (s/coll-of pos-int? :kind vector?))
+(s/def ::ctx #(instance? Context %))
+(s/def ::uniform-opts (s/keys :opt-un [::ctx]))
 
 (defn uniform
   "Generate uniform distribution in [low, high) with shape.
@@ -29,10 +39,18 @@
       out: Output place holder}
     returns: The result ndarray with generated result./"
   ([low high shape-vec {:keys [ctx out] :as opts}]
+   (util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
+   (util/validate! ::low low  "Incorrect random uniform parameter")
+   (util/validate! ::high high  "Incorrect random uniform parameters")
+   (util/validate! ::shape-vec shape-vec  "Incorrect random uniform 
parameters")
    (Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx 
out))
   ([low high shape-vec]
    (uniform low high shape-vec {})))
 
+(s/def ::loc number?)
+(s/def ::scale number?)
+(s/def ::normal-opts (s/keys :opt-un [::ctx]))
+
 (defn normal
   "Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
     loc: The standard deviation of the normal distribution
@@ -43,10 +61,15 @@
       out: Output place holder}
     returns: The result ndarray with generated result./"
   ([loc scale shape-vec {:keys [ctx out] :as opts}]
+   (util/validate! ::normal-opts opts  "Incorrect random normal parameters")
+   (util/validate! ::loc loc  "Incorrect random normal parameters")
+   (util/validate! ::scale scale  "Incorrect random normal parameters")
+   (util/validate! ::shape-vec shape-vec  "Incorrect random uniform 
parameters")
    (Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx 
out))
   ([loc scale shape-vec]
    (normal loc scale shape-vec {})))
 
+(s/def ::seed-state number?)
 (defn seed
   " Seed the random number generators in mxnet.
     This seed will affect behavior of functions in this module,
@@ -58,4 +81,5 @@
          This means if you set the same seed, the random number sequence
          generated from GPU0 can be different from CPU."
   [seed-state]
-  (Random/seed (int seed-state)))
+  (util/validate! ::seed-state seed-state  "Incorrect seed parameters")
+  (Random/seed (int seed-state)))
\ No newline at end of file
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
index 1b4b2ea..c97711b 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
@@ -462,7 +462,7 @@
         test (sym/transpose data)
         shape-vec [3 4]
         ctx (context/default-context)
-        arr-data (random/uniform 0 100 shape-vec ctx)
+        arr-data (random/uniform 0 100 shape-vec {:ctx ctx})
         trans (ndarray/transpose (ndarray/copy arr-data))
         exec-test (sym/bind test ctx {"data" arr-data})
         out     (->  (executor/forward exec-test)
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
index c4e9198..6952335 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
@@ -26,9 +26,9 @@
     (let [[a b] [-10 10]
           shape [100 100]
           _ (random/seed 128)
-          un1 (random/uniform a b shape {:context ctx})
+          un1 (random/uniform a b shape {:ctx ctx})
           _ (random/seed 128)
-          un2 (random/uniform a b shape {:context ctx})]
+          un2 (random/uniform a b shape {:ctx ctx})]
       (is (= un1 un2))
       (is (<  (Math/abs
                (/ (/ (apply + (ndarray/->vec un1))
@@ -52,3 +52,16 @@
       (is (<  (Math/abs (- mean mu)) 0.1))
       (is (< (Math/abs (- stddev sigma)) 0.1)))))
 
+(defn random-or-normal [fn_]
+  (is (thrown? Exception (fn_ 'a 2 [])))
+  (is (thrown? Exception (fn_ 1 'b [])))
+  (is (thrown? Exception (fn_ 1 2 [-1])))
+  (is (thrown? Exception (fn_ 1 2 [2 3 0])))
+  (is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"})))
+  (let [ctx (context/default-context)]
+    (is (not (nil? (fn_ 1 1 [100 100] {:ctx ctx}))))))
+
+(deftest test-random-parameters-specs
+  (random-or-normal random/normal)
+  (random-or-normal random/uniform)
+  (is (thrown? Exception (random/seed "a"))))
\ No newline at end of file

Reply via email to