This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-infer-predict-tweak in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 488843ee575e02ebf752bdfdc8891e14490994ee Author: gigasquid <[email protected]> AuthorDate: Fri Jan 11 14:24:57 2019 -0500 change object detection prediction to be a map --- .../src/org/apache/clojure_mxnet/infer.clj | 36 ++++++++++------ .../clojure_mxnet/infer/objectdetector_test.clj | 49 +++++++++++++++------- 2 files changed, 57 insertions(+), 28 deletions(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj index 224a392..bc5090f 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj @@ -103,6 +103,12 @@ (s/def ::nil-or-int (s/nilable int?)) +(defn- format-predictions [predictions] + (mapv (fn [[c p]] + (let [[prob xmin ymin xmax ymax] (mapv float p)] + {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax})) + predictions)) + (extend-protocol AClassifier WrappedClassifier (classify @@ -206,10 +212,12 @@ "Invalid object detector") (util/validate! ::image image "Invalid image") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.imageObjectDetect (:object-detector wrapped-detector) - image - (util/->int-option topk))))) + (->> (.imageObjectDetect (:object-detector wrapped-detector) + image + (util/->int-option topk)) + (util/coerce-return-recursive) + (first) + (format-predictions)))) (detect-objects-batch ([wrapped-detector images] (detect-objects-batch wrapped-detector images nil)) @@ -217,10 +225,12 @@ (util/validate! ::wrapped-detector wrapped-detector "Invalid object detector") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.imageBatchObjectDetect (:object-detector wrapped-detector) - images - (util/->int-option topk))))) + (->> (.imageBatchObjectDetect (:object-detector wrapped-detector) + images + (util/->int-option topk)) + (util/coerce-return-recursive) + (first) + (format-predictions)))) (detect-objects-with-ndarrays ([wrapped-detector input-arrays] (detect-objects-with-ndarrays wrapped-detector input-arrays nil)) @@ -230,10 +240,12 @@ (util/validate! ::vec-of-ndarrays input-arrays "Invalid inputs") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.objectDetectWithNDArray (:object-detector wrapped-detector) - (util/vec->indexed-seq input-arrays) - (util/->int-option topk)))))) + (->> (.objectDetectWithNDArray (:object-detector wrapped-detector) + (util/vec->indexed-seq input-arrays) + (util/->int-option topk)) + (util/coerce-return-recursive) + (first) + (format-predictions))))) (defprotocol AInferenceFactory (create-predictor [factory] [factory opts]) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj index 3a0e3d3..91d4f0e 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj @@ -17,11 +17,13 @@ (ns org.apache.clojure-mxnet.infer.objectdetector-test (:require [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.image :as image] [org.apache.clojure-mxnet.infer :as infer] [org.apache.clojure-mxnet.layout :as layout] [clojure.java.io :as io] [clojure.java.shell :refer [sh]] - [clojure.test :refer :all])) + [clojure.test :refer :all] + [org.apache.clojure-mxnet.ndarray :as ndarray])) (def model-dir "data/") (def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model")) @@ -40,28 +42,43 @@ (deftest test-single-detection (let [detector (create-detector) image (infer/load-image-from-file "test/test-images/kitten.jpg") - [predictions-all] (infer/detect-objects detector image) - [predictions] (infer/detect-objects detector image 5)] + predictions-all (infer/detect-objects detector image) + predictions (infer/detect-objects detector image 5) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] (is (some? predictions)) (is (= 5 (count predictions))) (is (= 13 (count predictions-all))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(= 5 (count (second %))) predictions)) - (is (every? #(< 0 (first (second %)) 1) predictions)) - (is (= "cat" (first (first predictions)))))) + (is (= "cat" class)) + (is (< 0.8 prob)) + (every? #(< 0 % 1) [x-min x-max y-min y-max]))) (deftest test-batch-detection (let [detector (create-detector) image-batch (infer/load-image-paths ["test/test-images/kitten.jpg" "test/test-images/Pug-Cookie.jpg"]) batch-predictions-all (infer/detect-objects-batch detector image-batch) - batch-predictions (infer/detect-objects-batch detector image-batch 5) - predictions (first batch-predictions)] - (is (some? batch-predictions)) - (is (= 13 (count (first batch-predictions-all)))) + predictions (infer/detect-objects-batch detector image-batch 5) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] + (is (some? predictions)) + (is (= 13 (count batch-predictions-all))) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(= 5 (count (second %))) predictions)) - (is (every? #(< 0 (first (second %)) 1) predictions)))) + (is (= "cat" class)) + (is (< 0.8 prob)) + (every? #(< 0 % 1) [x-min x-max y-min y-max]))) + +(deftest test-detection-with-ndarrays + (let [detector (create-detector) + image (-> (image/read-image "test/test-images/kitten.jpg" {:to-rbg true}) + (image/resize-image 512 512) + (ndarray/transpose) + (ndarray/expand-dims 0) + (ndarray/cast dtype/FLOAT32)) + predictions-all (infer/detect-objects-with-ndarrays detector [image]) + predictions (infer/detect-objects-with-ndarrays detector [image] 1) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] + (is (some? predictions-all)) + (is (= 1 (count predictions))) + (is (= "cat" class)) + (is (< 0.8 prob)) + (every? #(< 0 % 1) [x-min x-max y-min y-max]))) +
