This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a5db391 [Clojure] Add methods based on NDArrayAPI/SymbolAPI (#14195)
a5db391 is described below
commit a5db391cb27342bf8e267cfbfb4db26b5ef66721
Author: Kedar Bellare <[email protected]>
AuthorDate: Sat Apr 13 12:19:06 2019 -0700
[Clojure] Add methods based on NDArrayAPI/SymbolAPI (#14195)
* [Clojure] Add methods based on NDArrayAPI/SymbolAPI
* Add symbol API methods and ndarray API unit tests
* Some more ndarray API unit tests
* Explore direct use of JNI
* Use library info directly instead of reflection
* Add tests for generation op info
* Fix ordering of keys using array-map
* Ignore generated test files
* Minor style changes
* Refactor code for better readability
* Address comments
* Small tweaks to symbol api coercion
---
contrib/clojure-package/.gitignore | 2 +
contrib/clojure-package/src/dev/generator.clj | 460 ++++++++++++++++-----
.../src/org/apache/clojure_mxnet/ndarray_api.clj | 32 ++
.../src/org/apache/clojure_mxnet/symbol_api.clj | 32 ++
.../src/org/apache/clojure_mxnet/util.clj | 6 +-
.../clojure-package/test/dev/generator_test.clj | 148 ++++++-
.../clojure-package/test/good-test-ndarray-api.clj | 89 ++++
.../clojure-package/test/good-test-symbol-api.clj | 109 +++++
.../test/org/apache/clojure_mxnet/conv_test.clj | 24 +-
.../org/apache/clojure_mxnet/ndarray_api_test.clj | 415 +++++++++++++++++++
.../org/apache/clojure_mxnet/symbol_api_test.clj | 61 +++
11 files changed, 1257 insertions(+), 121 deletions(-)
diff --git a/contrib/clojure-package/.gitignore
b/contrib/clojure-package/.gitignore
index f5d81dd..71d812e 100644
--- a/contrib/clojure-package/.gitignore
+++ b/contrib/clojure-package/.gitignore
@@ -39,6 +39,8 @@ examples/visualization/test-vis.pdf
src/.DS_Store
src/org/.DS_Store
test/test-ndarray.clj
+test/test-ndarray-api.clj
test/test-symbol.clj
+test/test-symbol-api.clj
src/org/apache/clojure_mxnet/gen/*
diff --git a/contrib/clojure-package/src/dev/generator.clj
b/contrib/clojure-package/src/dev/generator.clj
index ca93c34..34210be 100644
--- a/contrib/clojure-package/src/dev/generator.clj
+++ b/contrib/clojure-package/src/dev/generator.clj
@@ -17,10 +17,14 @@
(ns dev.generator
(:require [t6.from-scala.core :as scala]
+ [t6.from-scala.core :refer [$ $$] :as $]
[clojure.reflect :as r]
- [org.apache.clojure-mxnet.util :as util]
- [clojure.pprint])
- (:import (org.apache.mxnet NDArray Symbol))
+ [clojure.pprint]
+ [org.apache.clojure-mxnet.util :as util])
+ (:import (org.apache.mxnet NDArray NDArrayAPI
+ Symbol SymbolAPI
+ Base Base$RefInt Base$RefLong Base$RefFloat
Base$RefString)
+ (scala.collection.mutable ListBuffer ArrayBuffer))
(:gen-class))
@@ -34,17 +38,17 @@
(clojure.string/replace #"\_" "-")
(clojure.string/replace #"\/" "div")))
-(defn symbol-transform-param-name [parameter-types]
+(defn transform-param-names [coerce-fn parameter-types]
(->> parameter-types
(map str)
- (map (fn [x] (or (util/symbol-param-coerce x) x)))
+ (map (fn [x] (or (coerce-fn x) x)))
(map (fn [x] (last (clojure.string/split x #"\."))))))
+(defn symbol-transform-param-name [parameter-types]
+ (transform-param-names util/symbol-param-coerce parameter-types))
+
(defn ndarray-transform-param-name [parameter-types]
- (->> parameter-types
- (map str)
- (map (fn [x] (or (util/ndarray-param-coerce x) x)))
- (map (fn [x] (last (clojure.string/split x #"\."))))))
+ (transform-param-names util/ndarray-param-coerce parameter-types))
(defn has-variadic? [params]
(->> params
@@ -56,37 +60,136 @@
(defn increment-param-name [pname]
(if-let [num-str (re-find #"-\d" pname)]
- (str (first (clojure.string/split pname #"-")) "-" (inc (Integer/parseInt
(last (clojure.string/split num-str #"-")))))
+ (str
+ (first (clojure.string/split pname #"-"))
+ "-"
+ (inc (Integer/parseInt (last (clojure.string/split num-str #"-")))))
(str pname "-" 1)))
-(defn rename-duplicate-params [params]
- (reduce (fn [known-names n] (conj known-names (if (contains? (set
known-names) n)
- (increment-param-name n)
- n)))
- []
- params))
-
+(defn rename-duplicate-params [pnames]
+ (->> (reduce
+ (fn [pname-counts n]
+ (let [rn (if (pname-counts n) (str n "-" (pname-counts n)) n)
+ inc-pname-counts (update-in pname-counts [n] (fnil inc 0))]
+ (update-in inc-pname-counts [:params] conj rn)))
+ {:params []}
+ pnames)
+ :params))
+
+(defn get-public-no-default-methods [obj]
+ (->> (r/reflect obj)
+ :members
+ (map #(into {} %))
+ (filter #(-> % :flags :public))
+ (remove #(re-find #"org\$apache\$mxnet" (str (:name %))))
+ (remove #(re-find #"\$default" (str (:name %))))))
+
+(defn get-public-to-gen-methods [public-to-hand-gen public-no-default]
+ (let [public-to-hand-gen-names
+ (into #{} (mapv (comp str :name) public-to-hand-gen))]
+ (remove #(-> % :name str public-to-hand-gen-names) public-no-default)))
-;;;;;;; symbol
+(defn public-by-name-and-param-count [public-reflect-info]
+ (->> public-reflect-info
+ (group-by :name)
+ (map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)]))
+ (into {})))
-(def symbol-reflect-info (->> (:members (r/reflect Symbol))
- (map #(into {} %))))
+(def license
+ (str
+ ";; Licensed to the Apache Software Foundation (ASF) under one or more\n"
+ ";; contributor license agreements. See the NOTICE file distributed with\n"
+ ";; this work for additional information regarding copyright ownership.\n"
+ ";; The ASF licenses this file to You under the Apache License, Version
2.0\n"
+ ";; (the \"License\"); you may not use this file except in compliance
with\n"
+ ";; the License. You may obtain a copy of the License at\n"
+ ";;\n"
+ ";; http://www.apache.org/licenses/LICENSE-2.0\n"
+ ";;\n"
+ ";; Unless required by applicable law or agreed to in writing, software\n"
+ ";; distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ ";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied.\n"
+ ";; See the License for the specific language governing permissions and\n"
+ ";; limitations under the License.\n"
+ ";;\n"))
-(def symbol-public (filter (fn [x] (-> x :flags :public)) symbol-reflect-info))
+(defn write-to-file [functions ns-gen fname]
+ (with-open [w (clojure.java.io/writer fname)]
+ (.write w ns-gen)
+ (.write w "\n\n")
+ (.write w ";; Do not edit - this is auto-generated")
+ (.write w "\n\n")
+ (.write w license)
+ (.write w "\n\n")
+ (.write w "\n\n")
+ (doseq [f functions]
+ (clojure.pprint/pprint f w)
+ (.write w "\n"))))
-(def symbol-public-no-default (->> symbol-public
- (filter #(not (re-find
#"org\$apache\$mxnet" (str (:name %)))))
- (filter #(not (re-find #"\$default" (str
(:name %)))))))
+;;;;;;; Common operations
+
+(def libinfo (Base/_LIB))
+(def op-names
+ (let [l ($ ListBuffer/empty)]
+ (do (.mxListAllOpNames libinfo l)
+ (remove #(or (= "Custom" %)
+ (re-matches #"^_.*" %))
+ (util/buffer->vec l)))))
+
+(defn- parse-arg-type [s]
+ (let [[_ var-arg-type _ set-arg-type arg-spec _ type-req _ default-val]
(re-find
#"(([\w-\[\]\s]+)|\{([^}]+)\})\s*(\([^)]+\))?(,\s*(optional|required)(,\s*default=(.*))?)?"
s)]
+ {:type (clojure.string/trim (or set-arg-type var-arg-type))
+ :spec arg-spec
+ :optional? (or (= "optional" type-req)
+ (= "boolean" var-arg-type))
+ :default default-val
+ :orig s}))
+
+(defn- get-op-handle [op-name]
+ (let [ref (new Base$RefLong 0)]
+ (do (.nnGetOpHandle libinfo op-name ref)
+ (.value ref))))
+
+(defn gen-op-info [op-name]
+ (let [handle (get-op-handle op-name)
+ name (new Base$RefString nil)
+ desc (new Base$RefString nil)
+ key-var-num-args (new Base$RefString nil)
+ num-args (new Base$RefInt 0)
+ arg-names ($ ListBuffer/empty)
+ arg-types ($ ListBuffer/empty)
+ arg-descs ($ ListBuffer/empty)]
+ (do (.mxSymbolGetAtomicSymbolInfo libinfo
+ handle
+ name
+ desc
+ num-args
+ arg-names
+ arg-types
+ arg-descs
+ key-var-num-args)
+ {:fn-name (clojure-case (.value name))
+ :fn-description (.value desc)
+ :args (mapv (fn [t n d] (assoc t :name n :description d))
+ (mapv parse-arg-type (util/buffer->vec arg-types))
+ (mapv clojure-case (util/buffer->vec arg-names))
+ (util/buffer->vec arg-descs))
+ :key-var-num-args (clojure-case (.value key-var-num-args))})))
+
+;;;;;;; Symbol
+
+(def symbol-public-no-default
+ (get-public-no-default-methods Symbol))
(into #{} (mapcat :parameter-types symbol-public-no-default))
- ;#{java.lang.Object
scala.collection.Seq scala.Option long double scala.collection.immutable.Map
int ml.dmlc.mxnet.Executor float ml.dmlc.mxnet.Context java.lang.String
scala.Enumeration$Value ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<>
ml.dmlc.mxnet.Shape java.lang.String<>}
+;; #{java.lang.Object scala.collection.Seq scala.Option long double
scala.collection.immutable.Map int ml.dmlc.mxnet.Executor float
ml.dmlc.mxnet.Context java.lang.String scala.Enumeration$Value
ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> ml.dmlc.mxnet.Shape
java.lang.String<>}
-(def symbol-hand-gen-set #{"scala.Option"
- "int org.apache.mxnet.Executor"
- "scala.Enumeration$Value"
- "org.apache.mxnet.Context"
- "scala.Tuple2"
- "scala.collection.Traversable"} )
+(def symbol-hand-gen-set
+ #{"scala.Option"
+ "scala.Enumeration$Value"
+ "org.apache.mxnet.Context"
+ "scala.Tuple2"
+ "scala.collection.Traversable"})
;;; min and max have a conflicting arity of 2 with the auto gen signatures
(def symbol-filter-name-set #{"max" "min"})
@@ -102,34 +205,35 @@
count
pos?)))
-(def symbol-public-to-hand-gen (filter is-symbol-hand-gen?
symbol-public-no-default))
-(def symbol-public-to-gen (->> (remove #(contains?(->>
symbol-public-to-hand-gen
- (mapv :name)
- (mapv str)
- (set)) (str (:name
%))) symbol-public-no-default)))
+(def symbol-public-to-hand-gen
+ (filter is-symbol-hand-gen? symbol-public-no-default))
+(def symbol-public-to-gen
+ (get-public-to-gen-methods symbol-public-to-hand-gen
+ symbol-public-no-default))
(count symbol-public-to-hand-gen) ;=> 35 mostly bind!
(count symbol-public-to-gen) ;=> 307
-(into #{} (map :name symbol-public-to-hand-gen));=> #{arange bind ones zeros
simpleBind Variable}
+(into #{} (map :name symbol-public-to-hand-gen))
+;;=> #{arange bind ones zeros simpleBind Variable}
-(defn public-by-name-and-param-count [public-reflect-info]
- (->> public-reflect-info
- (group-by :name)
- (map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)]))
- (into {})))
(defn symbol-vector-args []
- `(if (map? ~'kwargs-map-or-vec-or-sym) (~'util/empty-list)
(~'util/coerce-param ~'kwargs-map-or-vec-or-sym #{"scala.collection.Seq"})))
+ `(if (map? ~'kwargs-map-or-vec-or-sym)
+ (~'util/empty-list)
+ (~'util/coerce-param ~'kwargs-map-or-vec-or-sym
#{"scala.collection.Seq"})))
(defn symbol-map-args []
- `(if (map? ~'kwargs-map-or-vec-or-sym) (util/convert-symbol-map
~'kwargs-map-or-vec-or-sym) nil))
+ `(if (map? ~'kwargs-map-or-vec-or-sym)
+ (util/convert-symbol-map ~'kwargs-map-or-vec-or-sym)
+ nil))
(defn add-symbol-arities [params function-name]
- (if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"] (mapv str
params))
+ (if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"]
+ (mapv str params))
[`([~'sym-name ~'attr-map ~'kwargs-map]
(~function-name ~'sym-name (~'util/convert-symbol-map ~'attr-map)
(~'util/empty-list) (~'util/convert-symbol-map ~'kwargs-map)))
`([~'sym-name ~'kwargs-map-or-vec-or-sym]
@@ -180,36 +284,7 @@
`(~'defn ~function-name
~@(remove nil? (gen-symbol-function-arity op-name op-values
function-name))))))
-(def license
- (str
- ";; Licensed to the Apache Software Foundation (ASF) under one or more\n"
- ";; contributor license agreements. See the NOTICE file distributed with\n"
- ";; this work for additional information regarding copyright ownership.\n"
- ";; The ASF licenses this file to You under the Apache License, Version
2.0\n"
- ";; (the \"License\"); you may not use this file except in compliance
with\n"
- ";; the License. You may obtain a copy of the License at\n"
- ";;\n"
- ";; http://www.apache.org/licenses/LICENSE-2.0\n"
- ";;\n"
- ";; Unless required by applicable law or agreed to in writing, software\n"
- ";; distributed under the License is distributed on an \"AS IS\" BASIS,\n"
- ";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied.\n"
- ";; See the License for the specific language governing permissions and\n"
- ";; limitations under the License.\n"
- ";;\n"))
-(defn write-to-file [functions ns-gen fname]
- (with-open [w (clojure.java.io/writer fname)]
- (.write w ns-gen)
- (.write w "\n\n")
- (.write w ";; Do not edit - this is auto-generated")
- (.write w "\n\n")
- (.write w license)
- (.write w "\n\n")
- (.write w "\n\n")
- (doseq [f functions]
- (clojure.pprint/pprint f w)
- (.write w "\n"))))
(def symbol-gen-ns "(ns org.apache.clojure-mxnet.symbol
(:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten
load max
@@ -223,25 +298,18 @@
(println "Generating symbol file")
(write-to-file all-symbol-functions symbol-gen-ns
"src/org/apache/clojure_mxnet/gen/symbol.clj"))
+;;;;;;; NDArray
-;;;;;;;;NDARRAY
-
-
-(def ndarray-reflect-info (->> (:members (r/reflect NDArray))
- (map #(into {} %))))
+(def ndarray-public-no-default
+ (get-public-no-default-methods NDArray))
-(def ndarray-public (filter (fn [x] (-> x :flags :public))
ndarray-reflect-info))
-
-(def ndarray-public-no-default (->> ndarray-public
- (filter #(not (re-find
#"org\$apache\$mxnet" (str (:name %)))))
- (filter #(not (re-find #"\$default" (str
(:name %)))))))
-
-(def ndarray-hand-gen-set #{"org.apache.mxnet.NDArrayFuncReturn"
- "org.apache.mxnet.Context"
- "scala.Enumeration$Value"
- "scala.Tuple2"
- "scala.collection.Traversable"} )
+(def ndarray-hand-gen-set
+ #{"org.apache.mxnet.NDArrayFuncReturn"
+ "org.apache.mxnet.Context"
+ "scala.Enumeration$Value"
+ "scala.Tuple2"
+ "scala.collection.Traversable"})
(defn is-ndarray-hand-gen? [info]
(->> (map str (:parameter-types info))
@@ -251,17 +319,17 @@
pos?))
-(def ndarray-public-to-hand-gen (filter is-ndarray-hand-gen?
ndarray-public-no-default))
-(def ndarray-public-to-gen (->> (remove #(contains?(->>
ndarray-public-to-hand-gen
- (mapv :name)
- (mapv str)
- (set)) (str (:name
%))) ndarray-public-no-default)))
+(def ndarray-public-to-hand-gen
+ (filter is-ndarray-hand-gen? ndarray-public-no-default))
+(def ndarray-public-to-gen
+ (get-public-to-gen-methods ndarray-public-to-hand-gen
+ ndarray-public-no-default))
(count ndarray-public-to-hand-gen) ;=> 15
(count ndarray-public-to-gen) ;=> 486
-(map :name ndarray-public-to-hand-gen)
+(->> ndarray-public-to-hand-gen (map :name) (into #{}))
@@ -294,16 +362,19 @@
)))))
+(defn gen-ndarray-functions [public-to-gen-methods]
+ (for [operation (sort (public-by-name-and-param-count
public-to-gen-methods))]
+ (let [[op-name op-values] operation
+ function-name (-> op-name
+ str
+ scala/decode-scala-symbol
+ clojure-case
+ symbol)]
+ `(~'defn ~function-name
+ ~@(remove nil? (gen-ndarray-function-arity op-name op-values))))))
+
(def all-ndarray-functions
- (for [operation (sort (public-by-name-and-param-count
ndarray-public-to-gen))]
- (let [[op-name op-values] operation
- function-name (-> op-name
- str
- scala/decode-scala-symbol
- clojure-case
- symbol)]
- `(~'defn ~function-name
- ~@(remove nil? (gen-ndarray-function-arity op-name op-values))))))
+ (gen-ndarray-functions ndarray-public-to-gen))
(def ndarray-gen-ns "(ns org.apache.clojure-mxnet.ndarray
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity
load max
@@ -314,16 +385,191 @@
(defn generate-ndarray-file []
(println "Generating ndarray file")
- (write-to-file all-ndarray-functions ndarray-gen-ns
"src/org/apache/clojure_mxnet/gen/ndarray.clj"))
+ (write-to-file all-ndarray-functions
+ ndarray-gen-ns
+ "src/org/apache/clojure_mxnet/gen/ndarray.clj"))
+
+;;;;;;; SymbolAPI
+
+(defn symbol-api-coerce-param
+ [{:keys [name sym type optional?]}]
+ (let [coerced-param (case type
+ "Shape" `(when ~sym (~'mx-shape/->shape ~sym))
+ "NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym)
+ "Map[String, String]"
+ `(when ~sym
+ (->> ~sym
+ (mapv (fn [[~'k ~'v]] [~'k (str ~'v)]))
+ (into {})
+ ~'util/convert-map))
+ sym)
+ nil-param-allowed? (#{"name" "attr"} name)]
+ (if (and optional? (not nil-param-allowed?))
+ `(~'util/->option ~coerced-param)
+ coerced-param)))
+
+(defn gen-symbol-api-doc [fn-description params]
+ (let [param-descriptions (mapv (fn [{:keys [name description optional?]}]
+ (str "`" name "`: "
+ description
+ (when optional? " (optional)")
+ "\n"))
+ params)]
+ (str fn-description "\n\n"
+ (apply str param-descriptions))))
+
+(defn gen-symbol-api-default-arity [op-name params]
+ (let [opt-params (filter :optional? params)
+ coerced-params (mapv symbol-api-coerce-param params)
+ default-args (array-map :keys (mapv :sym params)
+ :or (into {}
+ (mapv (fn [{:keys [sym]}] [sym nil])
+ opt-params))
+ :as 'opts)]
+ `([~default-args]
+ (~'util/coerce-return
+ (~(symbol (str "SymbolAPI/" op-name))
+ ~@coerced-params)))))
+
+(defn gen-symbol-api-function [op-name]
+ (let [{:keys [fn-name fn-description args]} (gen-op-info op-name)
+ params (mapv (fn [{:keys [name type optional?] :as opts}]
+ (assoc opts
+ :sym (symbol name)
+ :optional? (or optional?
+ (= "NDArray-or-Symbol" type))))
+ (conj args
+ {:name "name"
+ :type "String"
+ :optional? true
+ :description "Name of the symbol"}
+ {:name "attr"
+ :type "Map[String, String]"
+ :optional? true
+ :description "Attributes of the symbol"}))
+ doc (gen-symbol-api-doc fn-description params)
+ default-call (gen-symbol-api-default-arity op-name params)]
+ `(~'defn ~(symbol fn-name)
+ ~doc
+ ~@default-call)))
+
+(def all-symbol-api-functions
+ (mapv gen-symbol-api-function op-names))
+
+(def symbol-api-gen-ns "(ns
+ ^{:doc \"Experimental\"}
+ org.apache.clojure-mxnet.symbol-api
+ (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten
load max
+ min repeat reverse set sort take to-array empty sin
+ get apply shuffle ref])
+ (:require [org.apache.clojure-mxnet.util :as util]
+ [org.apache.clojure-mxnet.shape :as mx-shape])
+ (:import (org.apache.mxnet SymbolAPI)))")
+
+(defn generate-symbol-api-file []
+ (println "Generating symbol-api file")
+ (write-to-file all-symbol-api-functions symbol-api-gen-ns
"src/org/apache/clojure_mxnet/gen/symbol_api.clj"))
+
+;;;;;;; NDArrayAPI
+
+(defn ndarray-api-coerce-param
+ [{:keys [sym type optional?]}]
+ (let [coerced-param (case type
+ "Shape" `(when ~sym (~'mx-shape/->shape ~sym))
+ "NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym)
+ sym)]
+ (if optional?
+ `(~'util/->option ~coerced-param)
+ coerced-param)))
+
+(defn gen-ndarray-api-doc [fn-description params]
+ (let [param-descriptions (mapv (fn [{:keys [name description optional?]}]
+ (str "`" name "`: "
+ description
+ (when optional? " (optional)")
+ "\n"))
+ params)]
+ (str fn-description "\n\n"
+ (apply str param-descriptions))))
+
+(defn gen-ndarray-api-default-arity [op-name params]
+ (let [opt-params (filter :optional? params)
+ coerced-params (mapv ndarray-api-coerce-param params)
+ default-args (array-map :keys (mapv :sym params)
+ :or (into {}
+ (mapv (fn [{:keys [sym]}] [sym nil])
+ opt-params))
+ :as 'opts)]
+ `([~default-args]
+ (~'util/coerce-return
+ (~(symbol (str "NDArrayAPI/" op-name))
+ ~@coerced-params)))))
+
+(defn gen-ndarray-api-required-arity [fn-name req-params]
+ (let [req-args (->> req-params
+ (mapv (fn [{:keys [sym]}] [(keyword sym) sym]))
+ (into {}))]
+ `(~(mapv :sym req-params)
+ (~(symbol fn-name) ~req-args))))
+
+(defn gen-ndarray-api-function [op-name]
+ (let [{:keys [fn-name fn-description args]} (gen-op-info op-name)
+ params (mapv (fn [{:keys [name] :as opts}]
+ (assoc opts :sym (symbol name)))
+ (conj args {:name "out"
+ :type "NDArray-or-Symbol"
+ :optional? true
+ :description "Output array."}))
+ doc (gen-ndarray-api-doc fn-description params)
+ opt-params (filter :optional? params)
+ req-params (remove :optional? params)
+ req-call (gen-ndarray-api-required-arity fn-name req-params)
+ default-call (gen-ndarray-api-default-arity op-name params)]
+ (if (= 1 (count req-params))
+ `(~'defn ~(symbol fn-name)
+ ~doc
+ ~@default-call)
+ `(~'defn ~(symbol fn-name)
+ ~doc
+ ~req-call
+ ~default-call))))
+
+(def all-ndarray-api-functions
+ (mapv gen-ndarray-api-function op-names))
+
+(def ndarray-api-gen-ns "(ns
+ ^{:doc \"Experimental\"}
+ org.apache.clojure-mxnet.ndarray-api
+ (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity
load max
+ min repeat reverse set sort take to-array empty
shuffle
+ ref])
+ (:require [org.apache.clojure-mxnet.shape :as mx-shape]
+ [org.apache.clojure-mxnet.util :as util])
+ (:import (org.apache.mxnet NDArrayAPI)))")
+
+
+(defn generate-ndarray-api-file []
+ (println "Generating ndarray-api file")
+ (write-to-file all-ndarray-api-functions
+ ndarray-api-gen-ns
+ "src/org/apache/clojure_mxnet/gen/ndarray_api.clj"))
;;; autogen the files
(do
(generate-ndarray-file)
- (generate-symbol-file))
+ (generate-ndarray-api-file)
+ (generate-symbol-file)
+ (generate-symbol-api-file))
(comment
+ (gen-op-info "ElementWiseSum")
+
+ (gen-ndarray-api-function "Activation")
+
+ (gen-symbol-api-function "Activation")
+
;; This generates a file with the bulk of the nd-array functions
(generate-ndarray-file)
diff --git
a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj
new file mode 100644
index 0000000..70359a6
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj
@@ -0,0 +1,32 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.ndarray-api
+ "Experimental NDArray API"
+ (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity
load max
+ min repeat reverse set sort take to-array empty
shuffle
+ ref])
+
+ (:require [org.apache.clojure-mxnet.base :as base]
+ [org.apache.clojure-mxnet.context :as mx-context]
+ [org.apache.clojure-mxnet.shape :as mx-shape]
+ [org.apache.clojure-mxnet.util :as util]
+ [clojure.reflect :as r]
+ [t6.from-scala.core :refer [$] :as $])
+ (:import (org.apache.mxnet NDArrayAPI)))
+
+;; loads the generated functions into the namespace
+(do (clojure.core/load "gen/ndarray_api"))
diff --git
a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj
new file mode 100644
index 0000000..69cc813
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj
@@ -0,0 +1,32 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.symbol-api
+ "Experimental Symbol API"
+ (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten
load max
+ min repeat reverse set sort take to-array empty sin
+ get apply shuffle ref])
+ (:require [org.apache.clojure-mxnet.base :as base]
+ [org.apache.clojure-mxnet.context :as mx-context]
+ [org.apache.clojure-mxnet.executor :as ex]
+ [org.apache.clojure-mxnet.shape :as mx-shape]
+ [org.apache.clojure-mxnet.util :as util]
+ [t6.from-scala.core :refer [$] :as $]
+ [org.apache.clojure-mxnet.ndarray :as ndarray])
+ (:import (org.apache.mxnet SymbolAPI)))
+
+;; loads the generated functions into the namespace
+(do (clojure.core/load "gen/symbol_api"))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 89ac1cd..7ee25d4 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -35,7 +35,6 @@
"int<>" "vec-of-ints"
"float<>" "vec-of-floats"
"byte<>" "byte-array"
- "java.lang.String<>" "vec-or-strings"
"org.apache.mxnet.NDArray" "ndarray"
"org.apache.mxnet.Symbol" "sym"
"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"
"double-or-float"})
@@ -49,7 +48,7 @@
"int<>" "vec-of-ints"
"float<>" "vec-of-floats"
"byte<>" "byte-array"
- "java.lang.String<>" "vec-or-strings"
+ "java.lang.String<>" "vec-of-strings"
"org.apache.mxnet.Symbol" "sym"
"java.lang.Object" "object"})
@@ -152,9 +151,12 @@
(and (get targets "scala.collection.Seq") (instance?
org.apache.mxnet.Symbol param)) ($/immutable-list param)
(and (get targets "scala.collection.Seq") (and (or (vector? param) (seq?
param)) (empty? param))) (empty-list)
(and (get targets "scala.collection.Seq") (or (vector? param) (seq?
param))) (apply $/immutable-list param)
+ (and (get targets "org.apache.mxnet.Shape") (or (vector? param) (seq?
param) (empty? param))) (mx-shape/->shape param)
(and (get targets "int<>") (vector? param)) (int-array param)
(and (get targets "float<>") (vector? param)) (float-array param)
(and (get targets "java.lang.String<>") (vector? param)) (into-array param)
+ (and (get targets "org.apache.mxnet.NDArray<>") (vector? param))
(into-array param)
+ (and (get targets "org.apache.mxnet.Symbol<>") (vector? param))
(into-array param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE")
(instance? Float param)) (primitives/mx-float param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE")
(number? param)) (primitives/mx-double param)
:else param))
diff --git a/contrib/clojure-package/test/dev/generator_test.clj
b/contrib/clojure-package/test/dev/generator_test.clj
index 05b4a74..cf28241 100644
--- a/contrib/clojure-package/test/dev/generator_test.clj
+++ b/contrib/clojure-package/test/dev/generator_test.clj
@@ -50,6 +50,127 @@
(is (= transformed-params (gen/symbol-transform-param-name
(:parameter-types (symbol-reflect-info
"floor")))))))
+(deftest test-gen-op-info
+ (testing "activation"
+ (let [activation-info (gen/gen-op-info "Activation")]
+ (is (= "activation" (:fn-name activation-info)))
+ (is (string? (:fn-description activation-info)))
+ (is (= 2 (-> activation-info :args count)))
+ (is (= "" (:key-var-num-args activation-info)))
+
+ (is (= "data" (-> activation-info :args first :name)))
+ (is (= "NDArray-or-Symbol" (-> activation-info :args first :type)))
+ (is (false? (-> activation-info :args first :optional?)))
+ (is (nil? (-> activation-info :args first :default)))
+ (is (string? (-> activation-info :args first :description)))
+
+ (is (= "act-type" (-> activation-info :args second :name)))
+ (is (= "'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'" (->
activation-info :args second :type)))
+ (is (false? (-> activation-info :args second :optional?)))
+ (is (nil? (-> activation-info :args second :default)))
+ (is (string? (-> activation-info :args second :description)))))
+
+ (testing "argmin"
+ (let [argmin-info (gen/gen-op-info "argmin")]
+ (is (= "argmin" (:fn-name argmin-info)))
+ (is (= 3 (-> argmin-info :args count)))
+
+ (is (= "data" (-> argmin-info :args (nth 0) :name)))
+ (is (= "NDArray-or-Symbol" (-> argmin-info :args (nth 0) :type)))
+ (is (false? (-> argmin-info :args (nth 0) :optional?)))
+
+ (is (= "axis" (-> argmin-info :args (nth 1) :name)))
+ (is (= "int or None" (-> argmin-info :args (nth 1) :type)))
+ (is (= "'None'" (-> argmin-info :args (nth 1) :default)))
+ (is (true? (-> argmin-info :args (nth 1) :optional?)))
+
+ (is (= "keepdims" (-> argmin-info :args (nth 2) :name)))
+ (is (= "boolean" (-> argmin-info :args (nth 2) :type)))
+ (is (= "0" (-> argmin-info :args (nth 2) :default)))
+ (is (true? (-> argmin-info :args (nth 2) :optional?)))))
+
+ (testing "concat"
+ (let [concat-info (gen/gen-op-info "Concat")]
+ (is (= "concat" (:fn-name concat-info)))
+ (is (= 3 (-> concat-info :args count)))
+ (is (= "num-args" (:key-var-num-args concat-info)))
+
+ (is (= "data" (-> concat-info :args (nth 0) :name)))
+ (is (= "NDArray-or-Symbol[]" (-> concat-info :args (nth 0) :type)))
+ (is (false? (-> concat-info :args (nth 0) :optional?)))
+
+ (is (= "num-args" (-> concat-info :args (nth 1) :name)))
+ (is (= "int" (-> concat-info :args (nth 1) :type)))
+ (is (false? (-> concat-info :args (nth 1) :optional?)))
+
+ (is (= "dim" (-> concat-info :args (nth 2) :name)))
+ (is (= "int" (-> concat-info :args (nth 2) :type)))
+ (is (= "'1'" (-> concat-info :args (nth 2) :default)))
+ (is (true? (-> concat-info :args (nth 2) :optional?)))))
+
+ (testing "convolution"
+ (let [convolution-info (gen/gen-op-info "Convolution")]
+
+ (is (= "convolution" (:fn-name convolution-info)))
+ (is (= 14 (-> convolution-info :args count)))
+ (is (= "" (:key-var-num-args convolution-info)))
+
+ (is (= "data" (-> convolution-info :args (nth 0) :name)))
+ (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 0) :type)))
+ (is (false? (-> convolution-info :args (nth 0) :optional?)))
+
+ (is (= "weight" (-> convolution-info :args (nth 1) :name)))
+ (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 1) :type)))
+ (is (false? (-> convolution-info :args (nth 1) :optional?)))
+
+ (is (= "kernel" (-> convolution-info :args (nth 3) :name)))
+ (is (= "Shape" (-> convolution-info :args (nth 3) :type)))
+ (is (= "(tuple)" (-> convolution-info :args (nth 3) :spec)))
+ (is (false? (-> convolution-info :args (nth 3) :optional?)))
+
+ (is (= "stride" (-> convolution-info :args (nth 4) :name)))
+ (is (= "Shape" (-> convolution-info :args (nth 4) :type)))
+ (is (= "(tuple)" (-> convolution-info :args (nth 4) :spec)))
+ (is (= "[]" (-> convolution-info :args (nth 4) :default)))
+ (is (true? (-> convolution-info :args (nth 4) :optional?)))
+
+ (is (= "num-filter" (-> convolution-info :args (nth 7) :name)))
+ (is (= "int" (-> convolution-info :args (nth 7) :type)))
+ (is (= "(non-negative)" (-> convolution-info :args (nth 7) :spec)))
+ (is (false? (-> convolution-info :args (nth 7) :optional?)))
+
+ (is (= "num-group" (-> convolution-info :args (nth 8) :name)))
+ (is (= "int" (-> convolution-info :args (nth 8) :type)))
+ (is (= "(non-negative)" (-> convolution-info :args (nth 8) :spec)))
+ (is (= "1" (-> convolution-info :args (nth 8) :default)))
+ (is (true? (-> convolution-info :args (nth 8) :optional?)))
+
+ (is (= "workspace" (-> convolution-info :args (nth 9) :name)))
+ (is (= "long" (-> convolution-info :args (nth 9) :type)))
+ (is (= "(non-negative)" (-> convolution-info :args (nth 9) :spec)))
+ (is (= "1024" (-> convolution-info :args (nth 9) :default)))
+ (is (true? (-> convolution-info :args (nth 9) :optional?)))
+
+ (is (= "no-bias" (-> convolution-info :args (nth 10) :name)))
+ (is (= "boolean" (-> convolution-info :args (nth 10) :type)))
+ (is (= "0" (-> convolution-info :args (nth 10) :default)))
+ (is (true? (-> convolution-info :args (nth 10) :optional?)))
+
+ (is (= "layout" (-> convolution-info :args (nth 13) :name)))
+ (is (= "None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'" (->
convolution-info :args (nth 13) :type)))
+ (is (= "'None'" (-> convolution-info :args (nth 13) :default)))
+ (is (true? (-> convolution-info :args (nth 13) :optional?)))))
+
+ (testing "element wise sum"
+ (let [element-wise-sum-info (gen/gen-op-info "ElementWiseSum")]
+ (is (= "add-n" (:fn-name element-wise-sum-info)))
+ (is (= 1 (-> element-wise-sum-info :args count)))
+ (is (= "num-args" (:key-var-num-args element-wise-sum-info)))
+
+ (is (= "args" (-> element-wise-sum-info :args (nth 0) :name)))
+ (is (= "NDArray-or-Symbol[]" (-> element-wise-sum-info :args (nth 0)
:type)))
+ (is (false? (-> element-wise-sum-info :args (nth 0) :optional?))))))
+
(deftest test-ndarray-transform-param-name
(let [params ["scala.collection.immutable.Map"
"scala.collection.Seq"]
@@ -68,7 +189,10 @@
(deftest test-rename-duplicate-params
(is (= ["foo" "bar" "baz"] (gen/rename-duplicate-params ["foo" "bar"
"baz"])))
- (is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar"
"bar"]))))
+ (is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar"
"bar"])))
+ (is (= ["foo" "bar" "bar-1" "foo-1"] (gen/rename-duplicate-params ["foo"
"bar" "bar" "foo"])))
+ (is (= ["foo" "bar" "bar-1" "bar-2"] (gen/rename-duplicate-params ["foo"
"bar" "bar" "bar"])))
+ (is (= ["foo" "bar" "bar-1" "bar-2" "foo-1" "baz"]
(gen/rename-duplicate-params ["foo" "bar" "bar" "bar" "foo" "baz"]))))
(deftest test-is-symbol-hand-gen?
(is (not (false? (gen/is-symbol-hand-gen? (symbol-reflect-info "max")))))
@@ -191,7 +315,17 @@
(gen/gen-ndarray-function-arity op-name op-values)))))
(deftest test-write-to-file
- (testing "symbol"
+ (testing "symbol-api"
+ (let [fname "test/test-symbol-api.clj"
+ _ (gen/write-to-file [(first gen/all-symbol-api-functions)
+ (second gen/all-symbol-api-functions)]
+ gen/symbol-api-gen-ns
+ fname)
+ good-contents (slurp "test/good-test-symbol-api.clj")
+ contents (slurp fname)]
+ (is (= good-contents contents))))
+
+ (testing "symbol"
(let [fname "test/test-symbol.clj"
_ (gen/write-to-file [(first gen/all-symbol-functions)]
gen/symbol-gen-ns
@@ -200,6 +334,16 @@
contents (slurp fname)]
(is (= good-contents contents))))
+ (testing "ndarray-api"
+ (let [fname "test/test-ndarray-api.clj"
+ _ (gen/write-to-file [(first gen/all-ndarray-api-functions)
+ (second gen/all-ndarray-api-functions)]
+ gen/ndarray-api-gen-ns
+ fname)
+ good-contents (slurp "test/good-test-ndarray-api.clj")
+ contents (slurp fname)]
+ (is (= good-contents contents))))
+
(testing "ndarray"
(let [fname "test/test-ndarray.clj"
_ (gen/write-to-file [(first gen/all-ndarray-functions)]
diff --git a/contrib/clojure-package/test/good-test-ndarray-api.clj
b/contrib/clojure-package/test/good-test-ndarray-api.clj
new file mode 100644
index 0000000..1b83a7b
--- /dev/null
+++ b/contrib/clojure-package/test/good-test-ndarray-api.clj
@@ -0,0 +1,89 @@
+(ns
+ ^{:doc "Experimental"}
+ org.apache.clojure-mxnet.ndarray-api
+ (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity
load max
+ min repeat reverse set sort take to-array empty
shuffle
+ ref])
+ (:require [org.apache.clojure-mxnet.shape :as mx-shape]
+ [org.apache.clojure-mxnet.util :as util])
+ (:import (org.apache.mxnet NDArrayAPI)))
+
+;; Do not edit - this is auto-generated
+
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+
+
+
+(defn
+ activation
+ "Applies an activation function element-wise to the input.\n\nThe following
activation functions are supported:\n\n- `relu`: Rectified Linear Unit,
:math:`y = max(x, 0)`\n- `sigmoid`: :math:`y = \\frac{1}{1 + exp(-x)}`\n-
`tanh`: Hyperbolic tangent, :math:`y = \\frac{exp(x) - exp(-x)}{exp(x) +
exp(-x)}`\n- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`\n-
`softsign`: :math:`y = \\frac{x}{1 + abs(x)}`\n\n\n\nDefined in
src/operator/nn/activation.cc:L167\n\n`data`: The in [...]
+ ([data act-type] (activation {:data data, :act-type act-type}))
+ ([{:keys [data act-type out], :or {out nil}, :as opts}]
+ (util/coerce-return
+ (NDArrayAPI/Activation data act-type (util/->option out)))))
+
+(defn
+ batch-norm
+ "Batch normalization.\n\nNormalizes a data batch by mean and variance, and
applies a scale ``gamma`` as\nwell as offset ``beta``.\n\nAssume the input has
more than one dimension and we normalize along axis 1.\nWe first compute the
mean and variance along this axis:\n\n.. math::\n\n data\\_mean[i] =
mean(data[:,i,:,...]) \\\\\n data\\_var[i] = var(data[:,i,:,...])\n\nThen
compute the normalized output, which has the same shape as input, as
following:\n\n.. math::\n\n out[:,i,:,...] = [...]
+ ([data gamma beta moving-mean moving-var]
+ (batch-norm
+ {:data data,
+ :gamma gamma,
+ :beta beta,
+ :moving-mean moving-mean,
+ :moving-var moving-var}))
+ ([{:keys
+ [data
+ gamma
+ beta
+ moving-mean
+ moving-var
+ eps
+ momentum
+ fix-gamma
+ use-global-stats
+ output-mean-var
+ axis
+ cudnn-off
+ out],
+ :or
+ {eps nil,
+ momentum nil,
+ fix-gamma nil,
+ use-global-stats nil,
+ output-mean-var nil,
+ axis nil,
+ cudnn-off nil,
+ out nil},
+ :as opts}]
+ (util/coerce-return
+ (NDArrayAPI/BatchNorm
+ data
+ gamma
+ beta
+ moving-mean
+ moving-var
+ (util/->option eps)
+ (util/->option momentum)
+ (util/->option fix-gamma)
+ (util/->option use-global-stats)
+ (util/->option output-mean-var)
+ (util/->option axis)
+ (util/->option cudnn-off)
+ (util/->option out)))))
+
diff --git a/contrib/clojure-package/test/good-test-symbol-api.clj
b/contrib/clojure-package/test/good-test-symbol-api.clj
new file mode 100644
index 0000000..a030884
--- /dev/null
+++ b/contrib/clojure-package/test/good-test-symbol-api.clj
@@ -0,0 +1,109 @@
+(ns
+ ^{:doc "Experimental"}
+ org.apache.clojure-mxnet.symbol-api
+ (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten
load max
+ min repeat reverse set sort take to-array empty sin
+ get apply shuffle ref])
+ (:require [org.apache.clojure-mxnet.util :as util]
+ [org.apache.clojure-mxnet.shape :as mx-shape])
+ (:import (org.apache.mxnet SymbolAPI)))
+
+;; Do not edit - this is auto-generated
+
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+
+
+
+(defn
+ activation
+ "Applies an activation function element-wise to the input.\n\nThe following
activation functions are supported:\n\n- `relu`: Rectified Linear Unit,
:math:`y = max(x, 0)`\n- `sigmoid`: :math:`y = \\frac{1}{1 + exp(-x)}`\n-
`tanh`: Hyperbolic tangent, :math:`y = \\frac{exp(x) - exp(-x)}{exp(x) +
exp(-x)}`\n- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`\n-
`softsign`: :math:`y = \\frac{x}{1 + abs(x)}`\n\n\n\nDefined in
src/operator/nn/activation.cc:L167\n\n`data`: The in [...]
+ [{:keys [data act-type name attr],
+ :or {data nil, name nil, attr nil},
+ :as opts}]
+ (util/coerce-return
+ (SymbolAPI/Activation
+ (util/->option data)
+ act-type
+ name
+ (clojure.core/when
+ attr
+ (clojure.core/->>
+ attr
+ (clojure.core/mapv
+ (clojure.core/fn [[k v]] [k (clojure.core/str v)]))
+ (clojure.core/into {})
+ util/convert-map)))))
+
+(defn
+ batch-norm
+ "Batch normalization.\n\nNormalizes a data batch by mean and variance, and
applies a scale ``gamma`` as\nwell as offset ``beta``.\n\nAssume the input has
more than one dimension and we normalize along axis 1.\nWe first compute the
mean and variance along this axis:\n\n.. math::\n\n data\\_mean[i] =
mean(data[:,i,:,...]) \\\\\n data\\_var[i] = var(data[:,i,:,...])\n\nThen
compute the normalized output, which has the same shape as input, as
following:\n\n.. math::\n\n out[:,i,:,...] = [...]
+ [{:keys
+ [data
+ gamma
+ beta
+ moving-mean
+ moving-var
+ eps
+ momentum
+ fix-gamma
+ use-global-stats
+ output-mean-var
+ axis
+ cudnn-off
+ name
+ attr],
+ :or
+ {output-mean-var nil,
+ axis nil,
+ cudnn-off nil,
+ fix-gamma nil,
+ eps nil,
+ data nil,
+ attr nil,
+ beta nil,
+ name nil,
+ use-global-stats nil,
+ moving-mean nil,
+ moving-var nil,
+ momentum nil,
+ gamma nil},
+ :as opts}]
+ (util/coerce-return
+ (SymbolAPI/BatchNorm
+ (util/->option data)
+ (util/->option gamma)
+ (util/->option beta)
+ (util/->option moving-mean)
+ (util/->option moving-var)
+ (util/->option eps)
+ (util/->option momentum)
+ (util/->option fix-gamma)
+ (util/->option use-global-stats)
+ (util/->option output-mean-var)
+ (util/->option axis)
+ (util/->option cudnn-off)
+ name
+ (clojure.core/when
+ attr
+ (clojure.core/->>
+ attr
+ (clojure.core/mapv
+ (clojure.core/fn [[k v]] [k (clojure.core/str v)]))
+ (clojure.core/into {})
+ util/convert-map)))))
+
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj
index feda45b..ca9d4bc 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj
@@ -24,6 +24,8 @@
[org.apache.clojure-mxnet.module :as m]
[org.apache.clojure-mxnet.optimizer :as optimizer]
[org.apache.clojure-mxnet.symbol :as sym]
+ [org.apache.clojure-mxnet.symbol-api :as sym-api]
+ [org.apache.clojure-mxnet.util :as util]
[clojure.reflect :as r]))
(def data-dir "data/")
@@ -54,17 +56,19 @@
(defn get-symbol []
(as-> (sym/variable "data") data
- (sym/convolution "conv1" {:data data :kernel [3 3] :num-filter 32 :stride
[2 2]})
- (sym/batch-norm "bn1" {:data data})
- (sym/activation "relu1" {:data data :act-type "relu"})
- (sym/pooling "mp1" {:data data :kernel [2 2] :pool-type "max" :stride [2
2]}) (sym/convolution "conv2" {:data data :kernel [3 3] :num-filter 32 :stride
[2 2]})
- (sym/batch-norm "bn2" {:data data})
- (sym/activation "relu2" {:data data :act-type "relu"})
- (sym/pooling "mp2" {:data data :kernel [2 2] :pool-type "max" :stride [2
2]})
+ (sym-api/convolution {:name "conv1" :data data :kernel [3 3] :num-filter
32 :stride [2 2]})
+ (sym-api/batch-norm {:name "bn1" :data data})
+ (sym-api/activation {:name "relu1" :data data :act-type "relu"})
+ (sym-api/pooling {:name "mp1" :data data :kernel [2 2] :pool-type "max"
:stride [2 2]})
- (sym/flatten "fl" {:data data})
- (sym/fully-connected "fc2" {:data data :num-hidden 10})
- (sym/softmax-output "softmax" {:data data})))
+ (sym-api/convolution {:name "conv2" :data data :kernel [3 3] :num-filter
32 :stride [2 2]})
+ (sym-api/batch-norm {:name "bn2" :data data})
+ (sym-api/activation {:name "relu2" :data data :act-type "relu"})
+ (sym-api/pooling {:name "mp2" :data data :kernel [2 2] :pool-type "max"
:stride [2 2]})
+
+ (sym-api/flatten {:name "fl" :data data})
+ (sym-api/fully-connected {:name "fc2" :data data :num-hidden 10})
+ (sym-api/softmax-output {:name "softmax" :data data})))
(deftest test-conv []
(let [mod (m/module (get-symbol))]
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj
new file mode 100644
index 0000000..18b8b78
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj
@@ -0,0 +1,415 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.ndarray-api-test
+ (:require [org.apache.clojure-mxnet.base :as base]
+ [org.apache.clojure-mxnet.context :as ctx]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.ndarray :as ndarray :refer [->vec zeros
ones += -= *= full shape shape-vec]]
+ [org.apache.clojure-mxnet.ndarray-api :as ndarray-api]
+ [org.apache.clojure-mxnet.shape :as mx-shape :refer [->shape]]
+ [org.apache.clojure-mxnet.test-util :as test-util :refer [approx=]]
+ [org.apache.clojure-mxnet.util :as util :refer [->option]]
+ [clojure.test :refer :all]))
+
+(deftest test-activation
+ (let [data (ndarray/array [2 1 0 -1 -2] [1 5])
+ relu (ndarray-api/activation data "relu")
+ sigmoid (ndarray-api/activation data "sigmoid")
+ softsign (ndarray-api/activation data "softsign")
+ out (ndarray/zeros [1 5])
+ _ (ndarray-api/activation {:data data :act-type "relu" :out out})]
+ (is (= [2.0 1.0 0.0 0.0 0.0] (->vec relu)))
+ (is (approx= 1e-3 [0.881 0.731 0.5 0.269 0.119] (->vec sigmoid)))
+ (is (approx= 1e-3 [0.666 0.5 0.0 -0.5 -0.666] (->vec softsign)))
+ (is (= [2.0 1.0 0.0 0.0 0.0] (->vec out)))))
+
+(deftest test-bilinear-sampler
+ (let [data (ndarray/array [1 4 3 6
+ 1 8 8 9
+ 0 4 1 5
+ 1 0 1 3]
+ [1 1 4 4])
+ affine (ndarray/array [2 0 0
+ 0 2 0]
+ [1 6])
+ grid (ndarray-api/grid-generator {:data affine :transform-type
"affine" :target-shape [4 4]})
+ out (ndarray-api/bilinear-sampler data grid)]
+ (is (approx= 1e-3
+ [0.0 0.0 0.0 0.0
+ 0.0 3.5 6.5 0.0
+ 0.0 1.25 2.5 0.0
+ 0.0 0.0 0.0 0.0]
+ (->vec out)))))
+
+(deftest test-cast
+ (let [nda1 (ndarray/array [0.9 1.3] [2])
+ nda2 (ndarray/array [1e20 11.1] [2])
+ nda3 (ndarray/array [300 11.1 10.9 -1 -3] [5])
+ out (ndarray/zeros [2] {:dtype dtype/INT32})
+ _ (ndarray-api/cast {:data nda1 :dtype (str dtype/INT32) :out out})]
+ (is (= [0.0 1.0] (->vec (ndarray-api/cast nda1 (str dtype/INT32)))))
+ (is (= [(float 1e20) (float 11.1)] (->vec (ndarray-api/cast nda2 (str
dtype/FLOAT32)))))
+ ;; uint8 gets converted to native types after ->vec
+ (is (= [44.0 11.0 10.0 -1.0 -3.0] (->vec (ndarray-api/cast nda3
"uint8"))))))
+
+(deftest test-concat
+ (let [nda1 (ndarray/zeros [1 2])
+ nda2 (ndarray/ones [1 2])
+ out (ndarray/zeros [1 4])
+ res1 (ndarray-api/concat [nda1 nda2] 2) ;; num_args=2, dim=1 (default)
+ res2 (ndarray-api/concat {:data [nda1 nda2] :num-args 2 :dim 0}) ;;
num_args=2, dim=0
+ res3 (ndarray-api/concat {:data [nda1 nda2 nda1] :num-args 3 :dim 1})
;; num_args=3, dim=1
+ _ (ndarray-api/concat {:data [nda1 nda2] :num-args 2 :dim 1 :out out})
;; store result in out
+ ]
+ (is (= [0.0 0.0 1.0 1.0] (->vec res1)))
+ (is (= [1 4] (shape-vec res1)))
+ (is (= [0.0 0.0 1.0 1.0] (->vec res2)))
+ (is (= [2 2] (shape-vec res2)))
+ (is (= [0.0 0.0 1.0 1.0 0.0 0.0] (->vec res3)))
+ (is (= [1 6] (shape-vec res3)))
+ (is (= [0.0 0.0 1.0 1.0] (->vec out)))
+ (is (= [1 4] (shape-vec out)))))
+
+(deftest test-embedding
+ (let [input-dim 4
+ output-dim 5
+ w (ndarray/array [0. 1. 2. 3. 4.
+ 5. 6. 7. 8. 9.
+ 10. 11. 12. 13. 14.
+ 15. 16. 17. 18. 19.]
+ [4 5])
+ x (ndarray/array [1. 3.
+ 0. 2.]
+ [2 2])
+ out (ndarray-api/embedding x w input-dim output-dim)]
+ (is (= [5. 6. 7. 8. 9.
+ 15. 16. 17. 18. 19.
+ 0. 1. 2. 3. 4.
+ 10. 11. 12. 13. 14.]
+ (->vec out)))
+ (is (= [2 2 5] (shape-vec out)))))
+
+(deftest test-flatten
+ (let [nda (ndarray/array [1 2 3
+ 4 5 6
+ 7 8 9
+ 1 2 3
+ 4 5 6
+ 7 8 9]
+ [2 3 3])
+ out (ndarray/zeros [2 9])
+ res (ndarray-api/flatten {:data nda})
+ _ (ndarray-api/flatten {:data nda :out out})]
+ (is (= [1. 2. 3. 4. 5. 6. 7. 8. 9.
+ 1. 2. 3. 4. 5. 6. 7. 8. 9.] (->vec res)))
+ (is (= [2 9] (shape-vec res)))
+ (is (= [1. 2. 3. 4. 5. 6. 7. 8. 9.
+ 1. 2. 3. 4. 5. 6. 7. 8. 9.] (->vec out)))
+ (is (= [2 9] (shape-vec out)))))
+
+(deftest test-instance-norm
+ (let [x (ndarray/array [1.1 2.2 3.3 4.4] [2 1 2])
+ gamma (ndarray/array [1.5] [1])
+ beta (ndarray/array [0.5] [1])
+ res (ndarray-api/instance-norm x gamma beta)]
+ (is (approx= 1e-4 [-0.9975 1.9975
+ -0.9975 1.9975] (->vec res)))
+ (is (= [2 1 2] (shape-vec res)))))
+
+(deftest test-l2-normalization
+ (let [x (ndarray/array [1 2 3 4 2 2 5 6] [2 2 2])
+ res1 (ndarray-api/l2-normalization {:data x}) ;; instance-wise
+ res2 (ndarray-api/l2-normalization {:data x :mode "instance"})
+ res3 (ndarray-api/l2-normalization {:data x :mode "channel"})
+ res4 (ndarray-api/l2-normalization {:data x :mode "spatial"})]
+ (is (approx= 1e-4 [0.1825 0.3651
+ 0.5477 0.7303
+ 0.2407 0.2407
+ 0.6019 0.7223] (->vec res1)))
+ (is (approx= 1e-4 [0.1825 0.3651
+ 0.5477 0.7303
+ 0.2407 0.2407
+ 0.6019 0.7223] (->vec res2)))
+ (is (approx= 1e-4 [0.3162 0.4472
+ 0.9486 0.8944
+ 0.3714 0.3162
+ 0.9284 0.9486] (->vec res3)))
+ (is (approx= 1e-4 [0.4472 0.8944
+ 0.6 0.8
+ 0.7071 0.7071
+ 0.6402 0.7682] (->vec res4)))))
+
+(deftest test-pad
+ (let [x (ndarray/array [1 2 3
+ 4 5 6
+ 7 8 9
+ 10 11 12
+ 11 12 13
+ 14 15 16
+ 17 18 19
+ 20 21 22]
+ [2 2 2 3])
+ res1 (ndarray-api/pad x "edge" [0,0,0,0,1,1,1,1])
+ res2 (ndarray-api/pad {:data x :mode "constant" :pad-width
[0,0,0,0,1,1,1,1] :constant-value 0})]
+ (is (= [1. 1. 2. 3. 3.
+ 1. 1. 2. 3. 3.
+ 4. 4. 5. 6. 6.
+ 4. 4. 5. 6. 6.
+ 7. 7. 8. 9. 9.
+ 7. 7. 8. 9. 9.
+ 10. 10. 11. 12. 12.
+ 10. 10. 11. 12. 12.
+ 11. 11. 12. 13. 13.
+ 11. 11. 12. 13. 13.
+ 14. 14. 15. 16. 16.
+ 14. 14. 15. 16. 16.
+ 17. 17. 18. 19. 19.
+ 17. 17. 18. 19. 19.
+ 20. 20. 21. 22. 22.
+ 20. 20. 21. 22. 22.] (->vec res1)))
+ (is (= [2 2 4 5] (shape-vec res1)))
+ (is (= [0. 0. 0. 0. 0.
+ 0. 1. 2. 3. 0.
+ 0. 4. 5. 6. 0.
+ 0. 0. 0. 0. 0.
+
+ 0. 0. 0. 0. 0.
+ 0. 7. 8. 9. 0.
+ 0. 10. 11. 12. 0.
+ 0. 0. 0. 0. 0.
+
+ 0. 0. 0. 0. 0.
+ 0. 11. 12. 13. 0.
+ 0. 14. 15. 16. 0.
+ 0. 0. 0. 0. 0.
+
+ 0. 0. 0. 0. 0.
+ 0. 17. 18. 19. 0.
+ 0. 20. 21. 22. 0.
+ 0. 0. 0. 0. 0.] (->vec res2)))
+ (is (= [2 2 4 5] (shape-vec res2)))))
+
+(deftest test-roi-pooling
+ (let [xi [[[[ 0., 1., 2., 3., 4., 5.],
+ [ 6., 7., 8., 9., 10., 11.],
+ [ 12., 13., 14., 15., 16., 17.],
+ [ 18., 19., 20., 21., 22., 23.],
+ [ 24., 25., 26., 27., 28., 29.],
+ [ 30., 31., 32., 33., 34., 35.],
+ [ 36., 37., 38., 39., 40., 41.],
+ [ 42., 43., 44., 45., 46., 47.]]]]
+ x (ndarray/array (-> xi flatten vec) [1 1 8 6])
+ y (ndarray/array [0 0 0 4 4] [1 5])
+ res1 (ndarray-api/roi-pooling x y [2 2] 1.0)
+ res2 (ndarray-api/roi-pooling x y [2 2] 0.7)]
+ (is (= [14. 16. 26. 28.] (->vec res1)))
+ (is (= [1 1 2 2] (shape-vec res1)))
+ (is (= [7. 9. 19. 21.] (->vec res2)))
+ (is (= [1 1 2 2] (shape-vec res2)))))
+
+(deftest test-reshape
+ (let [x (ndarray/array (vec (range 4)) [4])
+ y (ndarray/array (vec (range 24)) [2 3 4])
+ z (ndarray/array (vec (range 120)) [2 3 4 5])
+ res1 (ndarray-api/reshape {:data x :shape [2 2]})]
+ (is (= [0. 1. 2. 3.] (->vec res1)))
+ (is (= [2 2] (shape-vec res1)))
+ (is (= (map float (range 24)) (->vec (ndarray-api/reshape {:data y :shape
[4 0 2]}))))
+ (is (= [4 3 2] (shape-vec (ndarray-api/reshape {:data y :shape [4 0 2]}))))
+ (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 0 0]}))))
+ (is (= [6 1 4] (shape-vec (ndarray-api/reshape {:data y :shape [6 1
-1]}))))
+ (is (= [3 1 8] (shape-vec (ndarray-api/reshape {:data y :shape [3 -1
8]}))))
+ (is (= [24] (shape-vec (ndarray-api/reshape {:data y :shape [-1]}))))
+ (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [-2]}))))
+ (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 -2]}))))
+ (is (= [2 3 4 1 1] (shape-vec (ndarray-api/reshape {:data y :shape [-2 1
1]}))))
+ (is (= [6 4] (shape-vec (ndarray-api/reshape {:data y :shape [-3 4]}))))
+ (is (= [6 20] (shape-vec (ndarray-api/reshape {:data z :shape [-3 -3]}))))
+ (is (= [2 12] (shape-vec (ndarray-api/reshape {:data y :shape [0 -3]}))))
+ (is (= [6 4] (shape-vec (ndarray-api/reshape {:data y :shape [-3 -2]}))))
+ (is (= [1 2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [-4 1 2
-2]}))))
+ (is (= [2 1 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 -4 -1
3 -2]}))))))
+
+(deftest test-sequence-last
+ (let [xi [[[ 1., 2., 3.],
+ [ 4., 5., 6.],
+ [ 7., 8., 9.]],
+
+ [[ 10., 11., 12.],
+ [ 13., 14., 15.],
+ [ 16., 17., 18.]],
+
+ [[ 19., 20., 21.],
+ [ 22., 23., 24.],
+ [ 25., 26., 27.]]]
+ x (ndarray/array (-> xi flatten vec) [3 3 3])
+ seq-len1 (ndarray/array [1 1 1] [3])
+ seq-len2 (ndarray/array [1 2 3] [3])
+ ;; This test is failing with an exception
+ ;; (most likely a scala generation issue)
+ ;; res1 (ndarray-api/sequence-last x nil)
+ ]
+ ;; (is (= [] (->vec res1)))
+))
+
+(deftest test-sequence-mask
+ (let [xi [[[ 1., 2., 3.],
+ [ 4., 5., 6.]],
+
+ [[ 7., 8., 9.],
+ [ 10., 11., 12.]],
+
+ [[ 13., 14., 15.],
+ [ 16., 17., 18.]]]
+ x (ndarray/array (-> xi flatten vec) [3 2 3])
+ seq-len1 (ndarray/array [1 1] [2])
+ seq-len2 (ndarray/array [2 3] [2])
+ ;; Same issue as previous test
+ ;; res1 (ndarray-api/sequence-mask x seq-len1)
+ ]
+ ;; (is (= [] (->vec res1)))
+))
+
+(deftest test-slice-channel
+ (let [xi [[[ 1.] [ 2.]]
+ [[ 3.] [ 4.]]
+ [[ 5.] [ 6.]]]
+ x (ndarray/array (-> xi flatten vec) [3 2 1])
+ res1 (ndarray-api/slice-channel {:data x :num-outputs 2 :axis 1})
+ res2 (ndarray-api/slice-channel {:data x :num-outputs 3 :axis 0})
+ res3 (ndarray-api/slice-channel {:data x :num-outputs 3 :axis 0
:squeeze-axis 1})]
+ (is (= [1. 3. 5.] (->vec res1)))
+ (is (= [3 1 1] (shape-vec res1)))
+ (is (= [1. 2.] (->vec res2)))
+ (is (= [1 2 1] (shape-vec res2)))
+ (is (= [1. 2.] (->vec res3)))
+ (is (= [2 1] (shape-vec res3)))))
+
+(deftest test-softmax-activation
+ (let [x (ndarray/array [1 1 1 1 1 1] [2 3])
+ res1 (ndarray-api/softmax-activation {:data x :mode "instance"})]
+ (is (approx= 1e-3 [0.333 0.333 0.333
+ 0.333 0.333 0.333] (->vec res1)))
+ (is (= [2 3] (shape-vec res1)))))
+
+(deftest test-softmax-output
+ (let [datai [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
+ data (ndarray/array (-> datai flatten vec) [4 4])
+ label (ndarray/array [1,0,2,3] [4])
+ res1 (ndarray-api/softmax-output data label)]
+ (is (approx= 1e-4 [0.0321 0.0871 0.2369 0.6439
+ 0.25 0.25 0.25 0.25
+ 0.25 0.25 0.25 0.25
+ 0.25 0.25 0.25 0.25] (->vec res1)))
+ (is (= [4 4] (shape-vec res1)))))
+
+(deftest test-swap-axis
+ (let [x (ndarray/array (range 3) [1 3])
+ y (ndarray/array (range 8) [2 2 2])
+ res1 (ndarray-api/swap-axis {:data x :dim1 0 :dim2 1})
+ res2 (ndarray-api/swap-axis {:data y :dim1 0 :dim2 2})]
+ (is (= [0. 1. 2.] (->vec res1)))
+ (is (= [3 1] (shape-vec res1)))
+ (is (= [0. 4. 2. 6. 1. 5. 3. 7.] (->vec res2)))
+ (is (= [2 2 2] (shape-vec res2)))))
+
+(deftest test-abs
+ (let [x (ndarray/array [-2 0 3] [3])
+ res1 (ndarray-api/abs {:data x})]
+ (is (= [2. 0. 3.] (->vec res1)))
+ (is (= [3] (shape-vec res1)))))
+
+(deftest test-arccos
+ (let [x (ndarray/array [-1 -0.707 0 0.707 1] [5])
+ pi Math/PI
+ res1 (ndarray-api/arccos {:data x})]
+ (is (approx= 1e-3 [pi (* 0.75 pi) (* 0.5 pi) (* 0.25 pi) 0.] (->vec
res1)))))
+
+(deftest test-arcsin
+ (let [x (ndarray/array [-1 -0.707 0 0.707 1] [5])
+ pi Math/PI
+ res1 (ndarray-api/arcsin {:data x})]
+ (is (approx= 1e-3 [(- (* 0.5 pi)) (- (* 0.25 pi)) 0 (* 0.25 pi) (* 0.5
pi)] (->vec res1)))))
+
+(deftest test-argmax
+ (let [x (ndarray/array (range 6) [2 3])
+ res1 (ndarray-api/argmax {:data x :axis 0})
+ res2 (ndarray-api/argmax {:data x :axis 1})
+ res3 (ndarray-api/argmax {:data x :axis 0 :keepdims true})
+ res4 (ndarray-api/argmax {:data x :axis 1 :keepdims true})]
+ (is (= [1. 1. 1.] (->vec res1)))
+ (is (= [3] (shape-vec res1)))
+ (is (= [2. 2.] (->vec res2)))
+ (is (= [2] (shape-vec res2)))
+ (is (= [1. 1. 1.] (->vec res3)))
+ (is (= [1 3] (shape-vec res3)))
+ (is (= [2. 2.] (->vec res4)))
+ (is (= [2 1] (shape-vec res4)))))
+
+(deftest test-argmax-channel
+ (let [x (ndarray/array (range 6) [2 3])
+ res1 (ndarray-api/argmax-channel {:data x})]
+ (is (= [2. 2.] (->vec res1)))
+ (is (= [2] (shape-vec res1)))))
+
+(deftest test-argmin
+ (let [x (ndarray/array (reverse (range 6)) [2 3])
+ res1 (ndarray-api/argmin {:data x :axis 0})
+ res2 (ndarray-api/argmin {:data x :axis 1})
+ res3 (ndarray-api/argmin {:data x :axis 0 :keepdims true})
+ res4 (ndarray-api/argmin {:data x :axis 1 :keepdims true})]
+ (is (= [1. 1. 1.] (->vec res1)))
+ (is (= [3] (shape-vec res1)))
+ (is (= [2. 2.] (->vec res2)))
+ (is (= [2] (shape-vec res2)))
+ (is (= [1. 1. 1.] (->vec res3)))
+ (is (= [1 3] (shape-vec res3)))
+ (is (= [2. 2.] (->vec res4)))
+ (is (= [2 1] (shape-vec res4)))))
+
+(deftest test-argsort
+ (let [x (ndarray/array [0.3 0.2 0.4
+ 0.1 0.3 0.2]
+ [2 3])
+ y (ndarray/array [0.3 0.2 0.4 0.1 0.3 0.2] [6])
+ res1 (ndarray-api/argsort {:data x})
+ res2 (ndarray-api/argsort {:data x :axis 0})
+ res3 (ndarray-api/argsort {:data y})]
+ (is (= [1. 0. 2.
+ 0. 2. 1.]
+ (->vec res1)))
+ (is (= [2 3] (shape-vec res1)))
+ (is (= [1. 0. 1.
+ 0. 1. 0.]
+ (->vec res2)))
+ (is (= [2 3] (shape-vec res1)))
+ (is (= [3. 1. 5. 0. 4. 2.] (->vec res3)))
+ (is (= [6] (shape-vec res3)))))
+
+(deftest test-batch-take
+ (let [x (ndarray/array (range 6) [3 2])
+ i (ndarray/as-type (ndarray/array [0 1 0] [3]) dtype/INT32)
+ res1 (ndarray-api/batch-take x i) ]
+ (is (= [0. 3. 4.] (->vec res1)))))
+
+(deftest test-broadcast-add
+ (let [x (ndarray/ones [2 3])
+ y (ndarray/array (range 2) [2 1])
+ res1 (ndarray-api/broadcast-add x y)]
+ (is (= [1. 1. 1. 2. 2. 2.] (->vec res1)))
+ (is (= [2 3] (shape-vec res1)))))
diff --git
a/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj
b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj
new file mode 100644
index 0000000..b642ad7
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj
@@ -0,0 +1,61 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.symbol-api-test
+ (:require [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.executor :as executor]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]
+ [org.apache.clojure-mxnet.symbol :as sym]
+ [org.apache.clojure-mxnet.symbol-api :as sym-api]
+ [org.apache.clojure-mxnet.util :as util]
+ [clojure.test :refer :all]
+ [org.apache.clojure-mxnet.context :as context]))
+
+(deftest test-compose
+ (let [data (sym/variable "data")
+ net1 (sym-api/fully-connected {:data data :num-hidden 10 :name "fc1"})
+ net1 (sym-api/fully-connected {:data net1 :num-hidden 100 :name "fc2"}
)
+
+ net2 (sym-api/fully-connected {:num-hidden 10 :name "fc3"})
+ net2 (sym-api/activation {:data net2 :act-type "relu"})
+ net2 (sym-api/fully-connected {:data net2 :num-hidden 20 :name "fc4"})
+
+ composed (sym/apply net2 "composed" {"fc3_data" net1})
+
+ multi-out (sym/group [composed net1])]
+
+ (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"]
(sym/list-arguments net1)))
+ (is (= 2 (count (sym/list-outputs multi-out))))))
+
+(deftest test-symbol-internal
+ (let [data (sym/variable "data")
+ oldfc (sym-api/fully-connected {:data data :num-hidden 10 :name"fc1"})
+ net1 (sym-api/fully-connected {:data oldfc :num-hidden 100
:name"fc2"})]
+ (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"]
(sym/list-arguments net1)))
+ (= (sym/list-arguments oldfc) (-> (sym/get-internals net1)
+ (sym/get "fc1_output")
+ (sym/list-arguments)))))
+
+(deftest test-infer-type
+ (let [data (sym/variable "data")
+ f32data (sym-api/cast {:data data :dtype "float32"})
+ fc1 (sym-api/fully-connected {:data f32data :num-hidden 128
:name"fc1"})
+ mlp (sym-api/softmax-output {:data fc1 :name"softmax"})
+ [arg out aux] (sym/infer-type mlp {:data dtype/FLOAT64})]
+ (is (= [dtype/FLOAT64 dtype/FLOAT32 dtype/FLOAT32 dtype/FLOAT32]
(util/buffer->vec arg)))
+ (is (= [dtype/FLOAT32] (util/buffer->vec out)))
+ (is (= [] (util/buffer->vec aux)))))