gigasquid closed pull request #13864: [Clojure] package infer tweaks
URL: https://github.com/apache/incubator-mxnet/pull/13864
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
 
b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
index 4ec7ff7f149..6994b4fadc2 100644
--- 
a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
+++ 
b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
@@ -55,8 +55,8 @@
   "Print image classifier predictions for the given input file"
   [predictions]
   (println (apply str (repeat 80 "=")))
-  (doseq [[label probability] predictions]
-    (println (format "Class: %s Probability=%.8f" label probability)))
+  (doseq [p predictions]
+    (println p))
   (println (apply str (repeat 80 "="))))
 
 (defn classify-single-image
@@ -64,8 +64,8 @@
   [classifier input-image]
   (let [image (infer/load-image-from-file input-image)
         topk 5
-        [predictions] (infer/classify-image classifier image topk)]
-    predictions))
+        predictions (infer/classify-image classifier image topk)]
+    [predictions]))
 
 (defn classify-images-in-dir
   "Classify all jpg images in the directory"
@@ -78,12 +78,10 @@
                                 (filter #(re-matches #".*\.jpg$" (.getPath %)))
                                 (mapv #(.getPath %))
                                 (partition-all batch-size))]
-    (apply
-     concat
-     (for [image-files image-file-batches]
-       (let [image-batch (infer/load-image-paths image-files)
-             topk 5]
-         (infer/classify-image-batch classifier image-batch topk))))))
+    (apply concat (for [image-files image-file-batches]
+                    (let [image-batch (infer/load-image-paths image-files)
+                          topk 5]
+                      (infer/classify-image-batch classifier image-batch 
topk))))))
 
 (defn run-classifier
   "Runs an image classifier based on options provided"
@@ -98,6 +96,7 @@
                     factory {:contexts [(context/default-context)]})]
     (println "Classifying a single image")
     (print-predictions (classify-single-image classifier input-image))
+    (println "\n")
     (println "Classifying images in a directory")
     (doseq [predictions (classify-images-in-dir classifier input-dir)]
       (print-predictions predictions))))
diff --git 
a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
 
b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
index 5b3e08d134f..4b71f845dd5 100644
--- 
a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
+++ 
b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
@@ -43,27 +43,16 @@
 
 (deftest test-single-classification
   (let [classifier (create-classifier)
-        predictions (classify-single-image classifier image-file)]
+        [[predictions]] (classify-single-image classifier image-file)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))
-    (is (= ["n02123159 tiger cat"
-            "n02124075 Egyptian cat"
-            "n02123045 tabby, tabby cat"
-            "n02127052 lynx, catamount"
-            "n02128757 snow leopard, ounce, Panthera uncia"]
-           (map first predictions)))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
 
 (deftest test-batch-classification
   (let [classifier (create-classifier)
-        batch-predictions (classify-images-in-dir classifier image-dir)
-        predictions (first batch-predictions)]
-    (is (some? batch-predictions))
+        predictions (first (classify-images-in-dir classifier image-dir))]
+    (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
diff --git 
a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
 
b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
index 53172f0c8ca..5c30e5db63f 100644
--- 
a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
+++ 
b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
@@ -54,15 +54,15 @@
   "Print image detector predictions for the given input file"
   [predictions width height]
   (println (apply str (repeat 80 "=")))
-  (doseq [[label prob-and-bounds] predictions]
+  (doseq [{:keys [class prob x-min y-min x-max y-max]} predictions]
     (println (format
               "Class: %s Prob=%.5f Coords=(%.3f, %.3f, %.3f, %.3f)"
-              label
-              (aget prob-and-bounds 0)
-              (* (aget prob-and-bounds 1) width)
-              (* (aget prob-and-bounds 2) height)
-              (* (aget prob-and-bounds 3) width)
-              (* (aget prob-and-bounds 4) height))))
+              class
+              prob
+              (* x-min width)
+              (* y-min height)
+              (* x-max width)
+              (* y-max height))))
   (println (apply str (repeat 80 "="))))
 
 (defn detect-single-image
@@ -84,12 +84,10 @@
                                 (filter #(re-matches #".*\.jpg$" (.getPath %)))
                                 (mapv #(.getPath %))
                                 (partition-all batch-size))]
-    (apply
-     concat
-     (for [image-files image-file-batches]
-       (let [image-batch (infer/load-image-paths image-files)
-             topk 5]
-         (infer/detect-objects-batch detector image-batch topk))))))
+    (apply concat (for [image-files image-file-batches]
+                    (let [image-batch (infer/load-image-paths image-files)
+                          topk 5]
+                      (infer/detect-objects-batch detector image-batch 
topk))))))
 
 (defn run-detector
   "Runs an image detector based on options provided"
@@ -107,6 +105,7 @@
                   {:contexts [(context/default-context)]})]
     (println "Object detection on a single image")
     (print-predictions (detect-single-image detector input-image) width height)
