This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8feb826 #13441 [Clojure] Add Spec Validations for the Random
namespace (#13523)
8feb826 is described below
commit 8feb826e0b707531f8596e209e2d37598fd7a4d7
Author: Nicolas Modrzyk <[email protected]>
AuthorDate: Fri Dec 7 02:06:16 2018 +0900
#13441 [Clojure] Add Spec Validations for the Random namespace (#13523)
---
.../neural-style/src/neural_style/core.clj | 2 +-
.../src/org/apache/clojure_mxnet/optimizer.clj | 26 +++++++++----------
.../src/org/apache/clojure_mxnet/random.clj | 30 +++++++++++++++++++---
.../org/apache/clojure_mxnet/operator_test.clj | 2 +-
.../test/org/apache/clojure_mxnet/random_test.clj | 17 ++++++++++--
5 files changed, 57 insertions(+), 20 deletions(-)
diff --git
a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
index 50f95c9..fcf402f 100644
--- a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
+++ b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
@@ -193,7 +193,7 @@
;;;train
;;initialize with random noise
- img (ndarray/- (random/uniform 0 255 content-np-shape dev) 128)
+ img (ndarray/- (random/uniform 0 255 content-np-shape {:ctx dev}) 128)
;;; img (random/uniform -0.1 0.1 content-np-shape dev)
;; img content-np
lr-sched (lr-scheduler/factor-scheduler 10 0.9)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
index f77f553..672090a 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
@@ -24,11 +24,11 @@
(org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam
SGLD)
(org.apache.mxnet FactorScheduler)))
-(s/def ::learning-rate float?)
-(s/def ::momentum float?)
-(s/def ::wd float?)
-(s/def ::clip-gradient float?)
-(s/def ::lr-scheduler #(instance? FactorScheduler))
+(s/def ::learning-rate number?)
+(s/def ::momentum number?)
+(s/def ::wd number?)
+(s/def ::clip-gradient number?)
+(s/def ::lr-scheduler #(instance? FactorScheduler %))
(s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd
::clip-gradient ::lr-scheduler]))
(defn sgd
@@ -43,7 +43,7 @@
([]
(sgd {})))
-(s/def ::lambda float?)
+(s/def ::lambda number?)
(s/def ::dcasgd-opts (s/keys :opt-un [::learning-rate ::momentum ::lambda ::wd
::clip-gradient ::lr-scheduler]))
(defn dcasgd
@@ -77,9 +77,9 @@
([]
(nag {})))
-(s/def ::rho float?)
-(s/def ::rescale-gradient float?)
-(s/def ::epsilon float?)
+(s/def ::rho number?)
+(s/def ::rescale-gradient number?)
+(s/def ::epsilon number?)
(s/def ::ada-delta-opts (s/keys :opt-un [::rho ::rescale-gradient ::epsilon
::wd ::clip-gradient]))
(defn ada-delta
@@ -96,8 +96,8 @@
([]
(ada-delta {})))
-(s/def gamma1 float?)
-(s/def gamma2 float?)
+(s/def gamma1 number?)
+(s/def gamma2 number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient
::gamma1 ::gamma2 ::wd ::clip-gradient]))
(defn rms-prop
@@ -144,8 +144,8 @@
([]
(ada-grad {})))
-(s/def ::beta1 float?)
-(s/def ::beta2 float?)
+(s/def ::beta1 number?)
+(s/def ::beta2 number?)
(s/def ::adam-opts (s/keys :opt-un [::learning-rate ::beta1 ::beta2 ::epsilon
::decay-factor ::wd ::clip-gradient ::lr-scheduler]))
(defn adam
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
index d6e3378..0ec2039 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
@@ -16,8 +16,18 @@
;;
(ns org.apache.clojure-mxnet.random
- (:require [org.apache.clojure-mxnet.shape :as mx-shape])
- (:import (org.apache.mxnet Random)))
+ (:require
+ [org.apache.clojure-mxnet.shape :as mx-shape]
+ [org.apache.clojure-mxnet.context :as context]
+ [clojure.spec.alpha :as s]
+ [org.apache.clojure-mxnet.util :as util])
+ (:import (org.apache.mxnet Context Random)))
+
+(s/def ::low number?)
+(s/def ::high number?)
+(s/def ::shape-vec (s/coll-of pos-int? :kind vector?))
+(s/def ::ctx #(instance? Context %))
+(s/def ::uniform-opts (s/keys :opt-un [::ctx]))
(defn uniform
"Generate uniform distribution in [low, high) with shape.
@@ -29,10 +39,18 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([low high shape-vec {:keys [ctx out] :as opts}]
+ (util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
+ (util/validate! ::low low "Incorrect random uniform parameter")
+ (util/validate! ::high high "Incorrect random uniform parameters")
+ (util/validate! ::shape-vec shape-vec "Incorrect random uniform
parameters")
(Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx
out))
([low high shape-vec]
(uniform low high shape-vec {})))
+(s/def ::loc number?)
+(s/def ::scale number?)
+(s/def ::normal-opts (s/keys :opt-un [::ctx]))
+
(defn normal
"Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
loc: The standard deviation of the normal distribution
@@ -43,10 +61,15 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([loc scale shape-vec {:keys [ctx out] :as opts}]
+ (util/validate! ::normal-opts opts "Incorrect random normal parameters")
+ (util/validate! ::loc loc "Incorrect random normal parameters")
+ (util/validate! ::scale scale "Incorrect random normal parameters")
+ (util/validate! ::shape-vec shape-vec "Incorrect random uniform
parameters")
(Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx
out))
([loc scale shape-vec]
(normal loc scale shape-vec {})))
+(s/def ::seed-state number?)
(defn seed
" Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
@@ -58,4 +81,5 @@
This means if you set the same seed, the random number sequence
generated from GPU0 can be different from CPU."
[seed-state]
- (Random/seed (int seed-state)))
+ (util/validate! ::seed-state seed-state "Incorrect seed parameters")
+ (Random/seed (int seed-state)))
\ No newline at end of file
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
index 1b4b2ea..c97711b 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
@@ -462,7 +462,7 @@
test (sym/transpose data)
shape-vec [3 4]
ctx (context/default-context)
- arr-data (random/uniform 0 100 shape-vec ctx)
+ arr-data (random/uniform 0 100 shape-vec {:ctx ctx})
trans (ndarray/transpose (ndarray/copy arr-data))
exec-test (sym/bind test ctx {"data" arr-data})
out (-> (executor/forward exec-test)
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
index c4e9198..6952335 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
@@ -26,9 +26,9 @@
(let [[a b] [-10 10]
shape [100 100]
_ (random/seed 128)
- un1 (random/uniform a b shape {:context ctx})
+ un1 (random/uniform a b shape {:ctx ctx})
_ (random/seed 128)
- un2 (random/uniform a b shape {:context ctx})]
+ un2 (random/uniform a b shape {:ctx ctx})]
(is (= un1 un2))
(is (< (Math/abs
(/ (/ (apply + (ndarray/->vec un1))
@@ -52,3 +52,16 @@
(is (< (Math/abs (- mean mu)) 0.1))
(is (< (Math/abs (- stddev sigma)) 0.1)))))
+(defn random-or-normal [fn_]
+ (is (thrown? Exception (fn_ 'a 2 [])))
+ (is (thrown? Exception (fn_ 1 'b [])))
+ (is (thrown? Exception (fn_ 1 2 [-1])))
+ (is (thrown? Exception (fn_ 1 2 [2 3 0])))
+ (is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"})))
+ (let [ctx (context/default-context)]
+ (is (not (nil? (fn_ 1 1 [100 100] {:ctx ctx}))))))
+
+(deftest test-random-parameters-specs
+ (random-or-normal random/normal)
+ (random-or-normal random/uniform)
+ (is (thrown? Exception (random/seed "a"))))
\ No newline at end of file