You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/11 21:01:04 UTC

[incubator-mxnet] 01/02: change object detection prediction to be a map

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-infer-predict-tweak
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 488843ee575e02ebf752bdfdc8891e14490994ee
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Jan 11 14:24:57 2019 -0500

    change object detection prediction to be a map
---
 .../src/org/apache/clojure_mxnet/infer.clj         | 36 ++++++++++------
 .../clojure_mxnet/infer/objectdetector_test.clj    | 49 +++++++++++++++-------
 2 files changed, 57 insertions(+), 28 deletions(-)

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index 224a392..bc5090f 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -103,6 +103,12 @@
 
 (s/def ::nil-or-int (s/nilable int?))
 
+(defn- format-predictions [predictions]
+  (mapv (fn [[c p]]
+          (let [[prob xmin ymin xmax ymax] (mapv float p)]
+            {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
+        predictions))
+
 (extend-protocol AClassifier
   WrappedClassifier
   (classify
@@ -206,10 +212,12 @@
                     "Invalid object detector")
      (util/validate! ::image image "Invalid image")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.imageObjectDetect (:object-detector wrapped-detector)
-                          image
-                          (util/->int-option topk)))))
+     (->> (.imageObjectDetect (:object-detector wrapped-detector)
+                              image
+                              (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (first)
+          (format-predictions))))
   (detect-objects-batch
     ([wrapped-detector images]
      (detect-objects-batch wrapped-detector images nil))
@@ -217,10 +225,12 @@
      (util/validate! ::wrapped-detector wrapped-detector
                      "Invalid object detector")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.imageBatchObjectDetect (:object-detector wrapped-detector)
-                               images
-                               (util/->int-option topk)))))
+     (->> (.imageBatchObjectDetect (:object-detector wrapped-detector)
+                                   images
+                                   (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (first)
+          (format-predictions))))
   (detect-objects-with-ndarrays
     ([wrapped-detector input-arrays]
      (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
@@ -230,10 +240,12 @@
      (util/validate! ::vec-of-ndarrays input-arrays
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.objectDetectWithNDArray (:object-detector wrapped-detector)
-                                (util/vec->indexed-seq input-arrays)
-                                (util/->int-option topk))))))
+     (->> (.objectDetectWithNDArray (:object-detector wrapped-detector)
+                                    (util/vec->indexed-seq input-arrays)
+                                    (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (first)
+          (format-predictions)))))
 
 (defprotocol AInferenceFactory
   (create-predictor [factory] [factory opts])
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 3a0e3d3..91d4f0e 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -17,11 +17,13 @@
 (ns org.apache.clojure-mxnet.infer.objectdetector-test
   (:require [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.image :as image]
             [org.apache.clojure-mxnet.infer :as infer]
             [org.apache.clojure-mxnet.layout :as layout]
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
-            [clojure.test :refer :all]))
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]))
 
 (def model-dir "data/")
 (def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
@@ -40,28 +42,43 @@
 (deftest test-single-detection
   (let [detector (create-detector)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
-        [predictions-all] (infer/detect-objects detector image)
-        [predictions] (infer/detect-objects detector image 5)]
+        predictions-all (infer/detect-objects detector image)
+        predictions (infer/detect-objects detector image 5)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
     (is (= 13 (count predictions-all)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))
-    (is (= "cat" (first (first predictions))))))
+    (is (= "cat" class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])))
 
 (deftest test-batch-detection
   (let [detector (create-detector)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              "test/test-images/Pug-Cookie.jpg"])
         batch-predictions-all (infer/detect-objects-batch detector image-batch)
-        batch-predictions (infer/detect-objects-batch detector image-batch 5)
-        predictions (first batch-predictions)]
-    (is (some? batch-predictions))
-    (is (= 13 (count (first batch-predictions-all))))
+        predictions (infer/detect-objects-batch detector image-batch 5)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+    (is (some? predictions))
+    (is (= 13 (count batch-predictions-all)))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))))
+    (is (= "cat" class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
+(deftest test-detection-with-ndarrays
+  (let [detector (create-detector)
+        image     (-> (image/read-image "test/test-images/kitten.jpg" {:to-rbg true})
+                      (image/resize-image 512 512)
+                      (ndarray/transpose)
+                      (ndarray/expand-dims 0)
+                      (ndarray/cast dtype/FLOAT32))
+        predictions-all (infer/detect-objects-with-ndarrays detector [image])
+        predictions (infer/detect-objects-with-ndarrays detector [image] 1)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+        (is (some? predictions-all))
+        (is (= 1 (count predictions)))
+        (is (= "cat" class))
+        (is (< 0.8 prob))
+        (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+