+    (println "\n")
     (println "Object detection on images in a directory")
     (doseq [predictions (detect-images-in-dir detector input-dir)]
       (print-predictions predictions width height))))
diff --git 
a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
 
b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
index 90ed02f67a7..2b8ad951ae2 100644
--- 
a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
+++ 
b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
@@ -43,23 +43,23 @@
 
 (deftest test-single-detection
   (let [detector (create-detector)
-        predictions (detect-single-image detector image-file)]
+        predictions (detect-single-image detector image-file)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first 
predictions)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))
-    (is (= ["car" "bicycle" "dog" "bicycle" "person"]
-           (map first predictions)))))
+    (is (string? class))
+    (is (< 0.8 prob))
+    (is (every? #(< 0 % 1) [x-min x-max y-min y-max]))
+    (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class 
predictions))))))
 
 (deftest test-batch-detection
   (let [detector (create-detector)
         batch-predictions (detect-images-in-dir detector image-dir)
-        predictions (first batch-predictions)]
+        predictions (first batch-predictions)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first 
predictions)]
     (is (some? batch-predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))))
+    (is (string? class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])
+    (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class 
predictions))))))
diff --git 
a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
 
b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
index 498964128dd..05eb0add313 100644
--- 
a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
+++ 
b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
@@ -59,8 +59,8 @@
 (defn do-inference
   "Run inference using given predictor"
   [predictor image]
-  (let [[predictions] (infer/predict-with-ndarray predictor [image])]
-    predictions))
+  (let [predictions (infer/predict-with-ndarray predictor [image])]
+    (first predictions)))
 
 (defn postprocess
   [model-path-prefix predictions]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index 224a39275da..09edf15b428 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -22,7 +22,8 @@
             [org.apache.clojure-mxnet.io :as mx-io]
             [org.apache.clojure-mxnet.shape :as shape]
             [org.apache.clojure-mxnet.util :as util]
-            [clojure.spec.alpha :as s])
+            [clojure.spec.alpha :as s]
+            [org.apache.clojure-mxnet.shape :as mx-shape])
   (:import (java.awt.image BufferedImage)
            (org.apache.mxnet NDArray)
            (org.apache.mxnet.infer Classifier ImageClassifier
@@ -39,15 +40,26 @@
 (defrecord WrappedObjectDetector [object-detector])
 
 (s/def ::ndarray #(instance? NDArray %))
-(s/def ::float-array (s/and #(.isArray (class %)) #(every? float? %)))
-(s/def ::vec-of-float-arrays (s/coll-of ::float-array :kind vector?))
+(s/def ::number-array (s/coll-of number? :kind vector?))
+(s/def ::vvec-of-numbers (s/coll-of ::number-array :kind vector?))
 (s/def ::vec-of-ndarrays (s/coll-of ::ndarray :kind vector?))
+(s/def ::image #(instance? BufferedImage %))
+(s/def ::batch-images (s/coll-of ::image :kind vector?))
 
 (s/def ::wrapped-predictor (s/keys :req-un [::predictor]))
 (s/def ::wrapped-classifier (s/keys :req-un [::classifier]))
 (s/def ::wrapped-image-classifier (s/keys :req-un [::image-classifier]))
 (s/def ::wrapped-detector (s/keys :req-un [::object-detector]))
 
+(defn- format-detection-predictions [predictions]
+  (mapv (fn [[c p]]
+          (let [[prob xmin ymin xmax ymax] (mapv float p)]
+            {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max 
ymax}))
+        predictions))
+
+(defn- format-classification-predictions [predictions]
+  (mapv (fn [[c p]] {:class c :prob p}) predictions))
+
 (defprotocol APredictor
   (predict [wrapped-predictor inputs])
   (predict-with-ndarray [wrapped-predictor input-arrays]))
@@ -87,19 +99,20 @@
     [wrapped-predictor inputs]
     (util/validate! ::wrapped-predictor wrapped-predictor
                     "Invalid predictor")
-    (util/validate! ::vec-of-float-arrays inputs
+    (util/validate! ::vvec-of-numbers inputs
                     "Invalid inputs")
-    (util/coerce-return-recursive
-     (.predict (:predictor wrapped-predictor)
-               (util/vec->indexed-seq inputs))))
+    (->> (.predict (:predictor wrapped-predictor)
+                   (util/vec->indexed-seq (mapv float-array inputs)))
+         (util/coerce-return-recursive)
+         (mapv #(mapv float %))))
   (predict-with-ndarray [wrapped-predictor input-arrays]
     (util/validate! ::wrapped-predictor wrapped-predictor
                     "Invalid predictor")
     (util/validate! ::vec-of-ndarrays input-arrays
                     "Invalid input arrays")
-    (util/coerce-return-recursive
-     (.predictWithNDArray (:predictor wrapped-predictor)
-                          (util/vec->indexed-seq input-arrays)))))
+    (-> (.predictWithNDArray (:predictor wrapped-predictor)
+                             (util/vec->indexed-seq input-arrays))
+        (util/coerce-return-recursive))))
 
 (s/def ::nil-or-int (s/nilable int?))
 
@@ -111,13 +124,14 @@
     ([wrapped-classifier inputs topk]
      (util/validate! ::wrapped-classifier wrapped-classifier
                      "Invalid classifier")
-     (util/validate! ::vec-of-float-arrays inputs
+     (util/validate! ::vvec-of-numbers inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classify (:classifier wrapped-classifier)
-                 (util/vec->indexed-seq inputs)
-                 (util/->int-option topk)))))
+     (->> (.classify (:classifier wrapped-classifier)
+                     (util/vec->indexed-seq (mapv float-array inputs))
+                     (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (format-classification-predictions))))
   (classify-with-ndarray
     ([wrapped-classifier inputs]
      (classify-with-ndarray wrapped-classifier inputs nil))
@@ -127,10 +141,11 @@
      (util/validate! ::vec-of-ndarrays inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classifyWithNDArray (:classifier wrapped-classifier)
-                            (util/vec->indexed-seq inputs)
-                           (util/->int-option topk)))))
+     (->> (.classifyWithNDArray (:classifier wrapped-classifier)
+                                (util/vec->indexed-seq inputs)
+                                (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-classification-predictions))))
   WrappedImageClassifier
   (classify
     ([wrapped-image-classifier inputs]
@@ -138,13 +153,14 @@
     ([wrapped-image-classifier inputs topk]
      (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                      "Invalid classifier")
-     (util/validate! ::vec-of-float-arrays inputs
+     (util/validate! ::vvec-of-numbers inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classify (:image-classifier wrapped-image-classifier)
-                 (util/vec->indexed-seq inputs)
-                 (util/->int-option topk)))))
+     (->> (.classify (:image-classifier wrapped-image-classifier)
+                     (util/vec->indexed-seq (mapv float-array inputs))
+                     (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (format-classification-predictions))))
   (classify-with-ndarray
     ([wrapped-image-classifier inputs]
      (classify-with-ndarray wrapped-image-classifier inputs nil))
@@ -154,10 +170,11 @@
     (util/validate! ::vec-of-ndarrays inputs
                     "Invalid inputs")
     (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
-                           (util/vec->indexed-seq inputs)
-                           (util/->int-option topk))))))
+    (->> (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
+                               (util/vec->indexed-seq inputs)
+                               (util/->int-option topk))
+         (util/coerce-return-recursive)
+         (mapv format-classification-predictions)))))
 
 (s/def ::image #(instance? BufferedImage %))
 (s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 
dtype/FLOAT64})
@@ -175,11 +192,12 @@
      (util/validate! ::image image "Invalid image")
      (util/validate! ::nil-or-int topk "Invalid top-K")
      (util/validate! ::dtype dtype "Invalid dtype")
