You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/13 00:01:36 UTC

[incubator-mxnet] branch master updated: [Clojure] package infer tweaks (#13864)

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c2110ad  [Clojure] package infer tweaks (#13864)
c2110ad is described below

commit c2110ada6d43f10710d181c0deb0673fe6d829b2
Author: Carin Meier <cm...@gigasquidsoftware.com>
AuthorDate: Sat Jan 12 19:01:17 2019 -0500

    [Clojure] package infer tweaks (#13864)
    
    * change object detection prediction to be a map
    
    * change predictions to a map for image-classifiers
    
    * change return types of the classifiers to be a map
    - add tests for base classifier and with-ndarray as well
    
    * tweak return types and inputs for predict
    - add test for plain predict
    
    * updated infer-classify examples
    
    * adjust the infer/object detections tests
    
    * tweak predictor test
    
    * Feedback from @kedarbellare review
    
    * put scaling back in
    
    * put back predict so it can handle multiple inputs
    
    * restore original functions signatures (remove first)
---
 .../src/infer/imageclassifier_example.clj          |  19 ++-
 .../test/infer/imageclassifier_example_test.clj    |  25 ++--
 .../src/infer/objectdetector_example.clj           |  25 ++--
 .../test/infer/objectdetector_example_test.clj     |  24 ++--
 .../predictor/src/infer/predictor_example.clj      |   4 +-
 .../src/org/apache/clojure_mxnet/infer.clj         | 137 ++++++++++++---------
 .../clojure_mxnet/infer/imageclassifier_test.clj   |  96 +++++++++++----
 .../clojure_mxnet/infer/objectdetector_test.clj    |  47 ++++---
 .../apache/clojure_mxnet/infer/predictor_test.clj  |  24 +++-
 9 files changed, 250 insertions(+), 151 deletions(-)

diff --git a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
index 4ec7ff7..6994b4f 100644
--- a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
+++ b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
@@ -55,8 +55,8 @@
   "Print image classifier predictions for the given input file"
   [predictions]
   (println (apply str (repeat 80 "=")))
-  (doseq [[label probability] predictions]
-    (println (format "Class: %s Probability=%.8f" label probability)))
+  (doseq [p predictions]
+    (println p))
   (println (apply str (repeat 80 "="))))
 
 (defn classify-single-image
@@ -64,8 +64,8 @@
   [classifier input-image]
   (let [image (infer/load-image-from-file input-image)
         topk 5
-        [predictions] (infer/classify-image classifier image topk)]
-    predictions))
+        predictions (infer/classify-image classifier image topk)]
+    [predictions]))
 
 (defn classify-images-in-dir
   "Classify all jpg images in the directory"
@@ -78,12 +78,10 @@
                                 (filter #(re-matches #".*\.jpg$" (.getPath %)))
                                 (mapv #(.getPath %))
                                 (partition-all batch-size))]
-    (apply
-     concat
-     (for [image-files image-file-batches]
-       (let [image-batch (infer/load-image-paths image-files)
-             topk 5]
-         (infer/classify-image-batch classifier image-batch topk))))))
+    (apply concat (for [image-files image-file-batches]
+                    (let [image-batch (infer/load-image-paths image-files)
+                          topk 5]
+                      (infer/classify-image-batch classifier image-batch topk))))))
 
 (defn run-classifier
   "Runs an image classifier based on options provided"
@@ -98,6 +96,7 @@
                     factory {:contexts [(context/default-context)]})]
     (println "Classifying a single image")
     (print-predictions (classify-single-image classifier input-image))
