You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/11 21:01:04 UTC
[incubator-mxnet] 01/02: change object detection prediction to be a
map
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch clojure-infer-predict-tweak
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 488843ee575e02ebf752bdfdc8891e14490994ee
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Jan 11 14:24:57 2019 -0500
change object detection prediction to be a map
---
.../src/org/apache/clojure_mxnet/infer.clj | 36 ++++++++++------
.../clojure_mxnet/infer/objectdetector_test.clj | 49 +++++++++++++++-------
2 files changed, 57 insertions(+), 28 deletions(-)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index 224a392..bc5090f 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -103,6 +103,12 @@
(s/def ::nil-or-int (s/nilable int?))
+(defn- format-predictions [predictions]
+ (mapv (fn [[c p]]
+ (let [[prob xmin ymin xmax ymax] (mapv float p)]
+ {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
+ predictions))
+
(extend-protocol AClassifier
WrappedClassifier
(classify
@@ -206,10 +212,12 @@
"Invalid object detector")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageObjectDetect (:object-detector wrapped-detector)
- image
- (util/->int-option topk)))))
+ (->> (.imageObjectDetect (:object-detector wrapped-detector)
+ image
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (first)
+ (format-predictions))))
(detect-objects-batch
([wrapped-detector images]
(detect-objects-batch wrapped-detector images nil))
@@ -217,10 +225,12 @@
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageBatchObjectDetect (:object-detector wrapped-detector)
- images
- (util/->int-option topk)))))
+ (->> (.imageBatchObjectDetect (:object-detector wrapped-detector)
+ images
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (first)
+ (format-predictions))))
(detect-objects-with-ndarrays
([wrapped-detector input-arrays]
(detect-objects-with-ndarrays wrapped-detector input-arrays nil))
@@ -230,10 +240,12 @@
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.objectDetectWithNDArray (:object-detector wrapped-detector)
- (util/vec->indexed-seq input-arrays)
- (util/->int-option topk))))))
+ (->> (.objectDetectWithNDArray (:object-detector wrapped-detector)
+ (util/vec->indexed-seq input-arrays)
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (first)
+ (format-predictions)))))
(defprotocol AInferenceFactory
(create-predictor [factory] [factory opts])
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 3a0e3d3..91d4f0e 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -17,11 +17,13 @@
(ns org.apache.clojure-mxnet.infer.objectdetector-test
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.image :as image]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
- [clojure.test :refer :all]))
+ [clojure.test :refer :all]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]))
(def model-dir "data/")
(def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
@@ -40,28 +42,43 @@
(deftest test-single-detection
(let [detector (create-detector)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
- [predictions-all] (infer/detect-objects detector image)
- [predictions] (infer/detect-objects detector image 5)]
+ predictions-all (infer/detect-objects detector image)
+ predictions (infer/detect-objects detector image 5)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
(is (some? predictions))
(is (= 5 (count predictions)))
(is (= 13 (count predictions-all)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(= 5 (count (second %))) predictions))
- (is (every? #(< 0 (first (second %)) 1) predictions))
- (is (= "cat" (first (first predictions))))))
+ (is (= "cat" class))
+ (is (< 0.8 prob))
+ (every? #(< 0 % 1) [x-min x-max y-min y-max])))
(deftest test-batch-detection
(let [detector (create-detector)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
batch-predictions-all (infer/detect-objects-batch detector image-batch)
- batch-predictions (infer/detect-objects-batch detector image-batch 5)
- predictions (first batch-predictions)]
- (is (some? batch-predictions))
- (is (= 13 (count (first batch-predictions-all))))
+ predictions (infer/detect-objects-batch detector image-batch 5)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+ (is (some? predictions))
+ (is (= 13 (count batch-predictions-all)))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(= 5 (count (second %))) predictions))
- (is (every? #(< 0 (first (second %)) 1) predictions))))
+ (is (= "cat" class))
+ (is (< 0.8 prob))
+ (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
+(deftest test-detection-with-ndarrays
+ (let [detector (create-detector)
+ image (-> (image/read-image "test/test-images/kitten.jpg" {:to-rbg true})
+ (image/resize-image 512 512)
+ (ndarray/transpose)
+ (ndarray/expand-dims 0)
+ (ndarray/cast dtype/FLOAT32))
+ predictions-all (infer/detect-objects-with-ndarrays detector [image])
+ predictions (infer/detect-objects-with-ndarrays detector [image] 1)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+ (is (some? predictions-all))
+ (is (= 1 (count predictions)))
+ (is (= "cat" class))
+ (is (< 0.8 prob))
+ (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+