-     (util/coerce-return-recursive
-      (.classifyImage (:image-classifier wrapped-image-classifier)
-                      image
-                      (util/->int-option topk)
-                      dtype))))
+     (->> (.classifyImage (:image-classifier wrapped-image-classifier)
+                         image
+                         (util/->int-option topk)
+                         dtype)
+          (util/coerce-return-recursive)
+          (mapv format-classification-predictions))))
   (classify-image-batch
     ([wrapped-image-classifier images]
      (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
@@ -188,13 +206,15 @@
     ([wrapped-image-classifier images topk dtype]
      (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                      "Invalid classifier")
+     (util/validate! ::batch-images images "Invalid Batch Images")
      (util/validate! ::nil-or-int topk "Invalid top-K")
      (util/validate! ::dtype dtype "Invalid dtype")
-     (util/coerce-return-recursive
-      (.classifyImageBatch (:image-classifier wrapped-image-classifier)
-                           images
-                           (util/->int-option topk)
-                           dtype)))))
+     (->> (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+                              (util/vec->indexed-seq images)
+                              (util/->int-option topk)
+                              dtype)
+         (util/coerce-return-recursive)
+         (mapv format-classification-predictions)))))
 
 (extend-protocol AObjectDetector
   WrappedObjectDetector
@@ -206,10 +226,11 @@
                     "Invalid object detector")
      (util/validate! ::image image "Invalid image")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.imageObjectDetect (:object-detector wrapped-detector)
