This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a4b9802 [Clojure] Helper function for n-dim vector to ndarray (#14305)
a4b9802 is described below
commit a4b98024b951fca91df3188a06141ea3e3411015
Author: Kedar Bellare <[email protected]>
AuthorDate: Mon Mar 11 06:39:34 2019 -0700
[Clojure] Helper function for n-dim vector to ndarray (#14305)
* [Clojure] Helper function for n-dim vector to ndarray
* More tests, specs and rename method
* Address comments
* Allow every number type
---
.../src/org/apache/clojure_mxnet/ndarray.clj | 21 +++++++++++++++++
.../src/org/apache/clojure_mxnet/util.clj | 20 ++++++++++++++++
.../test/org/apache/clojure_mxnet/ndarray_test.clj | 12 ++++++++++
.../test/org/apache/clojure_mxnet/util_test.clj | 27 ++++++++++++++++++++++
4 files changed, 80 insertions(+)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index 151e18b..9caa00d 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -94,6 +94,27 @@
([start stop]
(arange start stop {})))
+(defn ->ndarray
+ "Creates a new NDArray based on the given n-dimenstional vector
+ of numbers.
+ `nd-vec`: n-dimensional vector with numbers.
+ `opts-map` {
+ `ctx`: Context of the output ndarray, will use default context if
unspecified.
+ }
+ returns: `ndarray` with the given values and matching the shape of the
input vector.
+ Ex:
+ (->ndarray [5.0 -4.0])
+ (->ndarray [5 -4] {:ctx (context/cpu)})
+ (->ndarray [[1 2 3] [4 5 6]])
+ (->ndarray [[[1.0] [2.0]]]"
+ ([nd-vec {:keys [ctx]
+ :or {ctx (mx-context/default-context)}
+ :as opts}]
+ (array (vec (clojure.core/flatten nd-vec))
+ (util/nd-seq-shape nd-vec)
+ {:ctx ctx}))
+ ([nd-vec] (->ndarray nd-vec {})))
+
(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 6b5f507..7eb1426 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -218,6 +218,26 @@
(throw (ex-info error-msg
(s/explain-data spec value)))))
+(s/def ::non-empty-seq (s/and sequential? not-empty))
+(defn to-array-nd
+ "Converts any N-D sequential structure to an array
+ with the same dimensions."
+ [nd-seq]
+ (validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
+ (if (sequential? (first nd-seq))
+ (to-array (mapv to-array-nd nd-seq))
+ (to-array nd-seq)))
+
+(defn nd-seq-shape
+ "Computes the shape of a n-dimensional sequential structure"
+ [nd-seq]
+ (validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
+ (loop [s nd-seq
+ shape [(count s)]]
+ (if (sequential? (first s))
+ (recur (first s) (conj shape (count (first s))))
+ shape)))
+
(defn map->scala-tuple-seq
"* Convert a map to a scala-Seq of scala-Tubple.
* Should also work if a seq of seq of 2 things passed.
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
index a9ae296..ee7c16b 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
@@ -146,6 +146,18 @@
(is (= [0.0 0.0 0.5 0.5 1.0 1.0 1.5 1.5 2.0 2.0 2.5 2.5 3.0 3.0 3.5 3.5
4.0 4.0 4.5 4.5]
(->vec (ndarray/arange start stop {:step step :repeat repeat}))))))
+(deftest test->ndarray
+ (let [nda1 (ndarray/->ndarray [5.0 -4.0])
+ nda2 (ndarray/->ndarray [[1 2 3]
+ [4 5 6]])
+ nda3 (ndarray/->ndarray [[[7.0] [8.0]]])]
+ (is (= [5.0 -4.0] (->vec nda1)))
+ (is (= [2] (mx-shape/->vec (shape nda1))))
+ (is (= [1.0 2.0 3.0 4.0 5.0 6.0] (->vec nda2)))
+ (is (= [2 3] (mx-shape/->vec (shape nda2))))
+ (is (= [7.0 8.0] (->vec nda3)))
+ (is (= [1 2 1] (mx-shape/->vec (shape nda3))))))
+
(deftest test-power
(let [nda (ndarray/array [3 5] [2 1])]
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index 4ed7d38..15c4859 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -163,6 +163,33 @@
(is (= [1 2] (-> (util/convert-tuple [1 2])
(util/tuple->vec)))))
+(deftest test-to-array-nd
+ (let [a1 (util/to-array-nd '(1))
+ a2 (util/to-array-nd [1.0 2.0])
+ a3 (util/to-array-nd [[3.0] [4.0]])
+ a4 (util/to-array-nd [[[5 -5]]])]
+ (is (= 1 (alength a1)))
+ (is (= [1] (->> a1 vec)))
+ (is (= 2 (alength a2)))
+ (is (= 2.0 (aget a2 1)))
+ (is (= [1.0 2.0] (->> a2 vec)))
+ (is (= 2 (alength a3)))
+ (is (= 1 (alength (aget a3 0))))
+ (is (= 4.0 (aget a3 1 0)))
+ (is (= [[3.0] [4.0]] (->> a3 vec (mapv vec))))
+ (is (= 1 (alength a4)))
+ (is (= 1 (alength (aget a4 0))))
+ (is (= 2 (alength (aget a4 0 0))))
+ (is (= 5 (aget a4 0 0 0)))
+ (is (= [[[5 -5]]] (->> a4 vec (mapv vec) (mapv #(mapv vec %)))))))
+
+(deftest test-nd-seq-shape
+ (is (= [1] (util/nd-seq-shape '(5))))
+ (is (= [2] (util/nd-seq-shape [1.0 2.0])))
+ (is (= [3] (util/nd-seq-shape [1 1 1])))
+ (is (= [2 1] (util/nd-seq-shape [[3.0] [4.0]])))
+ (is (= [1 3 2] (util/nd-seq-shape [[[5 -5] [5 -5] [5 -5]]]))))
+
(deftest test-coerce-return
(is (= [] (util/coerce-return (ArrayBuffer.))))
(is (= [1 2 3] (util/coerce-return (util/vec->indexed-seq [1 2 3]))))