+    (println "\n")
     (println "Classifying images in a directory")
     (doseq [predictions (classify-images-in-dir classifier input-dir)]
       (print-predictions predictions))))
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
index 5b3e08d..4b71f84 100644
--- a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
+++ b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
@@ -43,27 +43,16 @@
 
 (deftest test-single-classification
   (let [classifier (create-classifier)
-        predictions (classify-single-image classifier image-file)]
+        [[predictions]] (classify-single-image classifier image-file)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))
-    (is (= ["n02123159 tiger cat"
-            "n02124075 Egyptian cat"
-            "n02123045 tabby, tabby cat"
-            "n02127052 lynx, catamount"
-            "n02128757 snow leopard, ounce, Panthera uncia"]
-           (map first predictions)))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
 
 (deftest test-batch-classification
   (let [classifier (create-classifier)
-        batch-predictions (classify-images-in-dir classifier image-dir)
-        predictions (first batch-predictions)]
-    (is (some? batch-predictions))
+        predictions (first (classify-images-in-dir classifier image-dir))]
+    (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
diff --git a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
index 53172f0..5c30e5d 100644
--- a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
+++ b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
@@ -54,15 +54,15 @@
   "Print image detector predictions for the given input file"
   [predictions width height]
   (println (apply str (repeat 80 "=")))
-  (doseq [[label prob-and-bounds] predictions]
+  (doseq [{:keys [class prob x-min y-min x-max y-max]} predictions]
     (println (format
               "Class: %s Prob=%.5f Coords=(%.3f, %.3f, %.3f, %.3f)"
-              label
-              (aget prob-and-bounds 0)
-              (* (aget prob-and-bounds 1) width)
-              (* (aget prob-and-bounds 2) height)
-              (* (aget prob-and-bounds 3) width)
-              (* (aget prob-and-bounds 4) height))))
+              class
+              prob
+              (* x-min width)
+              (* y-min height)
+              (* x-max width)
+              (* y-max height))))
   (println (apply str (repeat 80 "="))))
 
 (defn detect-single-image
@@ -84,12 +84,10 @@
                                 (filter #(re-matches #".*\.jpg$" (.getPath %)))
                                 (mapv #(.getPath %))
                                 (partition-all batch-size))]
-    (apply
-     concat
-     (for [image-files image-file-batches]
-       (let [image-batch (infer/load-image-paths image-files)
-             topk 5]
-         (infer/detect-objects-batch detector image-batch topk))))))
+    (apply concat (for [image-files image-file-batches]
+                    (let [image-batch (infer/load-image-paths image-files)
+                          topk 5]
+                      (infer/detect-objects-batch detector image-batch topk))))))
 
 (defn run-detector
   "Runs an image detector based on options provided"
@@ -107,6 +105,7 @@
                   {:contexts [(context/default-context)]})]
     (println "Object detection on a single image")
     (print-predictions (detect-single-image detector input-image) width height)
+    (println "\n")
     (println "Object detection on images in a directory")
     (doseq [predictions (detect-images-in-dir detector input-dir)]
       (print-predictions predictions width height))))
diff --git a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
index 90ed02f..2b8ad95 100644
--- a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
+++ b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
@@ -43,23 +43,23 @@
 
 (deftest test-single-detection
   (let [detector (create-detector)
-        predictions (detect-single-image detector image-file)]
+        predictions (detect-single-image detector image-file)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))
-    (is (= ["car" "bicycle" "dog" "bicycle" "person"]
-           (map first predictions)))))
+    (is (string? class))
+    (is (< 0.8 prob))
+    (is (every? #(< 0 % 1) [x-min x-max y-min y-max]))
+    (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class predictions))))))
 
 (deftest test-batch-detection
   (let [detector (create-detector)
         batch-predictions (detect-images-in-dir detector image-dir)
-        predictions (first batch-predictions)]
+        predictions (first batch-predictions)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
     (is (some? batch-predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))))