-                          image
-                          (util/->int-option topk)))))
+     (->> (.imageObjectDetect (:object-detector wrapped-detector)
+                              image
+                              (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-detection-predictions))))
   (detect-objects-batch
     ([wrapped-detector images]
      (detect-objects-batch wrapped-detector images nil))
@@ -217,10 +238,12 @@
      (util/validate! ::wrapped-detector wrapped-detector
                      "Invalid object detector")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.imageBatchObjectDetect (:object-detector wrapped-detector)
-                               images
-                               (util/->int-option topk)))))
+     (util/validate! ::batch-images images "Invalid Batch Images")
+     (->> (.imageBatchObjectDetect (:object-detector wrapped-detector)
+                                   (util/vec->indexed-seq images)
+                                   (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-detection-predictions))))
   (detect-objects-with-ndarrays
     ([wrapped-detector input-arrays]
      (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
@@ -230,10 +253,11 @@
      (util/validate! ::vec-of-ndarrays input-arrays
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.objectDetectWithNDArray (:object-detector wrapped-detector)
-                                (util/vec->indexed-seq input-arrays)
-                                (util/->int-option topk))))))
+     (->> (.objectDetectWithNDArray (:object-detector wrapped-detector)
+                                    (util/vec->indexed-seq input-arrays)
+                                    (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-detection-predictions)))))
 
 (defprotocol AInferenceFactory
   (create-predictor [factory] [factory opts])
@@ -324,10 +348,12 @@
 
 (defn buffered-image-to-pixels
   "Convert input BufferedImage to NDArray of input shape"
-  [image input-shape-vec]
-  (util/validate! ::image image "Invalid image")
-  (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
-  (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) 
dtype/FLOAT32))
+  ([image input-shape-vec]
+   (buffered-image-to-pixels image input-shape-vec dtype/FLOAT32))
+  ([image input-shape-vec dtype]
+   (util/validate! ::image image "Invalid image")
+   (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
+   (ImageClassifier/bufferedImageToPixels image (shape/->shape 
input-shape-vec) dtype)))
 
 (s/def ::image-path string?)
 (s/def ::image-paths (s/coll-of ::image-path))
@@ -342,4 +368,5 @@
   "Loads images from a list of file names"
   [image-paths]
   (util/validate! ::image-paths image-paths "Invalid image paths")
-  (ImageClassifier/loadInputBatch (util/convert-vector image-paths)))
+  (util/scala-vector->vec
+   (ImageClassifier/loadInputBatch (util/convert-vector image-paths))))
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index b459b06132b..e3935c31e34 100644
--- 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -19,6 +19,7 @@
             [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.infer :as infer]
             [org.apache.clojure-mxnet.layout :as layout]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
             [clojure.test :refer :all]))
