This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-infer-predict-tweak
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit e8e55a90b11686219f260b7ef62e0792b5e79c2c
Author: gigasquid <[email protected]>
AuthorDate: Fri Jan 11 16:00:20 2019 -0500

    change predictions to a map for image-classifiers
---
 .../src/org/apache/clojure_mxnet/infer.clj         | 35 +++++++++++++---------
 .../clojure_mxnet/infer/imageclassifier_test.clj   | 34 +++++++--------------
 2 files changed, 32 insertions(+), 37 deletions(-)

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index bc5090f..801c717 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -103,12 +103,15 @@
 
 (s/def ::nil-or-int (s/nilable int?))
 
-(defn- format-predictions [predictions]
+(defn- format-detection-predictions [predictions]
   (mapv (fn [[c p]]
           (let [[prob xmin ymin xmax ymax] (mapv float p)]
             {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max 
ymax}))
         predictions))
 
+(defn- format-classification-predictions [predictions]
+  (mapv (fn [[c p]] {:class c :prob p}) predictions))
+
 (extend-protocol AClassifier
   WrappedClassifier
   (classify
@@ -181,11 +184,13 @@
      (util/validate! ::image image "Invalid image")
      (util/validate! ::nil-or-int topk "Invalid top-K")
      (util/validate! ::dtype dtype "Invalid dtype")
-     (util/coerce-return-recursive
-      (.classifyImage (:image-classifier wrapped-image-classifier)
-                      image
-                      (util/->int-option topk)
-                      dtype))))
+     (-> (.classifyImage (:image-classifier wrapped-image-classifier)
+                         image
+                         (util/->int-option topk)
+                         dtype)
+         (util/coerce-return-recursive)
+         (first)
+         (format-classification-predictions))))
   (classify-image-batch
     ([wrapped-image-classifier images]
      (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
@@ -196,11 +201,13 @@
                      "Invalid classifier")
      (util/validate! ::nil-or-int topk "Invalid top-K")
      (util/validate! ::dtype dtype "Invalid dtype")
-     (util/coerce-return-recursive
-      (.classifyImageBatch (:image-classifier wrapped-image-classifier)
-                           images
-                           (util/->int-option topk)
-                           dtype)))))
+     (-> (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+                              images
+                              (util/->int-option topk)
+                              dtype)
+         (util/coerce-return-recursive)
+         (first)
+         (format-classification-predictions)))))
 
 (extend-protocol AObjectDetector
   WrappedObjectDetector
@@ -217,7 +224,7 @@
                               (util/->int-option topk))
           (util/coerce-return-recursive)
           (first)
-          (format-predictions))))
+          (format-detection-predictions))))
   (detect-objects-batch
     ([wrapped-detector images]
      (detect-objects-batch wrapped-detector images nil))
@@ -230,7 +237,7 @@
                                    (util/->int-option topk))
           (util/coerce-return-recursive)
           (first)
-          (format-predictions))))
+          (format-detection-predictions))))
   (detect-objects-with-ndarrays
     ([wrapped-detector input-arrays]
      (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
@@ -245,7 +252,7 @@
                                     (util/->int-option topk))
           (util/coerce-return-recursive)
           (first)
-          (format-predictions)))))
+          (format-detection-predictions)))))
 
 (defprotocol AInferenceFactory
   (create-predictor [factory] [factory opts])
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index b459b06..448a52f 100644
--- 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -40,23 +40,15 @@
 (deftest test-single-classification
   (let [classifier (create-classifier)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
-        [predictions-all] (infer/classify-image classifier image)
-        [predictions-with-default-dtype] (infer/classify-image classifier 
image 10)
-        [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
+        predictions-all (infer/classify-image classifier image)
+        predictions-with-default-dtype (infer/classify-image classifier image 
10)
+        predictions (infer/classify-image classifier image 5 dtype/FLOAT32)]
+    predictions
     (is (= 1000 (count predictions-all)))
     (is (= 10 (count predictions-with-default-dtype)))
-    (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))
-    (is (= ["n02123159 tiger cat"
-            "n02124075 Egyptian cat"
-            "n02123045 tabby, tabby cat"
-            "n02127052 lynx, catamount"
-            "n02128757 snow leopard, ounce, Panthera uncia"]
-           (map first predictions)))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
 
 (deftest test-batch-classification
   (let [classifier (create-classifier)
@@ -64,13 +56,9 @@
                                              
"test/test-images/Pug-Cookie.jpg"])
         batch-predictions-all (infer/classify-image-batch classifier 
image-batch)
         batch-predictions-with-default-dtype (infer/classify-image-batch 
classifier image-batch 10)
-        batch-predictions (infer/classify-image-batch classifier image-batch 5 
dtype/FLOAT32)
-        predictions (first batch-predictions)]
-    (is (= 1000 (count (first batch-predictions-all))))
-    (is (= 10 (count (first batch-predictions-with-default-dtype))))
-    (is (some? batch-predictions))
+        predictions (infer/classify-image-batch classifier image-batch 5 
dtype/FLOAT32)]
+    (is (= 1000 (count batch-predictions-all)))
+    (is (= 10 (count batch-predictions-with-default-dtype)))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))

Reply via email to