+    (is (string? class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])
+    (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class predictions))))))
diff --git a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
index 4989641..05eb0ad 100644
--- a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
+++ b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
@@ -59,8 +59,8 @@
 (defn do-inference
   "Run inference using given predictor"
   [predictor image]
-  (let [[predictions] (infer/predict-with-ndarray predictor [image])]
-    predictions))
+  (let [predictions (infer/predict-with-ndarray predictor [image])]
+    (first predictions)))
 
 (defn postprocess
   [model-path-prefix predictions]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index 224a392..09edf15 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -22,7 +22,8 @@
             [org.apache.clojure-mxnet.io :as mx-io]
             [org.apache.clojure-mxnet.shape :as shape]
             [org.apache.clojure-mxnet.util :as util]
-            [clojure.spec.alpha :as s])
+            [clojure.spec.alpha :as s]
+            [org.apache.clojure-mxnet.shape :as mx-shape])
   (:import (java.awt.image BufferedImage)
            (org.apache.mxnet NDArray)
            (org.apache.mxnet.infer Classifier ImageClassifier
@@ -39,15 +40,26 @@
 (defrecord WrappedObjectDetector [object-detector])
 
 (s/def ::ndarray #(instance? NDArray %))
-(s/def ::float-array (s/and #(.isArray (class %)) #(every? float? %)))
-(s/def ::vec-of-float-arrays (s/coll-of ::float-array :kind vector?))
+(s/def ::number-array (s/coll-of number? :kind vector?))
+(s/def ::vvec-of-numbers (s/coll-of ::number-array :kind vector?))
 (s/def ::vec-of-ndarrays (s/coll-of ::ndarray :kind vector?))
+(s/def ::image #(instance? BufferedImage %))
+(s/def ::batch-images (s/coll-of ::image :kind vector?))
 
 (s/def ::wrapped-predictor (s/keys :req-un [::predictor]))
 (s/def ::wrapped-classifier (s/keys :req-un [::classifier]))
 (s/def ::wrapped-image-classifier (s/keys :req-un [::image-classifier]))
 (s/def ::wrapped-detector (s/keys :req-un [::object-detector]))
 
+(defn- format-detection-predictions [predictions]
+  (mapv (fn [[c p]]
+          (let [[prob xmin ymin xmax ymax] (mapv float p)]
+            {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
+        predictions))
+
+(defn- format-classification-predictions [predictions]
+  (mapv (fn [[c p]] {:class c :prob p}) predictions))
+
 (defprotocol APredictor
   (predict [wrapped-predictor inputs])
   (predict-with-ndarray [wrapped-predictor input-arrays]))
@@ -87,19 +99,20 @@
     [wrapped-predictor inputs]
     (util/validate! ::wrapped-predictor wrapped-predictor
                     "Invalid predictor")
-    (util/validate! ::vec-of-float-arrays inputs
+    (util/validate! ::vvec-of-numbers inputs
                     "Invalid inputs")
-    (util/coerce-return-recursive
-     (.predict (:predictor wrapped-predictor)
-               (util/vec->indexed-seq inputs))))
+    (->> (.predict (:predictor wrapped-predictor)
+                   (util/vec->indexed-seq (mapv float-array inputs)))
+         (util/coerce-return-recursive)
+         (mapv #(mapv float %))))
   (predict-with-ndarray [wrapped-predictor input-arrays]
     (util/validate! ::wrapped-predictor wrapped-predictor
                     "Invalid predictor")
     (util/validate! ::vec-of-ndarrays input-arrays
                     "Invalid input arrays")
-    (util/coerce-return-recursive
-     (.predictWithNDArray (:predictor wrapped-predictor)
-                          (util/vec->indexed-seq input-arrays)))))
+    (-> (.predictWithNDArray (:predictor wrapped-predictor)
+                             (util/vec->indexed-seq input-arrays))
+        (util/coerce-return-recursive))))
 
 (s/def ::nil-or-int (s/nilable int?))
 
@@ -111,13 +124,14 @@
     ([wrapped-classifier inputs topk]
      (util/validate! ::wrapped-classifier wrapped-classifier
                      "Invalid classifier")
-     (util/validate! ::vec-of-float-arrays inputs
+     (util/validate! ::vvec-of-numbers inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classify (:classifier wrapped-classifier)
-                 (util/vec->indexed-seq inputs)
-                 (util/->int-option topk)))))
+     (->> (.classify (:classifier wrapped-classifier)
+                     (util/vec->indexed-seq (mapv float-array inputs))
+                     (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (format-classification-predictions))))
   (classify-with-ndarray
     ([wrapped-classifier inputs]
      (classify-with-ndarray wrapped-classifier inputs nil))
@@ -127,10 +141,11 @@
      (util/validate! ::vec-of-ndarrays inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classifyWithNDArray (:classifier wrapped-classifier)
-                            (util/vec->indexed-seq inputs)
-                           (util/->int-option topk)))))
+     (->> (.classifyWithNDArray (:classifier wrapped-classifier)
+                                (util/vec->indexed-seq inputs)
+                                (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-classification-predictions))))
   WrappedImageClassifier
   (classify
     ([wrapped-image-classifier inputs]
@@ -138,13 +153,14 @@
     ([wrapped-image-classifier inputs topk]
      (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                      "Invalid classifier")
-     (util/validate! ::vec-of-float-arrays inputs
+     (util/validate! ::vvec-of-numbers inputs
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.classify (:image-classifier wrapped-image-classifier)
-                 (util/vec->indexed-seq inputs)
-                 (util/->int-option topk)))))
+     (->> (.classify (:image-classifier wrapped-image-classifier)
+                     (util/vec->indexed-seq (mapv float-array inputs))
+                     (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (format-classification-predictions))))
   (classify-with-ndarray
     ([wrapped-image-classifier inputs]
      (classify-with-ndarray wrapped-image-classifier inputs nil))
@@ -154,10 +170,11 @@
     (util/validate! ::vec-of-ndarrays inputs
                     "Invalid inputs")
     (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
-                           (util/vec->indexed-seq inputs)
-                           (util/->int-option topk))))))
+    (->> (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
+                               (util/vec->indexed-seq inputs)
+                               (util/->int-option topk))
+         (util/coerce-return-recursive)
+         (mapv format-classification-predictions)))))
 
 (s/def ::image #(instance? BufferedImage %))
 (s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
@@ -175,11 +192,12 @@
      (util/validate! ::image image "Invalid image")
      (util/validate! ::nil-or-int topk "Invalid top-K")
      (util/validate! ::dtype dtype "Invalid dtype")
-     (util/coerce-return-recursive
-      (.classifyImage (:image-classifier wrapped-image-classifier)
-                      image
-                      (util/->int-option topk)
-                      dtype))))
+     (->> (.classifyImage (:image-classifier wrapped-image-classifier)
+                         image
+                         (util/->int-option topk)
+                         dtype)
+          (util/coerce-return-recursive)
+          (mapv format-classification-predictions))))
   (classify-image-batch
     ([wrapped-image-classifier images]
      (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
@@ -188,13 +206,15 @@
     ([wrapped-image-classifier images topk dtype]
      (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                      "Invalid classifier")
+     (util/validate! ::batch-images images "Invalid Batch Images")
      (util/validate! ::nil-or-int topk "Invalid top-K")
      (util/validate! ::dtype dtype "Invalid dtype")
-     (util/coerce-return-recursive
-      (.classifyImageBatch (:image-classifier wrapped-image-classifier)
-                           images
-                           (util/->int-option topk)
-                           dtype)))))
+     (->> (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+                              (util/vec->indexed-seq images)
+                              (util/->int-option topk)
+                              dtype)
+         (util/coerce-return-recursive)
+         (mapv format-classification-predictions)))))
 
 (extend-protocol AObjectDetector
   WrappedObjectDetector
@@ -206,10 +226,11 @@
                     "Invalid object detector")
      (util/validate! ::image image "Invalid image")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.imageObjectDetect (:object-detector wrapped-detector)
