This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a4b9802  [Clojure] Helper function for n-dim vector to ndarray (#14305)
a4b9802 is described below

commit a4b98024b951fca91df3188a06141ea3e3411015
Author: Kedar Bellare <[email protected]>
AuthorDate: Mon Mar 11 06:39:34 2019 -0700

    [Clojure] Helper function for n-dim vector to ndarray (#14305)
    
    * [Clojure] Helper function for n-dim vector to ndarray
    
    * More tests, specs and rename method
    
    * Address comments
    
    * Allow every number type
---
 .../src/org/apache/clojure_mxnet/ndarray.clj       | 21 +++++++++++++++++
 .../src/org/apache/clojure_mxnet/util.clj          | 20 ++++++++++++++++
 .../test/org/apache/clojure_mxnet/ndarray_test.clj | 12 ++++++++++
 .../test/org/apache/clojure_mxnet/util_test.clj    | 27 ++++++++++++++++++++++
 4 files changed, 80 insertions(+)

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index 151e18b..9caa00d 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -94,6 +94,27 @@
   ([start stop]
    (arange start stop {})))
 
+(defn ->ndarray
+  "Creates a new NDArray based on the given n-dimenstional vector
+   of numbers.
+    `nd-vec`: n-dimensional vector with numbers.
+    `opts-map` {
+       `ctx`: Context of the output ndarray, will use default context if 
unspecified.
+    }
+    returns: `ndarray` with the given values and matching the shape of the 
input vector.
+   Ex:
+    (->ndarray [5.0 -4.0])
+    (->ndarray [5 -4] {:ctx (context/cpu)})
+    (->ndarray [[1 2 3] [4 5 6]])
+    (->ndarray [[[1.0] [2.0]]]"
+  ([nd-vec {:keys [ctx]
+            :or {ctx (mx-context/default-context)}
+            :as opts}]
+   (array (vec (clojure.core/flatten nd-vec))
+          (util/nd-seq-shape nd-vec)
+          {:ctx ctx}))
+  ([nd-vec] (->ndarray nd-vec {})))
+
 (defn slice
   "Return a sliced NDArray that shares memory with current one."
   ([ndarray i]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 6b5f507..7eb1426 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -218,6 +218,26 @@
     (throw (ex-info error-msg
                     (s/explain-data spec value)))))
 
+(s/def ::non-empty-seq (s/and sequential? not-empty))
+(defn to-array-nd
+  "Converts any N-D sequential structure to an array
+   with the same dimensions."
+  [nd-seq]
+  (validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
+  (if (sequential? (first nd-seq))
+    (to-array (mapv to-array-nd nd-seq))
+    (to-array nd-seq)))
+
+(defn nd-seq-shape
+  "Computes the shape of a n-dimensional sequential structure"
+  [nd-seq]
+  (validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
+  (loop [s nd-seq
+         shape [(count s)]]
+    (if (sequential? (first s))
+      (recur (first s) (conj shape (count (first s))))
+      shape)))
+
 (defn map->scala-tuple-seq
   "* Convert a map to a scala-Seq of scala-Tubple.
    * Should also work if a seq of seq of 2 things passed.
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
index a9ae296..ee7c16b 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
@@ -146,6 +146,18 @@
     (is (= [0.0 0.0 0.5 0.5 1.0 1.0 1.5 1.5 2.0 2.0 2.5 2.5 3.0 3.0 3.5 3.5 
4.0 4.0 4.5 4.5]
            (->vec (ndarray/arange start stop {:step step :repeat repeat}))))))
 
+(deftest test->ndarray
+  (let [nda1 (ndarray/->ndarray [5.0 -4.0])
+        nda2 (ndarray/->ndarray [[1 2 3]
+                                 [4 5 6]])
+        nda3 (ndarray/->ndarray [[[7.0] [8.0]]])]
+    (is (= [5.0 -4.0] (->vec nda1)))
+    (is (= [2] (mx-shape/->vec (shape nda1))))
+    (is (= [1.0 2.0 3.0 4.0 5.0 6.0] (->vec nda2)))
+    (is (= [2 3] (mx-shape/->vec (shape nda2))))
+    (is (= [7.0 8.0] (->vec nda3)))
+    (is (= [1 2 1] (mx-shape/->vec (shape nda3))))))
+
 (deftest test-power
   (let [nda (ndarray/array [3 5] [2 1])]
 
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index 4ed7d38..15c4859 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -163,6 +163,33 @@
   (is (= [1 2] (-> (util/convert-tuple [1 2])
                    (util/tuple->vec)))))
 
+(deftest test-to-array-nd
+  (let [a1 (util/to-array-nd '(1))
+        a2 (util/to-array-nd [1.0 2.0])
+        a3 (util/to-array-nd [[3.0] [4.0]])
+        a4 (util/to-array-nd [[[5 -5]]])]
+    (is (= 1 (alength a1)))
+    (is (= [1] (->> a1 vec)))
+    (is (= 2 (alength a2)))
+    (is (= 2.0 (aget a2 1)))
+    (is (= [1.0 2.0] (->> a2 vec)))
+    (is (= 2 (alength a3)))
+    (is (= 1 (alength (aget a3 0))))
+    (is (= 4.0 (aget a3 1 0)))
+    (is (= [[3.0] [4.0]] (->> a3 vec (mapv vec))))
+    (is (= 1 (alength a4)))
+    (is (= 1 (alength (aget a4 0))))
+    (is (= 2 (alength (aget a4 0 0))))
+    (is (= 5 (aget a4 0 0 0)))
+    (is (= [[[5 -5]]] (->> a4 vec (mapv vec) (mapv #(mapv vec %)))))))
+
+(deftest test-nd-seq-shape
+  (is (= [1] (util/nd-seq-shape '(5))))
+  (is (= [2] (util/nd-seq-shape [1.0 2.0])))
+  (is (= [3] (util/nd-seq-shape [1 1 1])))
+  (is (= [2 1] (util/nd-seq-shape [[3.0] [4.0]])))
+  (is (= [1 3 2] (util/nd-seq-shape [[[5 -5] [5 -5] [5 -5]]]))))
+
 (deftest test-coerce-return
   (is (= [] (util/coerce-return (ArrayBuffer.))))
   (is (= [1 2 3] (util/coerce-return (util/vec->indexed-seq [1 2 3]))))

Reply via email to