gigasquid closed pull request #13523: #13441 [Clojure] Add Spec Validations for
the Random namespace
URL: https://github.com/apache/incubator-mxnet/pull/13523
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
index 50f95c9750e..fcf402f3466 100644
--- a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
+++ b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj
@@ -193,7 +193,7 @@
;;;train
;;initialize with random noise
- img (ndarray/- (random/uniform 0 255 content-np-shape dev) 128)
+ img (ndarray/- (random/uniform 0 255 content-np-shape {:ctx dev}) 128)
;;; img (random/uniform -0.1 0.1 content-np-shape dev)
;; img content-np
lr-sched (lr-scheduler/factor-scheduler 10 0.9)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
index f77f5532bfb..672090a899b 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
@@ -24,11 +24,11 @@
(org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam
SGLD)
(org.apache.mxnet FactorScheduler)))
-(s/def ::learning-rate float?)
-(s/def ::momentum float?)
-(s/def ::wd float?)
-(s/def ::clip-gradient float?)
-(s/def ::lr-scheduler #(instance? FactorScheduler))
+(s/def ::learning-rate number?)
+(s/def ::momentum number?)
+(s/def ::wd number?)
+(s/def ::clip-gradient number?)
+(s/def ::lr-scheduler #(instance? FactorScheduler %))
(s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd
::clip-gradient ::lr-scheduler]))
(defn sgd
@@ -43,7 +43,7 @@
([]
(sgd {})))
-(s/def ::lambda float?)
+(s/def ::lambda number?)
(s/def ::dcasgd-opts (s/keys :opt-un [::learning-rate ::momentum ::lambda ::wd
::clip-gradient ::lr-scheduler]))
(defn dcasgd
@@ -77,9 +77,9 @@
([]
(nag {})))
-(s/def ::rho float?)
-(s/def ::rescale-gradient float?)
-(s/def ::epsilon float?)
+(s/def ::rho number?)
+(s/def ::rescale-gradient number?)
+(s/def ::epsilon number?)
(s/def ::ada-delta-opts (s/keys :opt-un [::rho ::rescale-gradient ::epsilon
::wd ::clip-gradient]))
(defn ada-delta
@@ -96,8 +96,8 @@
([]
(ada-delta {})))
-(s/def gamma1 float?)
-(s/def gamma2 float?)
+(s/def gamma1 number?)
+(s/def gamma2 number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient
::gamma1 ::gamma2 ::wd ::clip-gradient]))
(defn rms-prop
@@ -144,8 +144,8 @@
([]
(ada-grad {})))
-(s/def ::beta1 float?)
-(s/def ::beta2 float?)
+(s/def ::beta1 number?)
+(s/def ::beta2 number?)
(s/def ::adam-opts (s/keys :opt-un [::learning-rate ::beta1 ::beta2 ::epsilon
::decay-factor ::wd ::clip-gradient ::lr-scheduler]))
(defn adam
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
index d6e33789a62..0ec2039ba79 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
@@ -16,8 +16,18 @@
;;
(ns org.apache.clojure-mxnet.random
- (:require [org.apache.clojure-mxnet.shape :as mx-shape])
- (:import (org.apache.mxnet Random)))
+ (:require
+ [org.apache.clojure-mxnet.shape :as mx-shape]
+ [org.apache.clojure-mxnet.context :as context]
+ [clojure.spec.alpha :as s]
+ [org.apache.clojure-mxnet.util :as util])
+ (:import (org.apache.mxnet Context Random)))
+
+(s/def ::low number?)
+(s/def ::high number?)
+(s/def ::shape-vec (s/coll-of pos-int? :kind vector?))
+(s/def ::ctx #(instance? Context %))
+(s/def ::uniform-opts (s/keys :opt-un [::ctx]))
(defn uniform
"Generate uniform distribution in [low, high) with shape.
@@ -29,10 +39,18 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([low high shape-vec {:keys [ctx out] :as opts}]
+ (util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
+ (util/validate! ::low low "Incorrect random uniform parameter")
+ (util/validate! ::high high "Incorrect random uniform parameters")
+ (util/validate! ::shape-vec shape-vec "Incorrect random uniform
parameters")
(Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx
out))
([low high shape-vec]
(uniform low high shape-vec {})))
+(s/def ::loc number?)
+(s/def ::scale number?)
+(s/def ::normal-opts (s/keys :opt-un [::ctx]))
+
(defn normal
"Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
loc: The standard deviation of the normal distribution
@@ -43,10 +61,15 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([loc scale shape-vec {:keys [ctx out] :as opts}]
+ (util/validate! ::normal-opts opts "Incorrect random normal parameters")
+ (util/validate! ::loc loc "Incorrect random normal parameters")
+ (util/validate! ::scale scale "Incorrect random normal parameters")
+ (util/validate! ::shape-vec shape-vec "Incorrect random uniform
parameters")
(Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx
out))
([loc scale shape-vec]
(normal loc scale shape-vec {})))
+(s/def ::seed-state number?)
(defn seed
" Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
@@ -58,4 +81,5 @@
This means if you set the same seed, the random number sequence
generated from GPU0 can be different from CPU."
[seed-state]
- (Random/seed (int seed-state)))
+ (util/validate! ::seed-state seed-state "Incorrect seed parameters")
+ (Random/seed (int seed-state)))
\ No newline at end of file
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
index 1b4b2ea2fbe..c97711b5fed 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
@@ -462,7 +462,7 @@
test (sym/transpose data)
shape-vec [3 4]
ctx (context/default-context)
- arr-data (random/uniform 0 100 shape-vec ctx)
+ arr-data (random/uniform 0 100 shape-vec {:ctx ctx})
trans (ndarray/transpose (ndarray/copy arr-data))
exec-test (sym/bind test ctx {"data" arr-data})
out (-> (executor/forward exec-test)
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
index c4e9198073a..6952335c139 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj
@@ -26,9 +26,9 @@
(let [[a b] [-10 10]
shape [100 100]
_ (random/seed 128)
- un1 (random/uniform a b shape {:context ctx})
+ un1 (random/uniform a b shape {:ctx ctx})
_ (random/seed 128)
- un2 (random/uniform a b shape {:context ctx})]
+ un2 (random/uniform a b shape {:ctx ctx})]
(is (= un1 un2))
(is (< (Math/abs
(/ (/ (apply + (ndarray/->vec un1))
@@ -52,3 +52,16 @@
(is (< (Math/abs (- mean mu)) 0.1))
(is (< (Math/abs (- stddev sigma)) 0.1)))))
+(defn random-or-normal [fn_]
+ (is (thrown? Exception (fn_ 'a 2 [])))
+ (is (thrown? Exception (fn_ 1 'b [])))
+ (is (thrown? Exception (fn_ 1 2 [-1])))
+ (is (thrown? Exception (fn_ 1 2 [2 3 0])))
+ (is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"})))
+ (let [ctx (context/default-context)]
+ (is (not (nil? (fn_ 1 1 [100 100] {:ctx ctx}))))))
+
+(deftest test-random-parameters-specs
+ (random-or-normal random/normal)
+ (random-or-normal random/uniform)
+ (is (thrown? Exception (random/seed "a"))))
\ No newline at end of file
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services