This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-infer-predict-tweak in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit e8e55a90b11686219f260b7ef62e0792b5e79c2c Author: gigasquid <[email protected]> AuthorDate: Fri Jan 11 16:00:20 2019 -0500 change predictions to a map for image-classifiers --- .../src/org/apache/clojure_mxnet/infer.clj | 35 +++++++++++++--------- .../clojure_mxnet/infer/imageclassifier_test.clj | 34 +++++++-------------- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj index bc5090f..801c717 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj @@ -103,12 +103,15 @@ (s/def ::nil-or-int (s/nilable int?)) -(defn- format-predictions [predictions] +(defn- format-detection-predictions [predictions] (mapv (fn [[c p]] (let [[prob xmin ymin xmax ymax] (mapv float p)] {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax})) predictions)) +(defn- format-classification-predictions [predictions] + (mapv (fn [[c p]] {:class c :prob p}) predictions)) + (extend-protocol AClassifier WrappedClassifier (classify @@ -181,11 +184,13 @@ (util/validate! ::image image "Invalid image") (util/validate! ::nil-or-int topk "Invalid top-K") (util/validate! ::dtype dtype "Invalid dtype") - (util/coerce-return-recursive - (.classifyImage (:image-classifier wrapped-image-classifier) - image - (util/->int-option topk) - dtype)))) + (-> (.classifyImage (:image-classifier wrapped-image-classifier) + image + (util/->int-option topk) + dtype) + (util/coerce-return-recursive) + (first) + (format-classification-predictions)))) (classify-image-batch ([wrapped-image-classifier images] (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32)) @@ -196,11 +201,13 @@ "Invalid classifier") (util/validate! ::nil-or-int topk "Invalid top-K") (util/validate! ::dtype dtype "Invalid dtype") - (util/coerce-return-recursive - (.classifyImageBatch (:image-classifier wrapped-image-classifier) - images - (util/->int-option topk) - dtype))))) + (-> (.classifyImageBatch (:image-classifier wrapped-image-classifier) + images + (util/->int-option topk) + dtype) + (util/coerce-return-recursive) + (first) + (format-classification-predictions))))) (extend-protocol AObjectDetector WrappedObjectDetector @@ -217,7 +224,7 @@ (util/->int-option topk)) (util/coerce-return-recursive) (first) - (format-predictions)))) + (format-detection-predictions)))) (detect-objects-batch ([wrapped-detector images] (detect-objects-batch wrapped-detector images nil)) @@ -230,7 +237,7 @@ (util/->int-option topk)) (util/coerce-return-recursive) (first) - (format-predictions)))) + (format-detection-predictions)))) (detect-objects-with-ndarrays ([wrapped-detector input-arrays] (detect-objects-with-ndarrays wrapped-detector input-arrays nil)) @@ -245,7 +252,7 @@ (util/->int-option topk)) (util/coerce-return-recursive) (first) - (format-predictions))))) + (format-detection-predictions))))) (defprotocol AInferenceFactory (create-predictor [factory] [factory opts]) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj index b459b06..448a52f 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj @@ -40,23 +40,15 @@ (deftest test-single-classification (let [classifier (create-classifier) image (infer/load-image-from-file "test/test-images/kitten.jpg") - [predictions-all] (infer/classify-image classifier image) - [predictions-with-default-dtype] (infer/classify-image classifier image 10) - [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)] + predictions-all (infer/classify-image classifier image) + predictions-with-default-dtype (infer/classify-image classifier image 10) + predictions (infer/classify-image classifier image 5 dtype/FLOAT32)] + predictions (is (= 1000 (count predictions-all))) (is (= 10 (count predictions-with-default-dtype))) - (is (some? predictions)) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(float? (second %)) predictions)) - (is (every? #(< 0 (second %) 1) predictions)) - (is (= ["n02123159 tiger cat" - "n02124075 Egyptian cat" - "n02123045 tabby, tabby cat" - "n02127052 lynx, catamount" - "n02128757 snow leopard, ounce, Panthera uncia"] - (map first predictions))))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) (deftest test-batch-classification (let [classifier (create-classifier) @@ -64,13 +56,9 @@ "test/test-images/Pug-Cookie.jpg"]) batch-predictions-all (infer/classify-image-batch classifier image-batch) batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10) - batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32) - predictions (first batch-predictions)] - (is (= 1000 (count (first batch-predictions-all)))) - (is (= 10 (count (first batch-predictions-with-default-dtype)))) - (is (some? batch-predictions)) + predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)] + (is (= 1000 (count batch-predictions-all))) + (is (= 10 (count batch-predictions-with-default-dtype))) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(float? (second %)) predictions)) - (is (every? #(< 0 (second %) 1) predictions)))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1)))))