-                          image
-                          (util/->int-option topk)))))
+     (->> (.imageObjectDetect (:object-detector wrapped-detector)
+                              image
+                              (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-detection-predictions))))
   (detect-objects-batch
     ([wrapped-detector images]
      (detect-objects-batch wrapped-detector images nil))
@@ -217,10 +238,12 @@
      (util/validate! ::wrapped-detector wrapped-detector
                      "Invalid object detector")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.imageBatchObjectDetect (:object-detector wrapped-detector)
-                               images
-                               (util/->int-option topk)))))
+     (util/validate! ::batch-images images "Invalid Batch Images")
+     (->> (.imageBatchObjectDetect (:object-detector wrapped-detector)
+                                   (util/vec->indexed-seq images)
+                                   (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-detection-predictions))))
   (detect-objects-with-ndarrays
     ([wrapped-detector input-arrays]
      (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
@@ -230,10 +253,11 @@
      (util/validate! ::vec-of-ndarrays input-arrays
                      "Invalid inputs")
      (util/validate! ::nil-or-int topk "Invalid top-K")
-     (util/coerce-return-recursive
-      (.objectDetectWithNDArray (:object-detector wrapped-detector)
-                                (util/vec->indexed-seq input-arrays)
-                                (util/->int-option topk))))))
+     (->> (.objectDetectWithNDArray (:object-detector wrapped-detector)
+                                    (util/vec->indexed-seq input-arrays)
+                                    (util/->int-option topk))
+          (util/coerce-return-recursive)
+          (mapv format-detection-predictions)))))
 
 (defprotocol AInferenceFactory
   (create-predictor [factory] [factory opts])
@@ -324,10 +348,12 @@
 
 (defn buffered-image-to-pixels
   "Convert input BufferedImage to NDArray of input shape"
-  [image input-shape-vec]
-  (util/validate! ::image image "Invalid image")
-  (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
-  (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))
+  ([image input-shape-vec]
+   (buffered-image-to-pixels image input-shape-vec dtype/FLOAT32))
+  ([image input-shape-vec dtype]
+   (util/validate! ::image image "Invalid image")
+   (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
+   (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype)))
 
 (s/def ::image-path string?)
 (s/def ::image-paths (s/coll-of ::image-path))
@@ -342,4 +368,5 @@
   "Loads images from a list of file names"
   [image-paths]
   (util/validate! ::image-paths image-paths "Invalid image paths")
-  (ImageClassifier/loadInputBatch (util/convert-vector image-paths)))
+  (util/scala-vector->vec
+   (ImageClassifier/loadInputBatch (util/convert-vector image-paths))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index b459b06..e3935c3 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -19,6 +19,7 @@
             [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.infer :as infer]
             [org.apache.clojure-mxnet.layout :as layout]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
             [clojure.test :refer :all]))