@@ -45,32 +46,83 @@
         [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
     (is (= 1000 (count predictions-all)))
     (is (= 10 (count predictions-with-default-dtype)))
-    (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))
-    (is (= ["n02123159 tiger cat"
-            "n02124075 Egyptian cat"
-            "n02123045 tabby, tabby cat"
-            "n02127052 lynx, catamount"
-            "n02128757 snow leopard, ounce, Panthera uncia"]
-           (map first predictions)))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
 
 (deftest test-batch-classification
   (let [classifier (create-classifier)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              
"test/test-images/Pug-Cookie.jpg"])
-        batch-predictions-all (infer/classify-image-batch classifier 
image-batch)
-        batch-predictions-with-default-dtype (infer/classify-image-batch 
classifier image-batch 10)
-        batch-predictions (infer/classify-image-batch classifier image-batch 5 
dtype/FLOAT32)
-        predictions (first batch-predictions)]
-    (is (= 1000 (count (first batch-predictions-all))))
-    (is (= 10 (count (first batch-predictions-with-default-dtype))))
-    (is (some? batch-predictions))
+        [batch-predictions-all] (infer/classify-image-batch classifier 
image-batch)
+        [batch-predictions-with-default-dtype] (infer/classify-image-batch 
classifier image-batch 10)
+        [predictions] (infer/classify-image-batch classifier image-batch 5 
dtype/FLOAT32)]
+    (is (= 1000 (count batch-predictions-all)))
+    (is (= 10 (count batch-predictions-with-default-dtype)))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classification-with-ndarray
+  (let [classifier (create-classifier)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        [predictions-all] (infer/classify-with-ndarray classifier [image])
+        [predictions] (infer/classify-with-ndarray classifier [image] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classify
+  (let [classifier (create-classifier)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify classifier [(ndarray/->vec image)])
+        predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-classification-with-ndarray
+  (let [descriptors [{:name "data"
+                      :shape [1 3 224 224]
+                      :layout layout/NCHW
+                      :dtype dtype/FLOAT32}]
+        factory (infer/model-factory model-path-prefix descriptors)
+        classifier (infer/create-classifier factory)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        [predictions-all] (infer/classify-with-ndarray classifier [image])
+        [predictions] (infer/classify-with-ndarray classifier [image] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-single-classify
+  (let [descriptors [{:name "data"
+                      :shape [1 3 224 224]
+                      :layout layout/NCHW
+                      :dtype dtype/FLOAT32}]
+        factory (infer/model-factory model-path-prefix descriptors)
+        classifier (infer/create-classifier factory)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify classifier [(ndarray/->vec image)])
+        predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 3a0e3d30a1d..e2b9579c700 100644
--- 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -21,7 +21,8 @@
             [org.apache.clojure-mxnet.layout :as layout]
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
-            [clojure.test :refer :all]))
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]))
 
 (def model-dir "data/")
 (def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
@@ -41,27 +42,41 @@
   (let [detector (create-detector)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
         [predictions-all] (infer/detect-objects detector image)
-        [predictions] (infer/detect-objects detector image 5)]
+        [predictions] (infer/detect-objects detector image 5)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first 
predictions)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
     (is (= 13 (count predictions-all)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))
-    (is (= "cat" (first (first predictions))))))
+    (is (= "cat" class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])))
 
 (deftest test-batch-detection
   (let [detector (create-detector)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              
"test/test-images/Pug-Cookie.jpg"])
-        batch-predictions-all (infer/detect-objects-batch detector image-batch)
-        batch-predictions (infer/detect-objects-batch detector image-batch 5)
-        predictions (first batch-predictions)]
-    (is (some? batch-predictions))
-    (is (= 13 (count (first batch-predictions-all))))
+        [batch-predictions-all] (infer/detect-objects-batch detector 
image-batch)
+        [predictions] (infer/detect-objects-batch detector image-batch 5)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first 
predictions)]
+    (is (some? predictions))
+    (is (= 13 (count batch-predictions-all)))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))))
+    (is (= "cat" class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
+(deftest test-detection-with-ndarrays
+  (let [detector (create-detector)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 512 512)
+                  (infer/buffered-image-to-pixels [3 512 512] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        [predictions-all] (infer/detect-objects-with-ndarrays detector [image])
+        [predictions] (infer/detect-objects-with-ndarrays detector [image] 1)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first 
predictions)]
+        (is (some? predictions-all))
+        (is (= 1 (count predictions)))
+        (is (= "cat" class))
+        (is (< 0.8 prob))
+        (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
index 0e7532bc225..e1526be61fb 100644
--- 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
+++ 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
@@ -24,7 +24,8 @@
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
             [clojure.string :refer [split]]
-            [clojure.test :refer :all]))
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.util :as util]))
 
 (def model-dir "data/")
 (def model-path-prefix (str model-dir "resnet-18/resnet-18"))
