You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/11 22:21:56 UTC

[incubator-mxnet] branch clojure-infer-predict-tweak updated: change return types of the classifiers to be a map - add tests for base classifier and with-ndarray as well

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-infer-predict-tweak
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/clojure-infer-predict-tweak by this push:
     new 3ec2817  change return types of the classifiers to be a map - add tests for base classifier and with-ndarray as well
3ec2817 is described below

commit 3ec28174192abf37893ea26f96ef83dc8c8d3e74
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Jan 11 17:21:31 2019 -0500

    change return types of the classifiers to be a map
    - add tests for base classifier and with-ndarray as well
---
 .../src/org/apache/clojure_mxnet/infer.clj         | 85 ++++++++++++++--------
 .../clojure_mxnet/infer/imageclassifier_test.clj   | 65 +++++++++++++++++
 .../clojure_mxnet/infer/objectdetector_test.clj    | 10 +--
 3 files changed, 123 insertions(+), 37 deletions(-)

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index 801c717..ced8f10 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -22,7 +22,8 @@
             [org.apache.clojure-mxnet.io :as mx-io]
             [org.apache.clojure-mxnet.shape :as shape]
             [org.apache.clojure-mxnet.util :as util]
-            [clojure.spec.alpha :as s])
+            [clojure.spec.alpha :as s]
+            [org.apache.clojure-mxnet.shape :as mx-shape])
   (:import (java.awt.image BufferedImage)
            (org.apache.mxnet NDArray)
            (org.apache.mxnet.infer Classifier ImageClassifier
@@ -39,15 +40,26 @@
 (defrecord WrappedObjectDetector [object-detector])
 
 (s/def ::ndarray #(instance? NDArray %))
-(s/def ::float-array (s/and #(.isArray (class %)) #(every? float? %)))
-(s/def ::vec-of-float-arrays (s/coll-of ::float-array :kind vector?))
+(s/def ::number-array (s/coll-of number? :kind vector?))
+(s/def ::vvec-of-numbers (s/coll-of ::number-array :kind vector?))
 (s/def ::vec-of-ndarrays (s/coll-of ::ndarray :kind vector?))
+(s/def ::image #(instance? BufferedImage %))
+(s/def ::batch-images (s/coll-of ::image :kind vector?))
 
 (s/def ::wrapped-predictor (s/keys :req-un [::predictor]))
 (s/def ::wrapped-classifier (s/keys :req-un [::classifier]))
 (s/def ::wrapped-image-classifier (s/keys :req-un [::image-classifier]))
 (s/def ::wrapped-detector (s/keys :req-un [::object-detector]))
 
+(defn- format-detection-predictions [predictions]
+  (mapv (fn [[c p]]
+          (let [[prob xmin ymin xmax ymax] (mapv float p)]
+            {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
+        predictions))
+
+(defn- format-classification-predictions [predictions]
+  (mapv (fn [[c p]] {:class c :prob p}) predictions))
+
 (defprotocol APredictor
   (predict [wrapped-predictor inputs])
   (predict-with-ndarray [wrapped-predictor input-arrays]))
@@ -103,15 +115,6 @@
 
 (s/def ::nil-or-int (s/nilable int?))
 
-(defn- format-detection-predictions [predictions]
-  (mapv (fn [[c p]]
-          (let [[prob xmin ymin xmax ymax] (mapv float p)]
-            {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
-        predictions))
-
-(defn- format-classification-predictions [predictions]
-  (mapv (fn [[c p]] {:class c :prob p}) predictions))
-
 (extend-protocol AClassifier
   WrappedClassifier
   (classify
@@ -120,13 +123,14 @@
     ([wrapped-classifier inputs topk]
      (util/validate! ::wrapped-classifier wrapped-classifier
                      "Invalid classifier")
-     (util/validate! ::vec-of-float-arrays inputs
+     (util/validate! ::vvec-of-numbers inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classify (:classifier wrapped-classifier)
-                 (util/vec->indexed-seq inputs)
-                 (util/->int-option topk)))))
+     (-> (.classify (:classifier wrapped-classifier)
+                 (util/vec->indexed-seq [(float-array (first inputs))])
+                 (util/->int-option topk))
+         (util/coerce-return-recursive)
+         (format-classification-predictions))))
   (classify-with-ndarray
     ([wrapped-classifier inputs]
      (classify-with-ndarray wrapped-classifier inputs nil))
@@ -136,10 +140,13 @@
      (util/validate! ::vec-of-ndarrays inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
+     (->
       (.classifyWithNDArray (:classifier wrapped-classifier)
                             (util/vec->indexed-seq inputs)
-                           (util/->int-option topk)))))
+                            (util/->int-option topk))
+      (util/coerce-return-recursive)
+      (first)
+      (format-classification-predictions))))
   WrappedImageClassifier
   (classify
     ([wrapped-image-classifier inputs]
@@ -147,13 +154,14 @@
     ([wrapped-image-classifier inputs topk]
      (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                      "Invalid classifier")
-     (util/validate! ::vec-of-float-arrays inputs
+     (util/validate! ::vvec-of-numbers inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classify (:image-classifier wrapped-image-classifier)
-                 (util/vec->indexed-seq inputs)
-                 (util/->int-option topk)))))
+     (-> (.classify (:image-classifier wrapped-image-classifier)
+                 (util/vec->indexed-seq [(float-array (first inputs))])
+                 (util/->int-option topk))
+         (util/coerce-return-recursive)
+         (format-classification-predictions))))
   (classify-with-ndarray
     ([wrapped-image-classifier inputs]
      (classify-with-ndarray wrapped-image-classifier inputs nil))
@@ -163,10 +171,12 @@
     (util/validate! ::vec-of-ndarrays inputs
                     "Invalid inputs")
     (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
-                           (util/vec->indexed-seq inputs)
-                           (util/->int-option topk))))))
+    (-> (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
+                              (util/vec->indexed-seq inputs)
+                              (util/->int-option topk))
+        (util/coerce-return-recursive)
+        (first)
+        (format-classification-predictions)))))
 
 (s/def ::image #(instance? BufferedImage %))
 (s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
@@ -199,10 +209,11 @@
     ([wrapped-image-classifier images topk dtype]
      (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                      "Invalid classifier")
+     (util/validate! ::batch-images images "Invalid Batch Images")
      (util/validate! ::nil-or-int topk "Invalid top-K")
      (util/validate! ::dtype dtype "Invalid dtype")
      (-> (.classifyImageBatch (:image-classifier wrapped-image-classifier)
-                              images
+                              (util/vec->indexed-seq images)
                               (util/->int-option topk)
                               dtype)
          (util/coerce-return-recursive)
@@ -232,8 +243,9 @@
      (util/validate! ::wrapped-detector wrapped-detector
                      "Invalid object detector")
      (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/validate! ::batch-images images "Invalid Batch Images")
      (->> (.imageBatchObjectDetect (:object-detector wrapped-detector)
-                                   images
+                                   (util/vec->indexed-seq images)
                                    (util/->int-option topk))
           (util/coerce-return-recursive)
           (first)
@@ -361,4 +373,15 @@
   "Loads images from a list of file names"
   [image-paths]
   (util/validate! ::image-paths image-paths "Invalid image paths")
-  (ImageClassifier/loadInputBatch (util/convert-vector image-paths)))
+  (util/scala-vector->vec
+   (ImageClassifier/loadInputBatch (util/convert-vector image-paths))))
+
+(defn reshape-image
+  "Reshapes a buffered image"
+  [buffered-image width height]
+  (ImageClassifier/reshapeImage buffered-image (int width) (int height)))
+
+(defn buffered-image-to-pixels
+  "Turns the buffered image into a ndarray allowing to specify shape and type"
+  [buffered-image shape-vec dtype]
+  (ImageClassifier/bufferedImageToPixels buffered-image (mx-shape/->shape shape-vec) dtype))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index 448a52f..567ebc0 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -19,6 +19,7 @@
             [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.infer :as infer]
             [org.apache.clojure-mxnet.layout :as layout]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
             [clojure.test :refer :all]))
@@ -62,3 +63,67 @@
     (is (= 5 (count predictions)))
     (is (= "n02123159 tiger cat" (:class (first predictions))))
     (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classification-with-ndarray
+  (let [classifier (create-classifier)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify-with-ndarray classifier [image])
+        predictions (infer/classify-with-ndarray classifier [image] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classify
+  (let [classifier (create-classifier)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify classifier [(ndarray/->vec image)])
+        predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-classification-with-ndarray
+  (let [descriptors [{:name "data"
+                      :shape [1 3 224 224]
+                      :layout layout/NCHW
+                      :dtype dtype/FLOAT32}]
+        factory (infer/model-factory model-path-prefix descriptors)
+        classifier (infer/create-classifier factory)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify-with-ndarray classifier [image])
+        predictions (infer/classify-with-ndarray classifier [image] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-single-classify
+  (let [descriptors [{:name "data"
+                      :shape [1 3 224 224]
+                      :layout layout/NCHW
+                      :dtype dtype/FLOAT32}]
+        factory (infer/model-factory model-path-prefix descriptors)
+        classifier (infer/create-classifier factory)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify classifier [(ndarray/->vec image)])
+        predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 91d4f0e..76acbfc 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -17,7 +17,6 @@
 (ns org.apache.clojure-mxnet.infer.objectdetector-test
   (:require [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.dtype :as dtype]
-            [org.apache.clojure-mxnet.image :as image]
             [org.apache.clojure-mxnet.infer :as infer]
             [org.apache.clojure-mxnet.layout :as layout]
             [clojure.java.io :as io]
@@ -68,11 +67,10 @@
 
 (deftest test-detection-with-ndarrays
   (let [detector (create-detector)
-        image     (-> (image/read-image "test/test-images/kitten.jpg" {:to-rbg true})
-                      (image/resize-image 512 512)
-                      (ndarray/transpose)
-                      (ndarray/expand-dims 0)
-                      (ndarray/cast dtype/FLOAT32))
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 512 512)
+                  (infer/buffered-image-to-pixels [3 512 512] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
         predictions-all (infer/detect-objects-with-ndarrays detector [image])
         predictions (infer/detect-objects-with-ndarrays detector [image] 1)
         {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]