@@ -45,32 +46,83 @@
         [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
     (is (= 1000 (count predictions-all)))
     (is (= 10 (count predictions-with-default-dtype)))
-    (is (some? predictions))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))
-    (is (= ["n02123159 tiger cat"
-            "n02124075 Egyptian cat"
-            "n02123045 tabby, tabby cat"
-            "n02127052 lynx, catamount"
-            "n02128757 snow leopard, ounce, Panthera uncia"]
-           (map first predictions)))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
 
 (deftest test-batch-classification
   (let [classifier (create-classifier)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              "test/test-images/Pug-Cookie.jpg"])
-        batch-predictions-all (infer/classify-image-batch classifier image-batch)
-        batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10)
-        batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)
-        predictions (first batch-predictions)]
-    (is (= 1000 (count (first batch-predictions-all))))
-    (is (= 10 (count (first batch-predictions-with-default-dtype))))
-    (is (some? batch-predictions))
+        [batch-predictions-all] (infer/classify-image-batch classifier image-batch)
+        [batch-predictions-with-default-dtype] (infer/classify-image-batch classifier image-batch 10)
+        [predictions] (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)]
+    (is (= 1000 (count batch-predictions-all)))
+    (is (= 10 (count batch-predictions-with-default-dtype)))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(float? (second %)) predictions))
-    (is (every? #(< 0 (second %) 1) predictions))))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classification-with-ndarray
+  (let [classifier (create-classifier)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        [predictions-all] (infer/classify-with-ndarray classifier [image])
+        [predictions] (infer/classify-with-ndarray classifier [image] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classify
+  (let [classifier (create-classifier)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify classifier [(ndarray/->vec image)])
+        predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-classification-with-ndarray
+  (let [descriptors [{:name "data"
+                      :shape [1 3 224 224]
+                      :layout layout/NCHW
+                      :dtype dtype/FLOAT32}]
+        factory (infer/model-factory model-path-prefix descriptors)
+        classifier (infer/create-classifier factory)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        [predictions-all] (infer/classify-with-ndarray classifier [image])
+        [predictions] (infer/classify-with-ndarray classifier [image] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-single-classify
+  (let [descriptors [{:name "data"
+                      :shape [1 3 224 224]
+                      :layout layout/NCHW
+                      :dtype dtype/FLOAT32}]
+        factory (infer/model-factory model-path-prefix descriptors)
+        classifier (infer/create-classifier factory)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 224 224)
+                  (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        predictions-all (infer/classify classifier [(ndarray/->vec image)])
+        predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 5 (count predictions)))
+    (is (= "n02123159 tiger cat" (:class (first predictions))))
+    (is (= (< 0 (:prob (first predictions)) 1)))))
+
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 3a0e3d3..e2b9579 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -21,7 +21,8 @@
             [org.apache.clojure-mxnet.layout :as layout]
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
-            [clojure.test :refer :all]))
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]))
 
 (def model-dir "data/")
 (def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
@@ -41,27 +42,41 @@
   (let [detector (create-detector)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
         [predictions-all] (infer/detect-objects detector image)
-        [predictions] (infer/detect-objects detector image 5)]
+        [predictions] (infer/detect-objects detector image 5)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
     (is (= 13 (count predictions-all)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))
-    (is (= "cat" (first (first predictions))))))
+    (is (= "cat" class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])))
 
 (deftest test-batch-detection
   (let [detector (create-detector)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              "test/test-images/Pug-Cookie.jpg"])
-        batch-predictions-all (infer/detect-objects-batch detector image-batch)
-        batch-predictions (infer/detect-objects-batch detector image-batch 5)
-        predictions (first batch-predictions)]
-    (is (some? batch-predictions))
-    (is (= 13 (count (first batch-predictions-all))))
+        [batch-predictions-all] (infer/detect-objects-batch detector image-batch)
+        [predictions] (infer/detect-objects-batch detector image-batch 5)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+    (is (some? predictions))
+    (is (= 13 (count batch-predictions-all)))
     (is (= 5 (count predictions)))