@@ -42,6 +43,22 @@
         factory (infer/model-factory model-path-prefix descriptors)]
     (infer/create-predictor factory)))
 
+(deftest predictor-test-with-ndarray
+  (let [predictor (create-predictor)
+        image-ndarray (-> "test/test-images/kitten.jpg"
+                           infer/load-image-from-file
+                           (infer/reshape-image width height)
+                           (infer/buffered-image-to-pixels [3 width height])
+                           (ndarray/expand-dims 0))
+        predictions (infer/predict-with-ndarray predictor [image-ndarray])
+        synset-file (-> (io/file model-path-prefix)
+                        (.getParent)
+                        (io/file "synset.txt"))
+        synset-names (split (slurp synset-file) #"\n")
+        [best-index] (ndarray/->int-vec (ndarray/argmax (first predictions) 1))
+        best-prediction (synset-names best-index)]
+    (is (= "n02123159 tiger cat" best-prediction))))
+
 (deftest predictor-test
   (let [predictor (create-predictor)
         image-ndarray (-> "test/test-images/kitten.jpg"
@@ -49,11 +66,12 @@
                           (infer/reshape-image width height)
                           (infer/buffered-image-to-pixels [3 width height])
                           (ndarray/expand-dims 0))
-        [predictions] (infer/predict-with-ndarray predictor [image-ndarray])
+        predictions (infer/predict predictor [(ndarray/->vec image-ndarray)])
         synset-file (-> (io/file model-path-prefix)
                         (.getParent)
                         (io/file "synset.txt"))
         synset-names (split (slurp synset-file) #"\n")
-        [best-index] (ndarray/->int-vec (ndarray/argmax predictions 1))
+        ndarray-preds (ndarray/array (first predictions) [1 1000])
+        [best-index] (ndarray/->int-vec (ndarray/argmax ndarray-preds 1))
         best-prediction (synset-names best-index)]
     (is (= "n02123159 tiger cat" best-prediction))))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to