You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/13 00:01:36 UTC
[incubator-mxnet] branch master updated: [Clojure] package infer
tweaks (#13864)
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c2110ad [Clojure] package infer tweaks (#13864)
c2110ad is described below
commit c2110ada6d43f10710d181c0deb0673fe6d829b2
Author: Carin Meier <cm...@gigasquidsoftware.com>
AuthorDate: Sat Jan 12 19:01:17 2019 -0500
[Clojure] package infer tweaks (#13864)
* change object detection prediction to be a map
* change predictions to a map for image-classifiers
* change return types of the classifiers to be a map
- add tests for base classifier and with-ndarray as well
* tweak return types and inputs for predict
- add test for plain predict
* updated infer-classify examples
* adjust the infer/object detections tests
* tweak predictor test
* Feedback from @kedarbellare review
* put scaling back in
* put back predict so it can handle multiple inputs
* restore original functions signatures (remove first)
---
.../src/infer/imageclassifier_example.clj | 19 ++-
.../test/infer/imageclassifier_example_test.clj | 25 ++--
.../src/infer/objectdetector_example.clj | 25 ++--
.../test/infer/objectdetector_example_test.clj | 24 ++--
.../predictor/src/infer/predictor_example.clj | 4 +-
.../src/org/apache/clojure_mxnet/infer.clj | 137 ++++++++++++---------
.../clojure_mxnet/infer/imageclassifier_test.clj | 96 +++++++++++----
.../clojure_mxnet/infer/objectdetector_test.clj | 47 ++++---
.../apache/clojure_mxnet/infer/predictor_test.clj | 24 +++-
9 files changed, 250 insertions(+), 151 deletions(-)
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
index 4ec7ff7..6994b4f 100644
--- a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
+++ b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
@@ -55,8 +55,8 @@
"Print image classifier predictions for the given input file"
[predictions]
(println (apply str (repeat 80 "=")))
- (doseq [[label probability] predictions]
- (println (format "Class: %s Probability=%.8f" label probability)))
+ (doseq [p predictions]
+ (println p))
(println (apply str (repeat 80 "="))))
(defn classify-single-image
@@ -64,8 +64,8 @@
[classifier input-image]
(let [image (infer/load-image-from-file input-image)
topk 5
- [predictions] (infer/classify-image classifier image topk)]
- predictions))
+ predictions (infer/classify-image classifier image topk)]
+ [predictions]))
(defn classify-images-in-dir
"Classify all jpg images in the directory"
@@ -78,12 +78,10 @@
(filter #(re-matches #".*\.jpg$" (.getPath %)))
(mapv #(.getPath %))
(partition-all batch-size))]
- (apply
- concat
- (for [image-files image-file-batches]
- (let [image-batch (infer/load-image-paths image-files)
- topk 5]
- (infer/classify-image-batch classifier image-batch topk))))))
+ (apply concat (for [image-files image-file-batches]
+ (let [image-batch (infer/load-image-paths image-files)
+ topk 5]
+ (infer/classify-image-batch classifier image-batch topk))))))
(defn run-classifier
"Runs an image classifier based on options provided"
@@ -98,6 +96,7 @@
factory {:contexts [(context/default-context)]})]
(println "Classifying a single image")
(print-predictions (classify-single-image classifier input-image))
+ (println "\n")
(println "Classifying images in a directory")
(doseq [predictions (classify-images-in-dir classifier input-dir)]
(print-predictions predictions))))
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
index 5b3e08d..4b71f84 100644
--- a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
+++ b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
@@ -43,27 +43,16 @@
(deftest test-single-classification
(let [classifier (create-classifier)
- predictions (classify-single-image classifier image-file)]
+ [[predictions]] (classify-single-image classifier image-file)]
(is (some? predictions))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(float? (second %)) predictions))
- (is (every? #(< 0 (second %) 1) predictions))
- (is (= ["n02123159 tiger cat"
- "n02124075 Egyptian cat"
- "n02123045 tabby, tabby cat"
- "n02127052 lynx, catamount"
- "n02128757 snow leopard, ounce, Panthera uncia"]
- (map first predictions)))))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
(deftest test-batch-classification
(let [classifier (create-classifier)
- batch-predictions (classify-images-in-dir classifier image-dir)
- predictions (first batch-predictions)]
- (is (some? batch-predictions))
+ predictions (first (classify-images-in-dir classifier image-dir))]
+ (is (some? predictions))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(float? (second %)) predictions))
- (is (every? #(< 0 (second %) 1) predictions))))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
diff --git a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
index 53172f0..5c30e5d 100644
--- a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
+++ b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
@@ -54,15 +54,15 @@
"Print image detector predictions for the given input file"
[predictions width height]
(println (apply str (repeat 80 "=")))
- (doseq [[label prob-and-bounds] predictions]
+ (doseq [{:keys [class prob x-min y-min x-max y-max]} predictions]
(println (format
"Class: %s Prob=%.5f Coords=(%.3f, %.3f, %.3f, %.3f)"
- label
- (aget prob-and-bounds 0)
- (* (aget prob-and-bounds 1) width)
- (* (aget prob-and-bounds 2) height)
- (* (aget prob-and-bounds 3) width)
- (* (aget prob-and-bounds 4) height))))
+ class
+ prob
+ (* x-min width)
+ (* y-min height)
+ (* x-max width)
+ (* y-max height))))
(println (apply str (repeat 80 "="))))
(defn detect-single-image
@@ -84,12 +84,10 @@
(filter #(re-matches #".*\.jpg$" (.getPath %)))
(mapv #(.getPath %))
(partition-all batch-size))]
- (apply
- concat
- (for [image-files image-file-batches]
- (let [image-batch (infer/load-image-paths image-files)
- topk 5]
- (infer/detect-objects-batch detector image-batch topk))))))
+ (apply concat (for [image-files image-file-batches]
+ (let [image-batch (infer/load-image-paths image-files)
+ topk 5]
+ (infer/detect-objects-batch detector image-batch topk))))))
(defn run-detector
"Runs an image detector based on options provided"
@@ -107,6 +105,7 @@
{:contexts [(context/default-context)]})]
(println "Object detection on a single image")
(print-predictions (detect-single-image detector input-image) width height)
+ (println "\n")
(println "Object detection on images in a directory")
(doseq [predictions (detect-images-in-dir detector input-dir)]
(print-predictions predictions width height))))
diff --git a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
index 90ed02f..2b8ad95 100644
--- a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
+++ b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
@@ -43,23 +43,23 @@
(deftest test-single-detection
(let [detector (create-detector)
- predictions (detect-single-image detector image-file)]
+ predictions (detect-single-image detector image-file)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
(is (some? predictions))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(= 5 (count (second %))) predictions))
- (is (every? #(< 0 (first (second %)) 1) predictions))
- (is (= ["car" "bicycle" "dog" "bicycle" "person"]
- (map first predictions)))))
+ (is (string? class))
+ (is (< 0.8 prob))
+ (is (every? #(< 0 % 1) [x-min x-max y-min y-max]))
+ (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class predictions))))))
(deftest test-batch-detection
(let [detector (create-detector)
batch-predictions (detect-images-in-dir detector image-dir)
- predictions (first batch-predictions)]
+ predictions (first batch-predictions)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
(is (some? batch-predictions))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(= 5 (count (second %))) predictions))
- (is (every? #(< 0 (first (second %)) 1) predictions))))
+ (is (string? class))
+ (is (< 0.8 prob))
+ (every? #(< 0 % 1) [x-min x-max y-min y-max])
+ (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class predictions))))))
diff --git a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
index 4989641..05eb0ad 100644
--- a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
+++ b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
@@ -59,8 +59,8 @@
(defn do-inference
"Run inference using given predictor"
[predictor image]
- (let [[predictions] (infer/predict-with-ndarray predictor [image])]
- predictions))
+ (let [predictions (infer/predict-with-ndarray predictor [image])]
+ (first predictions)))
(defn postprocess
[model-path-prefix predictions]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index 224a392..09edf15 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -22,7 +22,8 @@
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.shape :as shape]
[org.apache.clojure-mxnet.util :as util]
- [clojure.spec.alpha :as s])
+ [clojure.spec.alpha :as s]
+ [org.apache.clojure-mxnet.shape :as mx-shape])
(:import (java.awt.image BufferedImage)
(org.apache.mxnet NDArray)
(org.apache.mxnet.infer Classifier ImageClassifier
@@ -39,15 +40,26 @@
(defrecord WrappedObjectDetector [object-detector])
(s/def ::ndarray #(instance? NDArray %))
-(s/def ::float-array (s/and #(.isArray (class %)) #(every? float? %)))
-(s/def ::vec-of-float-arrays (s/coll-of ::float-array :kind vector?))
+(s/def ::number-array (s/coll-of number? :kind vector?))
+(s/def ::vvec-of-numbers (s/coll-of ::number-array :kind vector?))
(s/def ::vec-of-ndarrays (s/coll-of ::ndarray :kind vector?))
+(s/def ::image #(instance? BufferedImage %))
+(s/def ::batch-images (s/coll-of ::image :kind vector?))
(s/def ::wrapped-predictor (s/keys :req-un [::predictor]))
(s/def ::wrapped-classifier (s/keys :req-un [::classifier]))
(s/def ::wrapped-image-classifier (s/keys :req-un [::image-classifier]))
(s/def ::wrapped-detector (s/keys :req-un [::object-detector]))
+(defn- format-detection-predictions [predictions]
+ (mapv (fn [[c p]]
+ (let [[prob xmin ymin xmax ymax] (mapv float p)]
+ {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
+ predictions))
+
+(defn- format-classification-predictions [predictions]
+ (mapv (fn [[c p]] {:class c :prob p}) predictions))
+
(defprotocol APredictor
(predict [wrapped-predictor inputs])
(predict-with-ndarray [wrapped-predictor input-arrays]))
@@ -87,19 +99,20 @@
[wrapped-predictor inputs]
(util/validate! ::wrapped-predictor wrapped-predictor
"Invalid predictor")
- (util/validate! ::vec-of-float-arrays inputs
+ (util/validate! ::vvec-of-numbers inputs
"Invalid inputs")
- (util/coerce-return-recursive
- (.predict (:predictor wrapped-predictor)
- (util/vec->indexed-seq inputs))))
+ (->> (.predict (:predictor wrapped-predictor)
+ (util/vec->indexed-seq (mapv float-array inputs)))
+ (util/coerce-return-recursive)
+ (mapv #(mapv float %))))
(predict-with-ndarray [wrapped-predictor input-arrays]
(util/validate! ::wrapped-predictor wrapped-predictor
"Invalid predictor")
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid input arrays")
- (util/coerce-return-recursive
- (.predictWithNDArray (:predictor wrapped-predictor)
- (util/vec->indexed-seq input-arrays)))))
+ (-> (.predictWithNDArray (:predictor wrapped-predictor)
+ (util/vec->indexed-seq input-arrays))
+ (util/coerce-return-recursive))))
(s/def ::nil-or-int (s/nilable int?))
@@ -111,13 +124,14 @@
([wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
+ (util/validate! ::vvec-of-numbers inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:classifier wrapped-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (->> (.classify (:classifier wrapped-classifier)
+ (util/vec->indexed-seq (mapv float-array inputs))
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (format-classification-predictions))))
(classify-with-ndarray
([wrapped-classifier inputs]
(classify-with-ndarray wrapped-classifier inputs nil))
@@ -127,10 +141,11 @@
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyWithNDArray (:classifier wrapped-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (->> (.classifyWithNDArray (:classifier wrapped-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (mapv format-classification-predictions))))
WrappedImageClassifier
(classify
([wrapped-image-classifier inputs]
@@ -138,13 +153,14 @@
([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
+ (util/validate! ::vvec-of-numbers inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:image-classifier wrapped-image-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (->> (.classify (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq (mapv float-array inputs))
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (format-classification-predictions))))
(classify-with-ndarray
([wrapped-image-classifier inputs]
(classify-with-ndarray wrapped-image-classifier inputs nil))
@@ -154,10 +170,11 @@
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))))
+ (->> (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (mapv format-classification-predictions)))))
(s/def ::image #(instance? BufferedImage %))
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
@@ -175,11 +192,12 @@
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/validate! ::dtype dtype "Invalid dtype")
- (util/coerce-return-recursive
- (.classifyImage (:image-classifier wrapped-image-classifier)
- image
- (util/->int-option topk)
- dtype))))
+ (->> (.classifyImage (:image-classifier wrapped-image-classifier)
+ image
+ (util/->int-option topk)
+ dtype)
+ (util/coerce-return-recursive)
+ (mapv format-classification-predictions))))
(classify-image-batch
([wrapped-image-classifier images]
(classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
@@ -188,13 +206,15 @@
([wrapped-image-classifier images topk dtype]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
+ (util/validate! ::batch-images images "Invalid Batch Images")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/validate! ::dtype dtype "Invalid dtype")
- (util/coerce-return-recursive
- (.classifyImageBatch (:image-classifier wrapped-image-classifier)
- images
- (util/->int-option topk)
- dtype)))))
+ (->> (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq images)
+ (util/->int-option topk)
+ dtype)
+ (util/coerce-return-recursive)
+ (mapv format-classification-predictions)))))
(extend-protocol AObjectDetector
WrappedObjectDetector
@@ -206,10 +226,11 @@
"Invalid object detector")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageObjectDetect (:object-detector wrapped-detector)
- image
- (util/->int-option topk)))))
+ (->> (.imageObjectDetect (:object-detector wrapped-detector)
+ image
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (mapv format-detection-predictions))))
(detect-objects-batch
([wrapped-detector images]
(detect-objects-batch wrapped-detector images nil))
@@ -217,10 +238,12 @@
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageBatchObjectDetect (:object-detector wrapped-detector)
- images
- (util/->int-option topk)))))
+ (util/validate! ::batch-images images "Invalid Batch Images")
+ (->> (.imageBatchObjectDetect (:object-detector wrapped-detector)
+ (util/vec->indexed-seq images)
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (mapv format-detection-predictions))))
(detect-objects-with-ndarrays
([wrapped-detector input-arrays]
(detect-objects-with-ndarrays wrapped-detector input-arrays nil))
@@ -230,10 +253,11 @@
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.objectDetectWithNDArray (:object-detector wrapped-detector)
- (util/vec->indexed-seq input-arrays)
- (util/->int-option topk))))))
+ (->> (.objectDetectWithNDArray (:object-detector wrapped-detector)
+ (util/vec->indexed-seq input-arrays)
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (mapv format-detection-predictions)))))
(defprotocol AInferenceFactory
(create-predictor [factory] [factory opts])
@@ -324,10 +348,12 @@
(defn buffered-image-to-pixels
"Convert input BufferedImage to NDArray of input shape"
- [image input-shape-vec]
- (util/validate! ::image image "Invalid image")
- (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
- (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))
+ ([image input-shape-vec]
+ (buffered-image-to-pixels image input-shape-vec dtype/FLOAT32))
+ ([image input-shape-vec dtype]
+ (util/validate! ::image image "Invalid image")
+ (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
+ (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype)))
(s/def ::image-path string?)
(s/def ::image-paths (s/coll-of ::image-path))
@@ -342,4 +368,5 @@
"Loads images from a list of file names"
[image-paths]
(util/validate! ::image-paths image-paths "Invalid image paths")
- (ImageClassifier/loadInputBatch (util/convert-vector image-paths)))
+ (util/scala-vector->vec
+ (ImageClassifier/loadInputBatch (util/convert-vector image-paths))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index b459b06..e3935c3 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -19,6 +19,7 @@
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.layout :as layout]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.test :refer :all]))
@@ -45,32 +46,83 @@
[predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
(is (= 1000 (count predictions-all)))
(is (= 10 (count predictions-with-default-dtype)))
- (is (some? predictions))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(float? (second %)) predictions))
- (is (every? #(< 0 (second %) 1) predictions))
- (is (= ["n02123159 tiger cat"
- "n02124075 Egyptian cat"
- "n02123045 tabby, tabby cat"
- "n02127052 lynx, catamount"
- "n02128757 snow leopard, ounce, Panthera uncia"]
- (map first predictions)))))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
(deftest test-batch-classification
(let [classifier (create-classifier)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
- batch-predictions-all (infer/classify-image-batch classifier image-batch)
- batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10)
- batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)
- predictions (first batch-predictions)]
- (is (= 1000 (count (first batch-predictions-all))))
- (is (= 10 (count (first batch-predictions-with-default-dtype))))
- (is (some? batch-predictions))
+ [batch-predictions-all] (infer/classify-image-batch classifier image-batch)
+ [batch-predictions-with-default-dtype] (infer/classify-image-batch classifier image-batch 10)
+ [predictions] (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)]
+ (is (= 1000 (count batch-predictions-all)))
+ (is (= 10 (count batch-predictions-with-default-dtype)))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(float? (second %)) predictions))
- (is (every? #(< 0 (second %) 1) predictions))))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classification-with-ndarray
+ (let [classifier (create-classifier)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ [predictions-all] (infer/classify-with-ndarray classifier [image])
+ [predictions] (infer/classify-with-ndarray classifier [image] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classify
+ (let [classifier (create-classifier)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ predictions-all (infer/classify classifier [(ndarray/->vec image)])
+ predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-classification-with-ndarray
+ (let [descriptors [{:name "data"
+ :shape [1 3 224 224]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)
+ classifier (infer/create-classifier factory)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ [predictions-all] (infer/classify-with-ndarray classifier [image])
+ [predictions] (infer/classify-with-ndarray classifier [image] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-single-classify
+ (let [descriptors [{:name "data"
+ :shape [1 3 224 224]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)
+ classifier (infer/create-classifier factory)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ predictions-all (infer/classify classifier [(ndarray/->vec image)])
+ predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 3a0e3d3..e2b9579 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -21,7 +21,8 @@
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
- [clojure.test :refer :all]))
+ [clojure.test :refer :all]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]))
(def model-dir "data/")
(def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
@@ -41,27 +42,41 @@
(let [detector (create-detector)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
[predictions-all] (infer/detect-objects detector image)
- [predictions] (infer/detect-objects detector image 5)]
+ [predictions] (infer/detect-objects detector image 5)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
(is (some? predictions))
(is (= 5 (count predictions)))
(is (= 13 (count predictions-all)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(= 5 (count (second %))) predictions))
- (is (every? #(< 0 (first (second %)) 1) predictions))
- (is (= "cat" (first (first predictions))))))
+ (is (= "cat" class))
+ (is (< 0.8 prob))
+ (every? #(< 0 % 1) [x-min x-max y-min y-max])))
(deftest test-batch-detection
(let [detector (create-detector)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
- batch-predictions-all (infer/detect-objects-batch detector image-batch)
- batch-predictions (infer/detect-objects-batch detector image-batch 5)
- predictions (first batch-predictions)]
- (is (some? batch-predictions))
- (is (= 13 (count (first batch-predictions-all))))
+ [batch-predictions-all] (infer/detect-objects-batch detector image-batch)
+ [predictions] (infer/detect-objects-batch detector image-batch 5)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+ (is (some? predictions))
+ (is (= 13 (count batch-predictions-all)))
(is (= 5 (count predictions)))
- (is (every? #(= 2 (count %)) predictions))
- (is (every? #(string? (first %)) predictions))
- (is (every? #(= 5 (count (second %))) predictions))
- (is (every? #(< 0 (first (second %)) 1) predictions))))
+ (is (= "cat" class))
+ (is (< 0.8 prob))
+ (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
+(deftest test-detection-with-ndarrays
+ (let [detector (create-detector)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 512 512)
+ (infer/buffered-image-to-pixels [3 512 512] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ [predictions-all] (infer/detect-objects-with-ndarrays detector [image])
+ [predictions] (infer/detect-objects-with-ndarrays detector [image] 1)
+ {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
+ (is (some? predictions-all))
+ (is (= 1 (count predictions)))
+ (is (= "cat" class))
+ (is (< 0.8 prob))
+ (every? #(< 0 % 1) [x-min x-max y-min y-max])))
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
index 0e7532b..e1526be 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
@@ -24,7 +24,8 @@
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.string :refer [split]]
- [clojure.test :refer :all]))
+ [clojure.test :refer :all]
+ [org.apache.clojure-mxnet.util :as util]))
(def model-dir "data/")
(def model-path-prefix (str model-dir "resnet-18/resnet-18"))
@@ -42,6 +43,22 @@
factory (infer/model-factory model-path-prefix descriptors)]
(infer/create-predictor factory)))
+(deftest predictor-test-with-ndarray
+ (let [predictor (create-predictor)
+ image-ndarray (-> "test/test-images/kitten.jpg"
+ infer/load-image-from-file
+ (infer/reshape-image width height)
+ (infer/buffered-image-to-pixels [3 width height])
+ (ndarray/expand-dims 0))
+ predictions (infer/predict-with-ndarray predictor [image-ndarray])
+ synset-file (-> (io/file model-path-prefix)
+ (.getParent)
+ (io/file "synset.txt"))
+ synset-names (split (slurp synset-file) #"\n")
+ [best-index] (ndarray/->int-vec (ndarray/argmax (first predictions) 1))
+ best-prediction (synset-names best-index)]
+ (is (= "n02123159 tiger cat" best-prediction))))
+
(deftest predictor-test
(let [predictor (create-predictor)
image-ndarray (-> "test/test-images/kitten.jpg"
@@ -49,11 +66,12 @@
(infer/reshape-image width height)
(infer/buffered-image-to-pixels [3 width height])
(ndarray/expand-dims 0))
- [predictions] (infer/predict-with-ndarray predictor [image-ndarray])
+ predictions (infer/predict predictor [(ndarray/->vec image-ndarray)])
synset-file (-> (io/file model-path-prefix)
(.getParent)
(io/file "synset.txt"))
synset-names (split (slurp synset-file) #"\n")
- [best-index] (ndarray/->int-vec (ndarray/argmax predictions 1))
+ ndarray-preds (ndarray/array (first predictions) [1 1000])
+ [best-index] (ndarray/->int-vec (ndarray/argmax ndarray-preds 1))
best-prediction (synset-names best-index)]
(is (= "n02123159 tiger cat" best-prediction))))