-    (is (every? #(= 2 (count %)) predictions))
-    (is (every? #(string? (first %)) predictions))
-    (is (every? #(= 5 (count (second %))) predictions))
-    (is (every? #(< 0 (first (second %)) 1) predictions))))
+    (is (= "cat" class))
+    (is (< 0.8 prob))
+    (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
+(deftest test-detection-with-ndarrays
+  (let [detector (create-detector)
+        image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+                  (infer/reshape-image 512 512)
+                  (infer/buffered-image-to-pixels [3 512 512] dtype/FLOAT32)
+                  (ndarray/expand-dims 0))
+        [predictions-all] (infer/detect-objects-with-ndarrays detector [image])
+        [predictions] (infer/detect-objects-with-ndarrays detector [image] 1)
+        {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+        (is (some? predictions-all))
+        (is (= 1 (count predictions)))
+        (is (= "cat" class))
+        (is (< 0.8 prob))
+        (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
index 0e7532b..e1526be 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
@@ -24,7 +24,8 @@
             [clojure.java.io :as io]
             [clojure.java.shell :refer [sh]]
             [clojure.string :refer [split]]
-            [clojure.test :refer :all]))
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.util :as util]))
 
 (def model-dir "data/")
 (def model-path-prefix (str model-dir "resnet-18/resnet-18"))
@@ -42,6 +43,22 @@
         factory (infer/model-factory model-path-prefix descriptors)]
     (infer/create-predictor factory)))
 
+(deftest predictor-test-with-ndarray
+  (let [predictor (create-predictor)
+        image-ndarray (-> "test/test-images/kitten.jpg"
+                           infer/load-image-from-file
+                           (infer/reshape-image width height)
+                           (infer/buffered-image-to-pixels [3 width height])
+                           (ndarray/expand-dims 0))
+        predictions (infer/predict-with-ndarray predictor [image-ndarray])
+        synset-file (-> (io/file model-path-prefix)
+                        (.getParent)
+                        (io/file "synset.txt"))
+        synset-names (split (slurp synset-file) #"\n")
+        [best-index] (ndarray/->int-vec (ndarray/argmax (first predictions) 1))
+        best-prediction (synset-names best-index)]
+    (is (= "n02123159 tiger cat" best-prediction))))
+
 (deftest predictor-test
   (let [predictor (create-predictor)
         image-ndarray (-> "test/test-images/kitten.jpg"
@@ -49,11 +66,12 @@
                           (infer/reshape-image width height)
                           (infer/buffered-image-to-pixels [3 width height])
                           (ndarray/expand-dims 0))
-        [predictions] (infer/predict-with-ndarray predictor [image-ndarray])
+        predictions (infer/predict predictor [(ndarray/->vec image-ndarray)])
         synset-file (-> (io/file model-path-prefix)
                         (.getParent)
                         (io/file "synset.txt"))
         synset-names (split (slurp synset-file) #"\n")
-        [best-index] (ndarray/->int-vec (ndarray/argmax predictions 1))
+        ndarray-preds (ndarray/array (first predictions) [1 1000])
+        [best-index] (ndarray/->int-vec (ndarray/argmax ndarray-preds 1))
         best-prediction (synset-names best-index)]
     (is (= "n02123159 tiger cat" best-prediction))))