This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a5db391  [Clojure] Add methods based on NDArrayAPI/SymbolAPI (#14195)
a5db391 is described below

commit a5db391cb27342bf8e267cfbfb4db26b5ef66721
Author: Kedar Bellare <[email protected]>
AuthorDate: Sat Apr 13 12:19:06 2019 -0700

    [Clojure] Add methods based on NDArrayAPI/SymbolAPI (#14195)
    
    * [Clojure] Add methods based on NDArrayAPI/SymbolAPI
    
    * Add symbol API methods and ndarray API unit tests
    
    * Some more ndarray API unit tests
    
    * Explore direct use of JNI
    
    * Use library info directly instead of reflection
    
    * Add tests for generation op info
    
    * Fix ordering of keys using array-map
    
    * Ignore generated test files
    
    * Minor style changes
    
    * Refactor code for better readability
    
    * Address comments
    
    * Small tweaks to symbol api coercion
---
 contrib/clojure-package/.gitignore                 |   2 +
 contrib/clojure-package/src/dev/generator.clj      | 460 ++++++++++++++++-----
 .../src/org/apache/clojure_mxnet/ndarray_api.clj   |  32 ++
 .../src/org/apache/clojure_mxnet/symbol_api.clj    |  32 ++
 .../src/org/apache/clojure_mxnet/util.clj          |   6 +-
 .../clojure-package/test/dev/generator_test.clj    | 148 ++++++-
 .../clojure-package/test/good-test-ndarray-api.clj |  89 ++++
 .../clojure-package/test/good-test-symbol-api.clj  | 109 +++++
 .../test/org/apache/clojure_mxnet/conv_test.clj    |  24 +-
 .../org/apache/clojure_mxnet/ndarray_api_test.clj  | 415 +++++++++++++++++++
 .../org/apache/clojure_mxnet/symbol_api_test.clj   |  61 +++
 11 files changed, 1257 insertions(+), 121 deletions(-)

diff --git a/contrib/clojure-package/.gitignore 
b/contrib/clojure-package/.gitignore
index f5d81dd..71d812e 100644
--- a/contrib/clojure-package/.gitignore
+++ b/contrib/clojure-package/.gitignore
@@ -39,6 +39,8 @@ examples/visualization/test-vis.pdf
 src/.DS_Store
 src/org/.DS_Store
 test/test-ndarray.clj
+test/test-ndarray-api.clj
 test/test-symbol.clj
+test/test-symbol-api.clj
 src/org/apache/clojure_mxnet/gen/*
 
diff --git a/contrib/clojure-package/src/dev/generator.clj 
b/contrib/clojure-package/src/dev/generator.clj
index ca93c34..34210be 100644
--- a/contrib/clojure-package/src/dev/generator.clj
+++ b/contrib/clojure-package/src/dev/generator.clj
@@ -17,10 +17,14 @@
 
 (ns dev.generator
   (:require [t6.from-scala.core :as scala]
+            [t6.from-scala.core :refer [$ $$] :as $]
             [clojure.reflect :as r]
-            [org.apache.clojure-mxnet.util :as util]
-            [clojure.pprint])
-  (:import (org.apache.mxnet NDArray Symbol))
+            [clojure.pprint]
+            [org.apache.clojure-mxnet.util :as util])
+  (:import (org.apache.mxnet NDArray NDArrayAPI
+                             Symbol SymbolAPI
+                             Base Base$RefInt Base$RefLong Base$RefFloat 
Base$RefString)
+           (scala.collection.mutable ListBuffer ArrayBuffer))
   (:gen-class))
 
 
@@ -34,17 +38,17 @@
       (clojure.string/replace #"\_" "-")
       (clojure.string/replace #"\/" "div")))
 
-(defn symbol-transform-param-name [parameter-types]
+(defn transform-param-names [coerce-fn parameter-types]
   (->> parameter-types
        (map str)
-       (map (fn [x] (or (util/symbol-param-coerce x) x)))
+       (map (fn [x] (or (coerce-fn x) x)))
        (map (fn [x] (last (clojure.string/split x #"\."))))))
 
+(defn symbol-transform-param-name [parameter-types]
+  (transform-param-names util/symbol-param-coerce parameter-types))
+
 (defn ndarray-transform-param-name [parameter-types]
-  (->> parameter-types
-       (map str)
-       (map (fn [x] (or (util/ndarray-param-coerce x) x)))
-       (map (fn [x] (last (clojure.string/split x #"\."))))))
+  (transform-param-names util/ndarray-param-coerce parameter-types))
 
 (defn has-variadic? [params]
   (->> params
@@ -56,37 +60,136 @@
 
 (defn increment-param-name [pname]
   (if-let [num-str (re-find #"-\d" pname)]
-    (str (first (clojure.string/split pname #"-")) "-" (inc (Integer/parseInt 
(last (clojure.string/split num-str #"-")))))
+    (str 
+     (first (clojure.string/split pname #"-"))
+     "-"
+     (inc (Integer/parseInt (last (clojure.string/split num-str #"-")))))
     (str pname "-" 1)))
 
-(defn rename-duplicate-params [params]
-  (reduce (fn [known-names n] (conj known-names (if (contains? (set 
known-names) n)
-                                                  (increment-param-name n)
-                                                  n)))
-          []
-          params))
-
+(defn rename-duplicate-params [pnames]
+  (->> (reduce
+        (fn [pname-counts n]
+          (let [rn (if (pname-counts n) (str n "-" (pname-counts n)) n)
+                inc-pname-counts (update-in pname-counts [n] (fnil inc 0))]
+            (update-in inc-pname-counts [:params] conj rn)))
+        {:params []}
+        pnames)
+       :params))
+
+(defn get-public-no-default-methods [obj]
+  (->> (r/reflect obj)
+       :members
+       (map #(into {} %))
+       (filter #(-> % :flags :public))
+       (remove #(re-find #"org\$apache\$mxnet" (str (:name %))))
+       (remove #(re-find #"\$default" (str (:name %))))))
+
+(defn get-public-to-gen-methods [public-to-hand-gen public-no-default]
+  (let [public-to-hand-gen-names
+        (into #{} (mapv (comp str :name) public-to-hand-gen))]
+    (remove #(-> % :name str public-to-hand-gen-names) public-no-default)))
 
-;;;;;;; symbol
+(defn public-by-name-and-param-count [public-reflect-info]
+ (->> public-reflect-info
+      (group-by :name)
+      (map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)]))
+      (into {})))
 
-(def symbol-reflect-info (->> (:members (r/reflect Symbol))
-                              (map #(into {} %))))
+(def license
+  (str
+   ";; Licensed to the Apache Software Foundation (ASF) under one or more\n"
+   ";; contributor license agreements.  See the NOTICE file distributed with\n"
+   ";; this work for additional information regarding copyright ownership.\n"
+   ";; The ASF licenses this file to You under the Apache License, Version 
2.0\n"
+   ";; (the \"License\"); you may not use this file except in compliance 
with\n"
+   ";; the License.  You may obtain a copy of the License at\n"
+   ";;\n"
+   ";;    http://www.apache.org/licenses/LICENSE-2.0\n";
+   ";;\n"
+   ";; Unless required by applicable law or agreed to in writing, software\n"
+   ";; distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+   ";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.\n"
+   ";; See the License for the specific language governing permissions and\n"
+   ";; limitations under the License.\n"
+   ";;\n"))
 
-(def symbol-public (filter (fn [x] (-> x :flags :public)) symbol-reflect-info))
+(defn write-to-file [functions ns-gen fname]
+  (with-open [w (clojure.java.io/writer fname)]
+    (.write w ns-gen)
+    (.write w "\n\n")
+    (.write w ";; Do not edit - this is auto-generated")
+    (.write w "\n\n")
+    (.write w license)
+    (.write w "\n\n")
+    (.write w "\n\n")
+  (doseq [f functions]
+    (clojure.pprint/pprint f w)
+    (.write w "\n"))))
 
-(def symbol-public-no-default (->> symbol-public
-                                   (filter #(not (re-find 
#"org\$apache\$mxnet" (str (:name %)))))
-                                   (filter #(not (re-find #"\$default" (str 
(:name %)))))))
+;;;;;;; Common operations
+
+(def libinfo (Base/_LIB))
+(def op-names
+  (let [l ($ ListBuffer/empty)]
+    (do (.mxListAllOpNames libinfo l)
+        (remove #(or (= "Custom" %)
+                     (re-matches #"^_.*" %))
+                (util/buffer->vec l)))))
+
+(defn- parse-arg-type [s]
+  (let [[_ var-arg-type _ set-arg-type arg-spec _ type-req _ default-val] 
(re-find 
#"(([\w-\[\]\s]+)|\{([^}]+)\})\s*(\([^)]+\))?(,\s*(optional|required)(,\s*default=(.*))?)?"
 s)]
+    {:type (clojure.string/trim (or set-arg-type var-arg-type))
+     :spec arg-spec
+     :optional? (or (= "optional" type-req)
+                    (= "boolean" var-arg-type))
+     :default default-val
+     :orig s}))
+
+(defn- get-op-handle [op-name]
+  (let [ref (new Base$RefLong 0)]
+    (do (.nnGetOpHandle libinfo op-name ref)
+        (.value ref))))
+
+(defn gen-op-info [op-name]
+  (let [handle (get-op-handle op-name)
+        name (new Base$RefString nil)
+        desc (new Base$RefString nil)
+        key-var-num-args (new Base$RefString nil)
+        num-args (new Base$RefInt 0)
+        arg-names ($ ListBuffer/empty)
+        arg-types ($ ListBuffer/empty)
+        arg-descs ($ ListBuffer/empty)]
+    (do (.mxSymbolGetAtomicSymbolInfo libinfo
+                                      handle
+                                      name
+                                      desc
+                                      num-args
+                                      arg-names
+                                      arg-types
+                                      arg-descs
+                                      key-var-num-args)
+        {:fn-name (clojure-case (.value name))
+         :fn-description (.value desc)
+         :args (mapv (fn [t n d] (assoc t :name n :description d))
+                     (mapv parse-arg-type (util/buffer->vec arg-types))
+                     (mapv clojure-case (util/buffer->vec arg-names))
+                     (util/buffer->vec arg-descs))
+         :key-var-num-args (clojure-case (.value key-var-num-args))})))
+
+;;;;;;; Symbol
+
+(def symbol-public-no-default
+  (get-public-no-default-methods Symbol))
 
 (into #{} (mapcat :parameter-types symbol-public-no-default))
-                                        ;#{java.lang.Object 
scala.collection.Seq scala.Option long double scala.collection.immutable.Map 
int ml.dmlc.mxnet.Executor float ml.dmlc.mxnet.Context java.lang.String 
scala.Enumeration$Value ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> 
ml.dmlc.mxnet.Shape java.lang.String<>}
+;; #{java.lang.Object scala.collection.Seq scala.Option long double 
scala.collection.immutable.Map int ml.dmlc.mxnet.Executor float 
ml.dmlc.mxnet.Context java.lang.String scala.Enumeration$Value 
ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> ml.dmlc.mxnet.Shape 
java.lang.String<>}
 
-(def symbol-hand-gen-set  #{"scala.Option"
-                            "int org.apache.mxnet.Executor"
-                            "scala.Enumeration$Value"
-                            "org.apache.mxnet.Context"
-                            "scala.Tuple2"
-                            "scala.collection.Traversable"} )
+(def symbol-hand-gen-set
+  #{"scala.Option"
+    "scala.Enumeration$Value"
+    "org.apache.mxnet.Context"
+    "scala.Tuple2"
+    "scala.collection.Traversable"})
 
 ;;; min and max have a conflicting arity of 2 with the auto gen signatures
 (def symbol-filter-name-set #{"max" "min"})
@@ -102,34 +205,35 @@
         count
         pos?)))
 
-(def symbol-public-to-hand-gen (filter is-symbol-hand-gen? 
symbol-public-no-default))
-(def symbol-public-to-gen (->> (remove  #(contains?(->>  
symbol-public-to-hand-gen
-                                                          (mapv :name)
-                                                          (mapv str)
-                                                          (set)) (str (:name 
%))) symbol-public-no-default)))
+(def symbol-public-to-hand-gen
+  (filter is-symbol-hand-gen? symbol-public-no-default))
+(def symbol-public-to-gen
+  (get-public-to-gen-methods symbol-public-to-hand-gen
+                             symbol-public-no-default))
 
 
 (count symbol-public-to-hand-gen) ;=> 35 mostly bind!
 (count symbol-public-to-gen) ;=> 307
 
-(into #{} (map :name symbol-public-to-hand-gen));=>  #{arange bind ones zeros 
simpleBind Variable}
+(into #{} (map :name symbol-public-to-hand-gen))
+;;=>  #{arange bind ones zeros simpleBind Variable}
 
-(defn public-by-name-and-param-count [public-reflect-info]
- (->> public-reflect-info
-      (group-by :name)
-      (map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)]))
-      (into {})))
 
 
 (defn symbol-vector-args []
-  `(if (map? ~'kwargs-map-or-vec-or-sym) (~'util/empty-list) 
(~'util/coerce-param ~'kwargs-map-or-vec-or-sym #{"scala.collection.Seq"})))
+  `(if (map? ~'kwargs-map-or-vec-or-sym)
+     (~'util/empty-list)
+     (~'util/coerce-param ~'kwargs-map-or-vec-or-sym 
#{"scala.collection.Seq"})))
 
 (defn symbol-map-args []
-  `(if (map? ~'kwargs-map-or-vec-or-sym) (util/convert-symbol-map 
~'kwargs-map-or-vec-or-sym) nil))
+  `(if (map? ~'kwargs-map-or-vec-or-sym)
+     (util/convert-symbol-map ~'kwargs-map-or-vec-or-sym)
+     nil))
 
 
 (defn add-symbol-arities [params function-name]
-  (if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"] (mapv str 
params))
+  (if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"]
+         (mapv str params))
     [`([~'sym-name ~'attr-map ~'kwargs-map]
        (~function-name ~'sym-name (~'util/convert-symbol-map ~'attr-map) 
(~'util/empty-list) (~'util/convert-symbol-map ~'kwargs-map)))
      `([~'sym-name ~'kwargs-map-or-vec-or-sym]
@@ -180,36 +284,7 @@
      `(~'defn ~function-name
        ~@(remove nil? (gen-symbol-function-arity op-name op-values 
function-name))))))
 
-(def license
-  (str
-   ";; Licensed to the Apache Software Foundation (ASF) under one or more\n"
-   ";; contributor license agreements.  See the NOTICE file distributed with\n"
-   ";; this work for additional information regarding copyright ownership.\n"
-   ";; The ASF licenses this file to You under the Apache License, Version 
2.0\n"
-   ";; (the \"License\"); you may not use this file except in compliance 
with\n"
-   ";; the License.  You may obtain a copy of the License at\n"
-   ";;\n"
-   ";;    http://www.apache.org/licenses/LICENSE-2.0\n";
-   ";;\n"
-   ";; Unless required by applicable law or agreed to in writing, software\n"
-   ";; distributed under the License is distributed on an \"AS IS\" BASIS,\n"
-   ";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.\n"
-   ";; See the License for the specific language governing permissions and\n"
-   ";; limitations under the License.\n"
-   ";;\n"))
 
-(defn write-to-file [functions ns-gen fname]
-  (with-open [w (clojure.java.io/writer fname)]
-    (.write w ns-gen)
-    (.write w "\n\n")
-    (.write w ";; Do not edit - this is auto-generated")
-    (.write w "\n\n")
-    (.write w license)
-    (.write w "\n\n")
-    (.write w "\n\n")
-  (doseq [f functions]
-    (clojure.pprint/pprint f w)
-    (.write w "\n"))))
 
 (def symbol-gen-ns "(ns org.apache.clojure-mxnet.symbol
   (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten 
load max
@@ -223,25 +298,18 @@
   (println "Generating symbol file")
   (write-to-file all-symbol-functions symbol-gen-ns 
"src/org/apache/clojure_mxnet/gen/symbol.clj"))
 
+;;;;;;; NDArray
 
-;;;;;;;;NDARRAY
-
-
-(def ndarray-reflect-info (->> (:members (r/reflect NDArray))
-                                (map #(into {} %))))
 
+(def ndarray-public-no-default
+  (get-public-no-default-methods NDArray))
 
-(def ndarray-public (filter (fn [x] (-> x :flags :public)) 
ndarray-reflect-info))
-
-(def ndarray-public-no-default (->> ndarray-public
-                                    (filter #(not (re-find 
#"org\$apache\$mxnet" (str (:name %)))))
-                                    (filter #(not (re-find #"\$default" (str 
(:name %)))))))
-
-(def ndarray-hand-gen-set  #{"org.apache.mxnet.NDArrayFuncReturn"
-                             "org.apache.mxnet.Context"
-                             "scala.Enumeration$Value"
-                             "scala.Tuple2"
-                             "scala.collection.Traversable"} )
+(def ndarray-hand-gen-set
+  #{"org.apache.mxnet.NDArrayFuncReturn"
+    "org.apache.mxnet.Context"
+    "scala.Enumeration$Value"
+    "scala.Tuple2"
+    "scala.collection.Traversable"})
 
 (defn is-ndarray-hand-gen? [info]
   (->> (map str (:parameter-types info))
@@ -251,17 +319,17 @@
        pos?))
 
 
-(def ndarray-public-to-hand-gen (filter is-ndarray-hand-gen? 
ndarray-public-no-default))
-(def ndarray-public-to-gen (->> (remove  #(contains?(->>  
ndarray-public-to-hand-gen
-                                                          (mapv :name)
-                                                          (mapv str)
-                                                          (set)) (str (:name 
%))) ndarray-public-no-default)))
+(def ndarray-public-to-hand-gen
+  (filter is-ndarray-hand-gen? ndarray-public-no-default))
+(def ndarray-public-to-gen
+  (get-public-to-gen-methods ndarray-public-to-hand-gen
+                             ndarray-public-no-default))
 
 
 (count ndarray-public-to-hand-gen) ;=> 15
 (count ndarray-public-to-gen) ;=> 486
 
-(map :name ndarray-public-to-hand-gen)
+(->> ndarray-public-to-hand-gen (map :name) (into #{}))
 
 
 
@@ -294,16 +362,19 @@
           )))))
 
 
+(defn gen-ndarray-functions [public-to-gen-methods]
+  (for [operation (sort (public-by-name-and-param-count 
public-to-gen-methods))]
+    (let [[op-name op-values] operation
+          function-name (-> op-name
+                            str
+                            scala/decode-scala-symbol
+                            clojure-case
+                            symbol)]
+      `(~'defn ~function-name
+        ~@(remove nil? (gen-ndarray-function-arity op-name op-values))))))
+
 (def all-ndarray-functions
- (for [operation  (sort (public-by-name-and-param-count 
ndarray-public-to-gen))]
-   (let [[op-name op-values] operation
-         function-name (-> op-name
-                           str
-                           scala/decode-scala-symbol
-                           clojure-case
-                           symbol)]
-     `(~'defn ~function-name
-       ~@(remove nil? (gen-ndarray-function-arity op-name op-values))))))
+  (gen-ndarray-functions ndarray-public-to-gen))
 
 (def ndarray-gen-ns "(ns org.apache.clojure-mxnet.ndarray
   (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity 
load max
@@ -314,16 +385,191 @@
 
 (defn generate-ndarray-file []
   (println "Generating ndarray file")
-  (write-to-file all-ndarray-functions ndarray-gen-ns  
"src/org/apache/clojure_mxnet/gen/ndarray.clj"))
+  (write-to-file all-ndarray-functions
+                 ndarray-gen-ns
+                 "src/org/apache/clojure_mxnet/gen/ndarray.clj"))
+
+;;;;;;; SymbolAPI
+
+(defn symbol-api-coerce-param
+  [{:keys [name sym type optional?]}]
+  (let [coerced-param (case type
+                        "Shape" `(when ~sym (~'mx-shape/->shape ~sym))
+                        "NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym)
+                        "Map[String, String]"
+                        `(when ~sym
+                           (->> ~sym
+                                (mapv (fn [[~'k ~'v]] [~'k (str ~'v)]))
+                                (into {})
+                                ~'util/convert-map))
+                        sym)
+        nil-param-allowed? (#{"name" "attr"} name)]
+    (if (and optional? (not nil-param-allowed?))
+      `(~'util/->option ~coerced-param)
+      coerced-param)))
+
+(defn gen-symbol-api-doc [fn-description params]
+  (let [param-descriptions (mapv (fn [{:keys [name description optional?]}]
+                                   (str "`" name "`: "
+                                        description
+                                        (when optional? " (optional)")
+                                        "\n"))
+                                 params)]
+    (str fn-description "\n\n"
+         (apply str param-descriptions))))
+
+(defn gen-symbol-api-default-arity [op-name params]
+  (let [opt-params (filter :optional? params)
+        coerced-params (mapv symbol-api-coerce-param params)
+        default-args (array-map :keys (mapv :sym params)
+                                :or (into {}
+                                          (mapv (fn [{:keys [sym]}] [sym nil])
+                                                opt-params))
+                                :as 'opts)]
+    `([~default-args]
+      (~'util/coerce-return
+       (~(symbol (str "SymbolAPI/" op-name))
+        ~@coerced-params)))))
+
+(defn gen-symbol-api-function [op-name]
+  (let [{:keys [fn-name fn-description args]} (gen-op-info op-name)
+        params (mapv (fn [{:keys [name type optional?] :as opts}]
+                       (assoc opts
+                              :sym (symbol name)
+                              :optional? (or optional?
+                                             (= "NDArray-or-Symbol" type))))
+                     (conj args
+                           {:name "name"
+                            :type "String"
+                            :optional? true
+                            :description "Name of the symbol"}
+                           {:name "attr"
+                            :type "Map[String, String]"
+                            :optional? true
+                            :description "Attributes of the symbol"}))
+        doc (gen-symbol-api-doc fn-description params)
+        default-call (gen-symbol-api-default-arity op-name params)]
+    `(~'defn ~(symbol fn-name)
+      ~doc
+      ~@default-call)))
+
+(def all-symbol-api-functions
+  (mapv gen-symbol-api-function op-names))
+
+(def symbol-api-gen-ns "(ns
+  ^{:doc \"Experimental\"}
+  org.apache.clojure-mxnet.symbol-api
+  (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten 
load max
+                            min repeat reverse set sort take to-array empty sin
+                            get apply shuffle ref])
+  (:require [org.apache.clojure-mxnet.util :as util]
+            [org.apache.clojure-mxnet.shape :as mx-shape])
+  (:import (org.apache.mxnet SymbolAPI)))")
+
+(defn generate-symbol-api-file []
+  (println "Generating symbol-api file")
+  (write-to-file all-symbol-api-functions symbol-api-gen-ns 
"src/org/apache/clojure_mxnet/gen/symbol_api.clj"))
+
+;;;;;;; NDArrayAPI
+
+(defn ndarray-api-coerce-param
+  [{:keys [sym type optional?]}]
+  (let [coerced-param (case type
+                        "Shape" `(when ~sym (~'mx-shape/->shape ~sym))
+                        "NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym)
+                        sym)]
+    (if optional?
+      `(~'util/->option ~coerced-param)
+      coerced-param)))
+
+(defn gen-ndarray-api-doc [fn-description params]
+  (let [param-descriptions (mapv (fn [{:keys [name description optional?]}]
+                                   (str "`" name "`: "
+                                        description
+                                        (when optional? " (optional)")
+                                        "\n"))
+                                 params)]
+    (str fn-description "\n\n"
+         (apply str param-descriptions))))
+
+(defn gen-ndarray-api-default-arity [op-name params]
+  (let [opt-params (filter :optional? params)
+        coerced-params (mapv ndarray-api-coerce-param params)
+        default-args (array-map :keys (mapv :sym params)
+                                :or (into {}
+                                          (mapv (fn [{:keys [sym]}] [sym nil])
+                                                opt-params))
+                                :as 'opts)]
+    `([~default-args]
+      (~'util/coerce-return
+       (~(symbol (str "NDArrayAPI/" op-name))
+        ~@coerced-params)))))
+
+(defn gen-ndarray-api-required-arity [fn-name req-params]
+  (let [req-args (->> req-params
+                      (mapv (fn [{:keys [sym]}] [(keyword sym) sym]))
+                      (into {}))]
+    `(~(mapv :sym req-params)
+      (~(symbol fn-name) ~req-args))))
+
+(defn gen-ndarray-api-function [op-name]
+  (let [{:keys [fn-name fn-description args]} (gen-op-info op-name)
+        params (mapv (fn [{:keys [name] :as opts}]
+                       (assoc opts :sym (symbol name)))
+                     (conj args {:name "out"
+                                 :type "NDArray-or-Symbol"
+                                 :optional? true
+                                 :description "Output array."}))
+        doc (gen-ndarray-api-doc fn-description params)
+        opt-params (filter :optional? params)
+        req-params (remove :optional? params)
+        req-call (gen-ndarray-api-required-arity fn-name req-params)
+        default-call (gen-ndarray-api-default-arity op-name params)]
+    (if (= 1 (count req-params))
+      `(~'defn ~(symbol fn-name)
+        ~doc
+        ~@default-call)
+      `(~'defn ~(symbol fn-name)
+        ~doc
+        ~req-call
+        ~default-call))))
+
+(def all-ndarray-api-functions
+  (mapv gen-ndarray-api-function op-names))
+
+(def ndarray-api-gen-ns "(ns
+  ^{:doc \"Experimental\"}
+  org.apache.clojure-mxnet.ndarray-api
+  (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity 
load max
+                            min repeat reverse set sort take to-array empty 
shuffle
+                            ref])
+  (:require [org.apache.clojure-mxnet.shape :as mx-shape]
+            [org.apache.clojure-mxnet.util :as util])
+  (:import (org.apache.mxnet NDArrayAPI)))")
+
+
+(defn generate-ndarray-api-file []
+  (println "Generating ndarray-api file")
+  (write-to-file all-ndarray-api-functions
+                 ndarray-api-gen-ns
+                 "src/org/apache/clojure_mxnet/gen/ndarray_api.clj"))
 
 ;;; autogen the files
 (do
   (generate-ndarray-file)
-  (generate-symbol-file))
+  (generate-ndarray-api-file)
+  (generate-symbol-file)
+  (generate-symbol-api-file))
 
 
 (comment
 
+  (gen-op-info "ElementWiseSum")
+
+  (gen-ndarray-api-function "Activation")
+
+  (gen-symbol-api-function "Activation")
+
   ;; This generates a file with the bulk of the nd-array functions
   (generate-ndarray-file)
 
diff --git 
a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj
new file mode 100644
index 0000000..70359a6
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj
@@ -0,0 +1,32 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.ndarray-api
+  "Experimental NDArray API"
+  (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity 
load max
+                            min repeat reverse set sort take to-array empty 
shuffle
+                            ref])
+
+  (:require [org.apache.clojure-mxnet.base :as base]
+            [org.apache.clojure-mxnet.context :as mx-context]
+            [org.apache.clojure-mxnet.shape :as mx-shape]
+            [org.apache.clojure-mxnet.util :as util]
+            [clojure.reflect :as r]
+            [t6.from-scala.core :refer [$] :as $])
+  (:import (org.apache.mxnet NDArrayAPI)))
+
+;; loads the generated functions into the namespace
+(do (clojure.core/load "gen/ndarray_api"))
diff --git 
a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj
new file mode 100644
index 0000000..69cc813
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj
@@ -0,0 +1,32 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.symbol-api
+  "Experimental Symbol API"
+  (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten 
load max
+                            min repeat reverse set sort take to-array empty sin
+                            get apply shuffle ref])
+  (:require [org.apache.clojure-mxnet.base :as base]
+            [org.apache.clojure-mxnet.context :as mx-context]
+            [org.apache.clojure-mxnet.executor :as ex]
+            [org.apache.clojure-mxnet.shape :as mx-shape]
+            [org.apache.clojure-mxnet.util :as util]
+            [t6.from-scala.core :refer [$] :as $]
+            [org.apache.clojure-mxnet.ndarray :as ndarray])
+  (:import (org.apache.mxnet SymbolAPI)))
+
+;; loads the generated functions into the namespace
+(do (clojure.core/load "gen/symbol_api"))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 89ac1cd..7ee25d4 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -35,7 +35,6 @@
                            "int<>" "vec-of-ints"
                            "float<>" "vec-of-floats"
                            "byte<>" "byte-array"
-                           "java.lang.String<>" "vec-or-strings"
                            "org.apache.mxnet.NDArray" "ndarray"
                            "org.apache.mxnet.Symbol" "sym"
                            "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" 
"double-or-float"})
@@ -49,7 +48,7 @@
                           "int<>" "vec-of-ints"
                           "float<>" "vec-of-floats"
                           "byte<>" "byte-array"
-                          "java.lang.String<>" "vec-or-strings"
+                          "java.lang.String<>" "vec-of-strings"
                           "org.apache.mxnet.Symbol" "sym"
                           "java.lang.Object" "object"})
 
@@ -152,9 +151,12 @@
     (and (get targets "scala.collection.Seq") (instance? 
org.apache.mxnet.Symbol param)) ($/immutable-list param)
     (and (get targets "scala.collection.Seq") (and (or (vector? param) (seq? 
param)) (empty? param))) (empty-list)
     (and (get targets "scala.collection.Seq") (or (vector? param) (seq? 
param))) (apply $/immutable-list param)
+    (and (get targets "org.apache.mxnet.Shape") (or (vector? param) (seq? 
param) (empty? param))) (mx-shape/->shape param)
     (and (get targets "int<>") (vector? param)) (int-array param)
     (and (get targets "float<>") (vector? param)) (float-array param)
     (and (get targets "java.lang.String<>") (vector? param)) (into-array param)
+    (and (get targets "org.apache.mxnet.NDArray<>") (vector? param)) 
(into-array param)
+    (and (get targets "org.apache.mxnet.Symbol<>") (vector? param)) 
(into-array param)
     (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") 
(instance? Float param)) (primitives/mx-float param)
     (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") 
(number? param)) (primitives/mx-double param)
     :else param))
diff --git a/contrib/clojure-package/test/dev/generator_test.clj 
b/contrib/clojure-package/test/dev/generator_test.clj
index 05b4a74..cf28241 100644
--- a/contrib/clojure-package/test/dev/generator_test.clj
+++ b/contrib/clojure-package/test/dev/generator_test.clj
@@ -50,6 +50,127 @@
     (is (= transformed-params (gen/symbol-transform-param-name
                                (:parameter-types (symbol-reflect-info 
"floor")))))))
 
+(deftest test-gen-op-info
+  (testing "activation"
+    (let [activation-info (gen/gen-op-info "Activation")]
+      (is (= "activation" (:fn-name activation-info)))
+      (is (string? (:fn-description activation-info)))
+      (is (= 2 (-> activation-info :args count)))
+      (is (= "" (:key-var-num-args activation-info)))
+
+      (is (= "data" (-> activation-info :args first :name)))
+      (is (= "NDArray-or-Symbol" (-> activation-info :args first :type)))
+      (is (false? (-> activation-info :args first :optional?)))
+      (is (nil? (-> activation-info :args first :default)))
+      (is (string? (-> activation-info :args first :description)))
+
+      (is (= "act-type" (-> activation-info :args second :name)))
+      (is (= "'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'" (-> 
activation-info :args second :type)))
+      (is (false? (-> activation-info :args second :optional?)))
+      (is (nil? (-> activation-info :args second :default)))
+      (is (string? (-> activation-info :args second :description)))))
+
+  (testing "argmin"
+    (let [argmin-info (gen/gen-op-info "argmin")]
+      (is (= "argmin" (:fn-name argmin-info)))
+      (is (= 3 (-> argmin-info :args count)))
+
+      (is (= "data" (-> argmin-info :args (nth 0) :name)))
+      (is (= "NDArray-or-Symbol" (-> argmin-info :args (nth 0) :type)))
+      (is (false? (-> argmin-info :args (nth 0) :optional?)))
+
+      (is (= "axis" (-> argmin-info :args (nth 1) :name)))
+      (is (= "int or None" (-> argmin-info :args (nth 1) :type)))
+      (is (= "'None'" (-> argmin-info :args (nth 1) :default)))
+      (is (true? (-> argmin-info :args (nth 1) :optional?)))
+
+      (is (= "keepdims" (-> argmin-info :args (nth 2) :name)))
+      (is (= "boolean" (-> argmin-info :args (nth 2) :type)))
+      (is (= "0" (-> argmin-info :args (nth 2) :default)))
+      (is (true? (-> argmin-info :args (nth 2) :optional?)))))
+
+  (testing "concat"
+    (let [concat-info (gen/gen-op-info "Concat")]
+      (is (= "concat" (:fn-name concat-info)))
+      (is (= 3 (-> concat-info :args count)))
+      (is (= "num-args" (:key-var-num-args concat-info)))
+
+      (is (= "data" (-> concat-info :args (nth 0) :name)))
+      (is (= "NDArray-or-Symbol[]" (-> concat-info :args (nth 0) :type)))
+      (is (false? (-> concat-info :args (nth 0) :optional?)))
+
+      (is (= "num-args" (-> concat-info :args (nth 1) :name)))
+      (is (= "int" (-> concat-info :args (nth 1) :type)))
+      (is (false? (-> concat-info :args (nth 1) :optional?)))
+
+      (is (= "dim" (-> concat-info :args (nth 2) :name)))
+      (is (= "int" (-> concat-info :args (nth 2) :type)))
+      (is (= "'1'" (-> concat-info :args (nth 2) :default)))
+      (is (true? (-> concat-info :args (nth 2) :optional?)))))
+
+  (testing "convolution"
+    (let [convolution-info (gen/gen-op-info "Convolution")]
+
+      (is (= "convolution" (:fn-name convolution-info)))
+      (is (= 14 (-> convolution-info :args count)))
+      (is (= "" (:key-var-num-args convolution-info)))
+
+      (is (= "data" (-> convolution-info :args (nth 0) :name)))
+      (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 0) :type)))
+      (is (false? (-> convolution-info :args (nth 0) :optional?)))
+
+      (is (= "weight" (-> convolution-info :args (nth 1) :name)))
+      (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 1) :type)))
+      (is (false? (-> convolution-info :args (nth 1) :optional?)))
+
+      (is (= "kernel" (-> convolution-info :args (nth 3) :name)))
+      (is (= "Shape" (-> convolution-info :args (nth 3) :type)))
+      (is (= "(tuple)" (-> convolution-info :args (nth 3) :spec)))
+      (is (false? (-> convolution-info :args (nth 3) :optional?)))
+
+      (is (= "stride" (-> convolution-info :args (nth 4) :name)))
+      (is (= "Shape" (-> convolution-info :args (nth 4) :type)))
+      (is (= "(tuple)" (-> convolution-info :args (nth 4) :spec)))
+      (is (= "[]" (-> convolution-info :args (nth 4) :default)))
+      (is (true? (-> convolution-info :args (nth 4) :optional?)))
+
+      (is (= "num-filter" (-> convolution-info :args (nth 7) :name)))
+      (is (= "int" (-> convolution-info :args (nth 7) :type)))
+      (is (= "(non-negative)" (-> convolution-info :args (nth 7) :spec)))
+      (is (false? (-> convolution-info :args (nth 7) :optional?)))
+
+      (is (= "num-group" (-> convolution-info :args (nth 8) :name)))
+      (is (= "int" (-> convolution-info :args (nth 8) :type)))
+      (is (= "(non-negative)" (-> convolution-info :args (nth 8) :spec)))
+      (is (= "1" (-> convolution-info :args (nth 8) :default)))
+      (is (true? (-> convolution-info :args (nth 8) :optional?)))
+
+      (is (= "workspace" (-> convolution-info :args (nth 9) :name)))
+      (is (= "long" (-> convolution-info :args (nth 9) :type)))
+      (is (= "(non-negative)" (-> convolution-info :args (nth 9) :spec)))
+      (is (= "1024" (-> convolution-info :args (nth 9) :default)))
+      (is (true? (-> convolution-info :args (nth 9) :optional?)))
+
+      (is (= "no-bias" (-> convolution-info :args (nth 10) :name)))
+      (is (= "boolean" (-> convolution-info :args (nth 10) :type)))
+      (is (= "0" (-> convolution-info :args (nth 10) :default)))
+      (is (true? (-> convolution-info :args (nth 10) :optional?)))
+
+      (is (= "layout" (-> convolution-info :args (nth 13) :name)))
+      (is (= "None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'" (-> 
convolution-info :args (nth 13) :type)))
+      (is (= "'None'" (-> convolution-info :args (nth 13) :default)))
+      (is (true? (-> convolution-info :args (nth 13) :optional?)))))
+
+  (testing "element wise sum"
+    (let [element-wise-sum-info (gen/gen-op-info "ElementWiseSum")]
+      (is (= "add-n" (:fn-name element-wise-sum-info)))
+      (is (= 1 (-> element-wise-sum-info :args count)))
+      (is (= "num-args" (:key-var-num-args element-wise-sum-info)))
+
+      (is (= "args" (-> element-wise-sum-info :args (nth 0) :name)))
+      (is (= "NDArray-or-Symbol[]" (-> element-wise-sum-info :args (nth 0) 
:type)))
+      (is (false? (-> element-wise-sum-info :args (nth 0) :optional?))))))
+
 (deftest test-ndarray-transform-param-name
   (let [params ["scala.collection.immutable.Map"
                 "scala.collection.Seq"]
@@ -68,7 +189,10 @@
 
 (deftest test-rename-duplicate-params
   (is (= ["foo" "bar" "baz"] (gen/rename-duplicate-params ["foo" "bar" 
"baz"])))
-  (is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar" 
"bar"]))))
+  (is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar" 
"bar"])))
+  (is (= ["foo" "bar" "bar-1" "foo-1"] (gen/rename-duplicate-params ["foo" 
"bar" "bar" "foo"])))
+  (is (= ["foo" "bar" "bar-1" "bar-2"] (gen/rename-duplicate-params ["foo" 
"bar" "bar" "bar"])))
+  (is (= ["foo" "bar" "bar-1" "bar-2" "foo-1" "baz"] 
(gen/rename-duplicate-params ["foo" "bar" "bar" "bar" "foo" "baz"]))))
 
 (deftest test-is-symbol-hand-gen?
   (is (not (false? (gen/is-symbol-hand-gen? (symbol-reflect-info "max")))))
@@ -191,7 +315,17 @@
            (gen/gen-ndarray-function-arity op-name op-values)))))
 
 (deftest test-write-to-file
-  (testing "symbol"
+  (testing "symbol-api"
+    (let [fname "test/test-symbol-api.clj"
+          _ (gen/write-to-file [(first gen/all-symbol-api-functions)
+                                (second gen/all-symbol-api-functions)]
+                               gen/symbol-api-gen-ns
+                               fname)
+          good-contents (slurp "test/good-test-symbol-api.clj")
+          contents (slurp fname)]
+      (is (= good-contents contents))))
+
+ (testing "symbol"
     (let [fname "test/test-symbol.clj"
           _ (gen/write-to-file [(first gen/all-symbol-functions)]
                                gen/symbol-gen-ns
@@ -200,6 +334,16 @@
           contents (slurp fname)]
       (is (= good-contents contents))))
 
+  (testing "ndarray-api"
+    (let [fname "test/test-ndarray-api.clj"
+          _ (gen/write-to-file [(first gen/all-ndarray-api-functions)
+                                (second gen/all-ndarray-api-functions)]
+                               gen/ndarray-api-gen-ns
+                               fname)
+          good-contents (slurp "test/good-test-ndarray-api.clj")
+          contents (slurp fname)]
+      (is (= good-contents contents))))
+
   (testing "ndarray"
     (let [fname "test/test-ndarray.clj"
           _ (gen/write-to-file [(first gen/all-ndarray-functions)]
diff --git a/contrib/clojure-package/test/good-test-ndarray-api.clj 
b/contrib/clojure-package/test/good-test-ndarray-api.clj
new file mode 100644
index 0000000..1b83a7b
--- /dev/null
+++ b/contrib/clojure-package/test/good-test-ndarray-api.clj
@@ -0,0 +1,89 @@
+(ns
+  ^{:doc "Experimental"}
+  org.apache.clojure-mxnet.ndarray-api
+  (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity 
load max
+                            min repeat reverse set sort take to-array empty 
shuffle
+                            ref])
+  (:require [org.apache.clojure-mxnet.shape :as mx-shape]
+            [org.apache.clojure-mxnet.util :as util])
+  (:import (org.apache.mxnet NDArrayAPI)))
+
+;; Do not edit - this is auto-generated
+
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+
+
+
+(defn
+ activation
+ "Applies an activation function element-wise to the input.\n\nThe following 
activation functions are supported:\n\n- `relu`: Rectified Linear Unit, 
:math:`y = max(x, 0)`\n- `sigmoid`: :math:`y = \\frac{1}{1 + exp(-x)}`\n- 
`tanh`: Hyperbolic tangent, :math:`y = \\frac{exp(x) - exp(-x)}{exp(x) + 
exp(-x)}`\n- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`\n- 
`softsign`: :math:`y = \\frac{x}{1 + abs(x)}`\n\n\n\nDefined in 
src/operator/nn/activation.cc:L167\n\n`data`: The in [...]
+ ([data act-type] (activation {:data data, :act-type act-type}))
+ ([{:keys [data act-type out], :or {out nil}, :as opts}]
+  (util/coerce-return
+   (NDArrayAPI/Activation data act-type (util/->option out)))))
+
+(defn
+ batch-norm
+ "Batch normalization.\n\nNormalizes a data batch by mean and variance, and 
applies a scale ``gamma`` as\nwell as offset ``beta``.\n\nAssume the input has 
more than one dimension and we normalize along axis 1.\nWe first compute the 
mean and variance along this axis:\n\n.. math::\n\n  data\\_mean[i] = 
mean(data[:,i,:,...]) \\\\\n  data\\_var[i] = var(data[:,i,:,...])\n\nThen 
compute the normalized output, which has the same shape as input, as 
following:\n\n.. math::\n\n  out[:,i,:,...] =  [...]
+ ([data gamma beta moving-mean moving-var]
+  (batch-norm
+   {:data data,
+    :gamma gamma,
+    :beta beta,
+    :moving-mean moving-mean,
+    :moving-var moving-var}))
+ ([{:keys
+    [data
+     gamma
+     beta
+     moving-mean
+     moving-var
+     eps
+     momentum
+     fix-gamma
+     use-global-stats
+     output-mean-var
+     axis
+     cudnn-off
+     out],
+    :or
+    {eps nil,
+     momentum nil,
+     fix-gamma nil,
+     use-global-stats nil,
+     output-mean-var nil,
+     axis nil,
+     cudnn-off nil,
+     out nil},
+    :as opts}]
+  (util/coerce-return
+   (NDArrayAPI/BatchNorm
+    data
+    gamma
+    beta
+    moving-mean
+    moving-var
+    (util/->option eps)
+    (util/->option momentum)
+    (util/->option fix-gamma)
+    (util/->option use-global-stats)
+    (util/->option output-mean-var)
+    (util/->option axis)
+    (util/->option cudnn-off)
+    (util/->option out)))))
+
diff --git a/contrib/clojure-package/test/good-test-symbol-api.clj 
b/contrib/clojure-package/test/good-test-symbol-api.clj
new file mode 100644
index 0000000..a030884
--- /dev/null
+++ b/contrib/clojure-package/test/good-test-symbol-api.clj
@@ -0,0 +1,109 @@
+(ns
+  ^{:doc "Experimental"}
+  org.apache.clojure-mxnet.symbol-api
+  (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten 
load max
+                            min repeat reverse set sort take to-array empty sin
+                            get apply shuffle ref])
+  (:require [org.apache.clojure-mxnet.util :as util]
+            [org.apache.clojure-mxnet.shape :as mx-shape])
+  (:import (org.apache.mxnet SymbolAPI)))
+
+;; Do not edit - this is auto-generated
+
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+
+
+
+(defn
+ activation
+ "Applies an activation function element-wise to the input.\n\nThe following 
activation functions are supported:\n\n- `relu`: Rectified Linear Unit, 
:math:`y = max(x, 0)`\n- `sigmoid`: :math:`y = \\frac{1}{1 + exp(-x)}`\n- 
`tanh`: Hyperbolic tangent, :math:`y = \\frac{exp(x) - exp(-x)}{exp(x) + 
exp(-x)}`\n- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`\n- 
`softsign`: :math:`y = \\frac{x}{1 + abs(x)}`\n\n\n\nDefined in 
src/operator/nn/activation.cc:L167\n\n`data`: The in [...]
+ [{:keys [data act-type name attr],
+   :or {data nil, name nil, attr nil},
+   :as opts}]
+ (util/coerce-return
+  (SymbolAPI/Activation
+   (util/->option data)
+   act-type
+   name
+   (clojure.core/when
+    attr
+    (clojure.core/->>
+     attr
+     (clojure.core/mapv
+      (clojure.core/fn [[k v]] [k (clojure.core/str v)]))
+     (clojure.core/into {})
+     util/convert-map)))))
+
+(defn
+ batch-norm
+ "Batch normalization.\n\nNormalizes a data batch by mean and variance, and 
applies a scale ``gamma`` as\nwell as offset ``beta``.\n\nAssume the input has 
more than one dimension and we normalize along axis 1.\nWe first compute the 
mean and variance along this axis:\n\n.. math::\n\n  data\\_mean[i] = 
mean(data[:,i,:,...]) \\\\\n  data\\_var[i] = var(data[:,i,:,...])\n\nThen 
compute the normalized output, which has the same shape as input, as 
following:\n\n.. math::\n\n  out[:,i,:,...] =  [...]
+ [{:keys
+   [data
+    gamma
+    beta
+    moving-mean
+    moving-var
+    eps
+    momentum
+    fix-gamma
+    use-global-stats
+    output-mean-var
+    axis
+    cudnn-off
+    name
+    attr],
+   :or
+   {output-mean-var nil,
+    axis nil,
+    cudnn-off nil,
+    fix-gamma nil,
+    eps nil,
+    data nil,
+    attr nil,
+    beta nil,
+    name nil,
+    use-global-stats nil,
+    moving-mean nil,
+    moving-var nil,
+    momentum nil,
+    gamma nil},
+   :as opts}]
+ (util/coerce-return
+  (SymbolAPI/BatchNorm
+   (util/->option data)
+   (util/->option gamma)
+   (util/->option beta)
+   (util/->option moving-mean)
+   (util/->option moving-var)
+   (util/->option eps)
+   (util/->option momentum)
+   (util/->option fix-gamma)
+   (util/->option use-global-stats)
+   (util/->option output-mean-var)
+   (util/->option axis)
+   (util/->option cudnn-off)
+   name
+   (clojure.core/when
+    attr
+    (clojure.core/->>
+     attr
+     (clojure.core/mapv
+      (clojure.core/fn [[k v]] [k (clojure.core/str v)]))
+     (clojure.core/into {})
+     util/convert-map)))))
+
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj
index feda45b..ca9d4bc 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj
@@ -24,6 +24,8 @@
             [org.apache.clojure-mxnet.module :as m]
             [org.apache.clojure-mxnet.optimizer :as optimizer]
             [org.apache.clojure-mxnet.symbol :as sym]
+            [org.apache.clojure-mxnet.symbol-api :as sym-api]
+            [org.apache.clojure-mxnet.util :as util]
             [clojure.reflect :as r]))
 
 (def data-dir "data/")
@@ -54,17 +56,19 @@
 (defn get-symbol []
   (as-> (sym/variable "data") data
 
-    (sym/convolution "conv1" {:data data :kernel [3 3] :num-filter 32 :stride 
[2 2]})
-    (sym/batch-norm "bn1" {:data data})
-    (sym/activation "relu1" {:data data :act-type "relu"})
-    (sym/pooling "mp1" {:data data :kernel [2 2] :pool-type "max" :stride [2 
2]}) (sym/convolution "conv2" {:data data :kernel [3 3] :num-filter 32 :stride 
[2 2]})
-    (sym/batch-norm "bn2" {:data data})
-    (sym/activation "relu2" {:data data :act-type "relu"})
-    (sym/pooling "mp2" {:data data :kernel [2 2] :pool-type "max" :stride [2 
2]})
+    (sym-api/convolution {:name "conv1" :data data :kernel [3 3] :num-filter 
32 :stride [2 2]})
+    (sym-api/batch-norm {:name "bn1" :data data})
+    (sym-api/activation {:name "relu1" :data data :act-type "relu"})
+    (sym-api/pooling {:name "mp1" :data data :kernel [2 2] :pool-type "max" 
:stride [2 2]})
 
-    (sym/flatten "fl" {:data data})
-    (sym/fully-connected "fc2" {:data data :num-hidden 10})
-    (sym/softmax-output "softmax" {:data data})))
+    (sym-api/convolution {:name "conv2" :data data :kernel [3 3] :num-filter 
32 :stride [2 2]})
+    (sym-api/batch-norm {:name "bn2" :data data})
+    (sym-api/activation {:name "relu2" :data data :act-type "relu"})
+    (sym-api/pooling {:name "mp2" :data data :kernel [2 2] :pool-type "max" 
:stride [2 2]})
+
+    (sym-api/flatten {:name "fl" :data data})
+    (sym-api/fully-connected {:name "fc2" :data data :num-hidden 10})
+    (sym-api/softmax-output {:name "softmax" :data data})))
 
 (deftest test-conv []
   (let [mod (m/module (get-symbol))]
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj
new file mode 100644
index 0000000..18b8b78
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj
@@ -0,0 +1,415 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.ndarray-api-test
+  (:require [org.apache.clojure-mxnet.base :as base]
+            [org.apache.clojure-mxnet.context :as ctx]
+            [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.ndarray :as ndarray :refer [->vec zeros 
ones += -= *= full shape shape-vec]]
+            [org.apache.clojure-mxnet.ndarray-api :as ndarray-api]
+            [org.apache.clojure-mxnet.shape :as mx-shape :refer [->shape]]
+            [org.apache.clojure-mxnet.test-util :as test-util :refer [approx=]]
+            [org.apache.clojure-mxnet.util :as util :refer [->option]]
+            [clojure.test :refer :all]))
+
+(deftest test-activation
+  (let [data (ndarray/array [2 1 0 -1 -2] [1 5])
+        relu (ndarray-api/activation data "relu")
+        sigmoid (ndarray-api/activation data "sigmoid")
+        softsign (ndarray-api/activation data "softsign")
+        out (ndarray/zeros [1 5])
+        _ (ndarray-api/activation {:data data :act-type "relu" :out out})]
+    (is (= [2.0 1.0 0.0 0.0 0.0] (->vec relu)))
+    (is (approx= 1e-3 [0.881 0.731 0.5 0.269 0.119] (->vec sigmoid)))
+    (is (approx= 1e-3 [0.666 0.5 0.0 -0.5 -0.666] (->vec softsign)))
+    (is (= [2.0 1.0 0.0 0.0 0.0] (->vec out)))))
+
+(deftest test-bilinear-sampler
+  (let [data (ndarray/array [1 4 3 6
+                             1 8 8 9
+                             0 4 1 5
+                             1 0 1 3]
+                            [1 1 4 4])
+        affine (ndarray/array [2 0 0
+                               0 2 0]
+                              [1 6])
+        grid (ndarray-api/grid-generator {:data affine :transform-type 
"affine" :target-shape [4 4]})
+        out (ndarray-api/bilinear-sampler data grid)]
+    (is (approx= 1e-3
+                 [0.0 0.0 0.0 0.0
+                  0.0 3.5 6.5 0.0
+                  0.0 1.25 2.5 0.0
+                  0.0 0.0 0.0 0.0]
+                 (->vec out)))))
+
+(deftest test-cast
+  (let [nda1 (ndarray/array [0.9 1.3] [2])
+        nda2 (ndarray/array [1e20 11.1] [2])
+        nda3 (ndarray/array [300 11.1 10.9 -1 -3] [5])
+        out (ndarray/zeros [2] {:dtype dtype/INT32})
+        _ (ndarray-api/cast {:data nda1 :dtype (str dtype/INT32) :out out})]
+    (is (= [0.0 1.0] (->vec (ndarray-api/cast nda1 (str dtype/INT32)))))
+    (is (= [(float 1e20) (float 11.1)] (->vec (ndarray-api/cast nda2 (str 
dtype/FLOAT32)))))
+    ;; uint8 gets converted to native types after ->vec
+    (is (= [44.0 11.0 10.0 -1.0 -3.0] (->vec (ndarray-api/cast nda3 
"uint8"))))))
+
+(deftest test-concat
+  (let [nda1 (ndarray/zeros [1 2])
+        nda2 (ndarray/ones [1 2])
+        out (ndarray/zeros [1 4])
+        res1 (ndarray-api/concat [nda1 nda2] 2) ;; num_args=2, dim=1 (default)
+        res2 (ndarray-api/concat {:data [nda1 nda2] :num-args 2 :dim 0}) ;; 
num_args=2, dim=0
+        res3 (ndarray-api/concat {:data [nda1 nda2 nda1] :num-args 3 :dim 1}) 
;; num_args=3, dim=1
+        _ (ndarray-api/concat {:data [nda1 nda2] :num-args 2 :dim 1 :out out}) 
;; store result in out
+        ]
+    (is (= [0.0 0.0 1.0 1.0] (->vec res1)))
+    (is (= [1 4] (shape-vec res1)))
+    (is (= [0.0 0.0 1.0 1.0] (->vec res2)))
+    (is (= [2 2] (shape-vec res2)))
+    (is (= [0.0 0.0 1.0 1.0 0.0 0.0] (->vec res3)))
+    (is (= [1 6] (shape-vec res3)))
+    (is (= [0.0 0.0 1.0 1.0] (->vec out)))
+    (is (= [1 4] (shape-vec out)))))
+
+(deftest test-embedding
+  (let [input-dim 4
+        output-dim 5
+        w (ndarray/array [0.  1.  2.  3.  4.
+                          5.  6.  7.  8.  9.
+                          10. 11. 12. 13. 14.
+                          15. 16. 17. 18. 19.]
+                         [4 5])
+        x (ndarray/array [1. 3.
+                          0. 2.]
+                         [2 2])
+        out (ndarray-api/embedding x w input-dim output-dim)]
+    (is (= [5.  6.  7.  8.  9.
+            15. 16. 17. 18. 19.
+            0.  1.  2.  3.  4.
+            10. 11. 12. 13. 14.]
+           (->vec out)))
+    (is (= [2 2 5] (shape-vec out)))))
+
+(deftest test-flatten
+  (let [nda (ndarray/array [1 2 3
+                            4 5 6
+                            7 8 9
+                            1 2 3
+                            4 5 6
+                            7 8 9]
+                           [2 3 3])
+        out (ndarray/zeros [2 9])
+        res (ndarray-api/flatten {:data nda})
+        _ (ndarray-api/flatten {:data nda :out out})]
+    (is (= [1. 2. 3. 4. 5. 6. 7. 8. 9.
+            1. 2. 3. 4. 5. 6. 7. 8. 9.] (->vec res)))
+    (is (= [2 9] (shape-vec res)))
+    (is (= [1. 2. 3. 4. 5. 6. 7. 8. 9.
+            1. 2. 3. 4. 5. 6. 7. 8. 9.] (->vec out)))
+    (is (= [2 9] (shape-vec out)))))
+
+(deftest test-instance-norm
+  (let [x (ndarray/array [1.1 2.2 3.3 4.4] [2 1 2])
+        gamma (ndarray/array [1.5] [1])
+        beta (ndarray/array [0.5] [1])
+        res (ndarray-api/instance-norm x gamma beta)]
+    (is (approx= 1e-4 [-0.9975 1.9975
+                       -0.9975 1.9975] (->vec res)))
+    (is (= [2 1 2] (shape-vec res)))))
+
+(deftest test-l2-normalization
+  (let [x (ndarray/array [1 2 3 4 2 2 5 6] [2 2 2])
+        res1 (ndarray-api/l2-normalization {:data x}) ;; instance-wise
+        res2 (ndarray-api/l2-normalization {:data x :mode "instance"})
+        res3 (ndarray-api/l2-normalization {:data x :mode "channel"})
+        res4 (ndarray-api/l2-normalization {:data x :mode "spatial"})]
+    (is (approx= 1e-4 [0.1825 0.3651
+                       0.5477 0.7303
+                       0.2407 0.2407
+                       0.6019 0.7223] (->vec res1)))
+    (is (approx= 1e-4 [0.1825 0.3651
+                       0.5477 0.7303
+                       0.2407 0.2407
+                       0.6019 0.7223] (->vec res2)))
+    (is (approx= 1e-4 [0.3162 0.4472
+                       0.9486 0.8944
+                       0.3714 0.3162
+                       0.9284 0.9486] (->vec res3)))
+    (is (approx= 1e-4 [0.4472 0.8944
+                       0.6    0.8
+                       0.7071 0.7071
+                       0.6402 0.7682] (->vec res4)))))
+
+(deftest test-pad
+  (let [x (ndarray/array [1 2 3
+                          4 5 6
+                          7 8 9
+                          10 11 12
+                          11 12 13
+                          14 15 16
+                          17 18 19
+                          20 21 22]
+                         [2 2 2 3])
+        res1 (ndarray-api/pad x "edge" [0,0,0,0,1,1,1,1])
+        res2 (ndarray-api/pad {:data x :mode "constant" :pad-width 
[0,0,0,0,1,1,1,1] :constant-value 0})]
+    (is (= [1.   1.   2.   3.   3.
+            1.   1.   2.   3.   3.
+            4.   4.   5.   6.   6.
+            4.   4.   5.   6.   6.
+            7.   7.   8.   9.   9.
+            7.   7.   8.   9.   9.
+            10.  10.  11.  12.  12.
+            10.  10.  11.  12.  12.
+            11.  11.  12.  13.  13.
+            11.  11.  12.  13.  13.
+            14.  14.  15.  16.  16.
+            14.  14.  15.  16.  16.
+            17.  17.  18.  19.  19.
+            17.  17.  18.  19.  19.
+            20.  20.  21.  22.  22.
+            20.  20.  21.  22.  22.] (->vec res1)))
+    (is (= [2 2 4 5] (shape-vec res1)))
+    (is (= [0.   0.   0.   0.   0.
+            0.   1.   2.   3.   0.
+            0.   4.   5.   6.   0.
+            0.   0.   0.   0.   0.
+            
+            0.   0.   0.   0.   0.
+            0.   7.   8.   9.   0.
+            0.  10.  11.  12.   0.
+            0.   0.   0.   0.   0.
+            
+            0.   0.   0.   0.   0.
+            0.  11.  12.  13.   0.
+            0.  14.  15.  16.   0.
+            0.   0.   0.   0.   0.
+            
+            0.   0.   0.   0.   0.
+            0.  17.  18.  19.   0.
+            0.  20.  21.  22.   0.
+            0.   0.   0.   0.   0.] (->vec res2)))
+    (is (= [2 2 4 5] (shape-vec res2)))))
+
+(deftest test-roi-pooling
+  (let [xi [[[[  0.,   1.,   2.,   3.,   4.,   5.],
+              [  6.,   7.,   8.,   9.,  10.,  11.],
+              [ 12.,  13.,  14.,  15.,  16.,  17.],
+              [ 18.,  19.,  20.,  21.,  22.,  23.],
+              [ 24.,  25.,  26.,  27.,  28.,  29.],
+              [ 30.,  31.,  32.,  33.,  34.,  35.],
+              [ 36.,  37.,  38.,  39.,  40.,  41.],
+              [ 42.,  43.,  44.,  45.,  46.,  47.]]]]
+        x (ndarray/array (-> xi flatten vec) [1 1 8 6])
+        y (ndarray/array [0 0 0 4 4] [1 5])
+        res1 (ndarray-api/roi-pooling x y [2 2] 1.0)
+        res2 (ndarray-api/roi-pooling x y [2 2] 0.7)]
+    (is (= [14. 16. 26. 28.] (->vec res1)))
+    (is (= [1 1 2 2] (shape-vec res1)))
+    (is (= [7. 9. 19. 21.] (->vec res2)))
+    (is (= [1 1 2 2] (shape-vec res2)))))
+
+(deftest test-reshape
+  (let [x (ndarray/array (vec (range 4)) [4])
+        y (ndarray/array (vec (range 24)) [2 3 4])
+        z (ndarray/array (vec (range 120)) [2 3 4 5])
+        res1 (ndarray-api/reshape {:data x :shape [2 2]})]
+    (is (= [0. 1. 2. 3.] (->vec res1)))
+    (is (= [2 2] (shape-vec res1)))
+    (is (= (map float (range 24)) (->vec (ndarray-api/reshape {:data y :shape 
[4 0 2]}))))
+    (is (= [4 3 2] (shape-vec (ndarray-api/reshape {:data y :shape [4 0 2]}))))
+    (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 0 0]}))))
+    (is (= [6 1 4] (shape-vec (ndarray-api/reshape {:data y :shape [6 1 
-1]}))))
+    (is (= [3 1 8] (shape-vec (ndarray-api/reshape {:data y :shape [3 -1 
8]}))))
+    (is (= [24] (shape-vec (ndarray-api/reshape {:data y :shape [-1]}))))
+    (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [-2]}))))
+    (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 -2]}))))
+    (is (= [2 3 4 1 1] (shape-vec (ndarray-api/reshape {:data y :shape [-2 1 
1]}))))
+    (is (= [6 4] (shape-vec (ndarray-api/reshape {:data y :shape [-3 4]}))))
+    (is (= [6 20] (shape-vec (ndarray-api/reshape {:data z :shape [-3 -3]}))))
+    (is (= [2 12] (shape-vec (ndarray-api/reshape {:data y :shape [0 -3]}))))
+    (is (= [6 4] (shape-vec (ndarray-api/reshape {:data y :shape [-3 -2]}))))
+    (is (= [1 2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [-4 1 2 
-2]}))))
+    (is (= [2 1 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 -4 -1 
3 -2]}))))))
+
+(deftest test-sequence-last
+  (let [xi [[[  1.,   2.,   3.],
+             [  4.,   5.,   6.],
+             [  7.,   8.,   9.]],
+            
+            [[ 10.,   11.,   12.],
+             [ 13.,   14.,   15.],
+             [ 16.,   17.,   18.]],
+            
+            [[  19.,   20.,   21.],
+             [  22.,   23.,   24.],
+             [  25.,   26.,   27.]]]
+        x (ndarray/array (-> xi flatten vec) [3 3 3])
+        seq-len1 (ndarray/array [1 1 1] [3])
+        seq-len2 (ndarray/array [1 2 3] [3])
+        ;; This test is failing with an exception
+        ;; (most likely a scala generation issue)
+        ;; res1 (ndarray-api/sequence-last x nil)
+        ]
+    ;; (is (= [] (->vec res1)))
+))
+
+(deftest test-sequence-mask
+  (let [xi [[[  1.,   2.,   3.],
+             [  4.,   5.,   6.]],
+            
+            [[  7.,   8.,   9.],
+             [ 10.,  11.,  12.]],
+            
+            [[ 13.,  14.,   15.],
+             [ 16.,  17.,   18.]]]
+        x (ndarray/array (-> xi flatten vec) [3 2 3])
+        seq-len1 (ndarray/array [1 1] [2])
+        seq-len2 (ndarray/array [2 3] [2])
+        ;; Same issue as previous test
+        ;; res1 (ndarray-api/sequence-mask x seq-len1)
+        ]
+    ;; (is (= [] (->vec res1)))
+))
+
+(deftest test-slice-channel
+  (let [xi [[[ 1.] [ 2.]]
+            [[ 3.] [ 4.]]
+            [[ 5.] [ 6.]]]
+        x (ndarray/array (-> xi flatten vec) [3 2 1])
+        res1 (ndarray-api/slice-channel {:data x :num-outputs 2 :axis 1})
+        res2 (ndarray-api/slice-channel {:data x :num-outputs 3 :axis 0})
+        res3 (ndarray-api/slice-channel {:data x :num-outputs 3 :axis 0 
:squeeze-axis 1})]
+    (is (= [1. 3. 5.] (->vec res1)))
+    (is (= [3 1 1] (shape-vec res1)))
+    (is (= [1. 2.] (->vec res2)))
+    (is (= [1 2 1] (shape-vec res2)))
+    (is (= [1. 2.] (->vec res3)))
+    (is (= [2 1] (shape-vec res3)))))
+
+(deftest test-softmax-activation
+  (let [x (ndarray/array [1 1 1 1 1 1] [2 3])
+        res1 (ndarray-api/softmax-activation {:data x :mode "instance"})]
+    (is (approx= 1e-3 [0.333 0.333 0.333
+                       0.333 0.333 0.333] (->vec res1)))
+    (is (= [2 3] (shape-vec res1)))))
+
+(deftest test-softmax-output
+  (let [datai [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
+        data (ndarray/array (-> datai flatten vec) [4 4])
+        label (ndarray/array [1,0,2,3] [4])
+        res1 (ndarray-api/softmax-output data label)]
+    (is (approx= 1e-4 [0.0321 0.0871 0.2369 0.6439
+                       0.25 0.25 0.25 0.25
+                       0.25 0.25 0.25 0.25
+                       0.25 0.25 0.25 0.25] (->vec res1)))
+    (is (= [4 4] (shape-vec res1)))))
+
+(deftest test-swap-axis
+  (let [x (ndarray/array (range 3) [1 3])
+        y (ndarray/array (range 8) [2 2 2])
+        res1 (ndarray-api/swap-axis {:data x :dim1 0 :dim2 1})
+        res2 (ndarray-api/swap-axis {:data y :dim1 0 :dim2 2})]
+    (is (= [0. 1. 2.] (->vec res1)))
+    (is (= [3 1] (shape-vec res1)))
+    (is (= [0. 4. 2. 6. 1. 5. 3. 7.] (->vec res2)))
+    (is (= [2 2 2] (shape-vec res2)))))
+
+(deftest test-abs
+  (let [x (ndarray/array [-2 0 3] [3])
+        res1 (ndarray-api/abs {:data x})]
+    (is (= [2. 0. 3.] (->vec res1)))
+    (is (= [3] (shape-vec res1)))))
+
+(deftest test-arccos
+  (let [x (ndarray/array [-1 -0.707 0 0.707 1] [5])
+        pi Math/PI
+        res1 (ndarray-api/arccos {:data x})]
+    (is (approx= 1e-3 [pi (* 0.75 pi) (* 0.5 pi) (* 0.25 pi) 0.] (->vec 
res1)))))
+
+(deftest test-arcsin
+  (let [x (ndarray/array [-1 -0.707 0 0.707 1] [5])
+        pi Math/PI
+        res1 (ndarray-api/arcsin {:data x})]
+    (is (approx= 1e-3 [(- (* 0.5 pi)) (- (* 0.25 pi)) 0 (* 0.25 pi) (* 0.5 
pi)] (->vec res1)))))
+
+(deftest test-argmax
+  (let [x (ndarray/array (range 6) [2 3])
+        res1 (ndarray-api/argmax {:data x :axis 0})
+        res2 (ndarray-api/argmax {:data x :axis 1})
+        res3 (ndarray-api/argmax {:data x :axis 0 :keepdims true})
+        res4 (ndarray-api/argmax {:data x :axis 1 :keepdims true})]
+    (is (= [1. 1. 1.] (->vec res1)))
+    (is (= [3] (shape-vec res1)))
+    (is (= [2. 2.] (->vec res2)))
+    (is (= [2] (shape-vec res2)))
+    (is (= [1. 1. 1.] (->vec res3)))
+    (is (= [1 3] (shape-vec res3)))
+    (is (= [2. 2.] (->vec res4)))
+    (is (= [2 1] (shape-vec res4)))))
+
+(deftest test-argmax-channel
+  (let [x (ndarray/array (range 6) [2 3])
+        res1 (ndarray-api/argmax-channel {:data x})]
+    (is (= [2. 2.] (->vec res1)))
+    (is (= [2] (shape-vec res1)))))
+
+(deftest test-argmin
+  (let [x (ndarray/array (reverse (range 6)) [2 3])
+        res1 (ndarray-api/argmin {:data x :axis 0})
+        res2 (ndarray-api/argmin {:data x :axis 1})
+        res3 (ndarray-api/argmin {:data x :axis 0 :keepdims true})
+        res4 (ndarray-api/argmin {:data x :axis 1 :keepdims true})]
+    (is (= [1. 1. 1.] (->vec res1)))
+    (is (= [3] (shape-vec res1)))
+    (is (= [2. 2.] (->vec res2)))
+    (is (= [2] (shape-vec res2)))
+    (is (= [1. 1. 1.] (->vec res3)))
+    (is (= [1 3] (shape-vec res3)))
+    (is (= [2. 2.] (->vec res4)))
+    (is (= [2 1] (shape-vec res4)))))
+
+(deftest test-argsort
+  (let [x (ndarray/array [0.3  0.2  0.4
+                          0.1  0.3  0.2]
+                         [2 3])
+        y (ndarray/array [0.3 0.2 0.4 0.1 0.3 0.2] [6])
+        res1 (ndarray-api/argsort {:data x})
+        res2 (ndarray-api/argsort {:data x :axis 0})
+        res3 (ndarray-api/argsort {:data y})]
+    (is (= [1. 0. 2.
+            0. 2. 1.]
+           (->vec res1)))
+    (is (= [2 3] (shape-vec res1)))
+    (is (= [1. 0. 1.
+            0. 1. 0.]
+           (->vec res2)))
+    (is (= [2 3] (shape-vec res1)))
+    (is (= [3. 1. 5. 0. 4. 2.] (->vec res3)))
+    (is (= [6] (shape-vec res3)))))
+
+(deftest test-batch-take
+  (let [x (ndarray/array (range 6) [3 2])
+        i (ndarray/as-type (ndarray/array [0 1 0] [3]) dtype/INT32)
+        res1 (ndarray-api/batch-take x i)        ]
+    (is (= [0. 3. 4.] (->vec res1)))))
+
+(deftest test-broadcast-add
+  (let [x (ndarray/ones [2 3])
+        y (ndarray/array (range 2) [2 1])
+        res1 (ndarray-api/broadcast-add x y)]
+    (is (= [1. 1. 1. 2. 2. 2.] (->vec res1)))
+    (is (= [2 3] (shape-vec res1)))))
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj
new file mode 100644
index 0000000..b642ad7
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj
@@ -0,0 +1,61 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.symbol-api-test
+  (:require [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.executor :as executor]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]
+            [org.apache.clojure-mxnet.symbol :as sym]
+            [org.apache.clojure-mxnet.symbol-api :as sym-api]
+            [org.apache.clojure-mxnet.util :as util]
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.context :as context]))
+
+(deftest test-compose
+  (let [data (sym/variable "data")
+        net1 (sym-api/fully-connected {:data data :num-hidden 10 :name "fc1"})
+        net1 (sym-api/fully-connected {:data net1 :num-hidden 100 :name "fc2"} 
)
+
+        net2 (sym-api/fully-connected {:num-hidden 10 :name "fc3"})
+        net2 (sym-api/activation {:data net2 :act-type "relu"})
+        net2 (sym-api/fully-connected {:data net2 :num-hidden 20 :name "fc4"})
+
+        composed (sym/apply net2 "composed" {"fc3_data" net1})
+
+        multi-out (sym/group [composed net1])]
+
+    (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"] 
(sym/list-arguments net1)))
+    (is (= 2 (count (sym/list-outputs multi-out))))))
+
+(deftest test-symbol-internal
+  (let [data (sym/variable "data")
+        oldfc (sym-api/fully-connected {:data data :num-hidden 10 :name"fc1"})
+        net1 (sym-api/fully-connected {:data oldfc :num-hidden 100 
:name"fc2"})]
+    (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"] 
(sym/list-arguments net1)))
+    (= (sym/list-arguments oldfc) (-> (sym/get-internals net1)
+                                      (sym/get "fc1_output")
+                                      (sym/list-arguments)))))
+
+(deftest test-infer-type
+  (let [data (sym/variable "data")
+        f32data (sym-api/cast {:data data :dtype "float32"})
+        fc1 (sym-api/fully-connected {:data f32data :num-hidden 128 
:name"fc1"})
+        mlp (sym-api/softmax-output {:data fc1 :name"softmax"})
+        [arg out aux] (sym/infer-type mlp {:data dtype/FLOAT64})]
+    (is (= [dtype/FLOAT64 dtype/FLOAT32 dtype/FLOAT32 dtype/FLOAT32] 
(util/buffer->vec arg)))
+    (is (= [dtype/FLOAT32] (util/buffer->vec out)))
+    (is (= [] (util/buffer->vec aux)))))

Reply via email to