You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/11 22:21:56 UTC
[incubator-mxnet] branch clojure-infer-predict-tweak updated:
change return types of the classifiers to be a map - add tests for base
classifier and with-ndarray as well
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch clojure-infer-predict-tweak
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/clojure-infer-predict-tweak by this push:
new 3ec2817 change return types of the classifiers to be a map - add tests for base classifier and with-ndarray as well
3ec2817 is described below
commit 3ec28174192abf37893ea26f96ef83dc8c8d3e74
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Jan 11 17:21:31 2019 -0500
change return types of the classifiers to be a map
- add tests for base classifier and with-ndarray as well
---
.../src/org/apache/clojure_mxnet/infer.clj | 85 ++++++++++++++--------
.../clojure_mxnet/infer/imageclassifier_test.clj | 65 +++++++++++++++++
.../clojure_mxnet/infer/objectdetector_test.clj | 10 +--
3 files changed, 123 insertions(+), 37 deletions(-)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index 801c717..ced8f10 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -22,7 +22,8 @@
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.shape :as shape]
[org.apache.clojure-mxnet.util :as util]
- [clojure.spec.alpha :as s])
+ [clojure.spec.alpha :as s]
+ [org.apache.clojure-mxnet.shape :as mx-shape])
(:import (java.awt.image BufferedImage)
(org.apache.mxnet NDArray)
(org.apache.mxnet.infer Classifier ImageClassifier
@@ -39,15 +40,26 @@
(defrecord WrappedObjectDetector [object-detector])
(s/def ::ndarray #(instance? NDArray %))
-(s/def ::float-array (s/and #(.isArray (class %)) #(every? float? %)))
-(s/def ::vec-of-float-arrays (s/coll-of ::float-array :kind vector?))
+(s/def ::number-array (s/coll-of number? :kind vector?))
+(s/def ::vvec-of-numbers (s/coll-of ::number-array :kind vector?))
(s/def ::vec-of-ndarrays (s/coll-of ::ndarray :kind vector?))
+(s/def ::image #(instance? BufferedImage %))
+(s/def ::batch-images (s/coll-of ::image :kind vector?))
(s/def ::wrapped-predictor (s/keys :req-un [::predictor]))
(s/def ::wrapped-classifier (s/keys :req-un [::classifier]))
(s/def ::wrapped-image-classifier (s/keys :req-un [::image-classifier]))
(s/def ::wrapped-detector (s/keys :req-un [::object-detector]))
+(defn- format-detection-predictions [predictions]
+ (mapv (fn [[c p]]
+ (let [[prob xmin ymin xmax ymax] (mapv float p)]
+ {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
+ predictions))
+
+(defn- format-classification-predictions [predictions]
+ (mapv (fn [[c p]] {:class c :prob p}) predictions))
+
(defprotocol APredictor
(predict [wrapped-predictor inputs])
(predict-with-ndarray [wrapped-predictor input-arrays]))
@@ -103,15 +115,6 @@
(s/def ::nil-or-int (s/nilable int?))
-(defn- format-detection-predictions [predictions]
- (mapv (fn [[c p]]
- (let [[prob xmin ymin xmax ymax] (mapv float p)]
- {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax}))
- predictions))
-
-(defn- format-classification-predictions [predictions]
- (mapv (fn [[c p]] {:class c :prob p}) predictions))
-
(extend-protocol AClassifier
WrappedClassifier
(classify
@@ -120,13 +123,14 @@
([wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
+ (util/validate! ::vvec-of-numbers inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:classifier wrapped-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (-> (.classify (:classifier wrapped-classifier)
+ (util/vec->indexed-seq [(float-array (first inputs))])
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (format-classification-predictions))))
(classify-with-ndarray
([wrapped-classifier inputs]
(classify-with-ndarray wrapped-classifier inputs nil))
@@ -136,10 +140,13 @@
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
+ (->
(.classifyWithNDArray (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (first)
+ (format-classification-predictions))))
WrappedImageClassifier
(classify
([wrapped-image-classifier inputs]
@@ -147,13 +154,14 @@
([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
+ (util/validate! ::vvec-of-numbers inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:image-classifier wrapped-image-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (-> (.classify (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq [(float-array (first inputs))])
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (format-classification-predictions))))
(classify-with-ndarray
([wrapped-image-classifier inputs]
(classify-with-ndarray wrapped-image-classifier inputs nil))
@@ -163,10 +171,12 @@
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))))
+ (-> (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk))
+ (util/coerce-return-recursive)
+ (first)
+ (format-classification-predictions)))))
(s/def ::image #(instance? BufferedImage %))
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
@@ -199,10 +209,11 @@
([wrapped-image-classifier images topk dtype]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
+ (util/validate! ::batch-images images "Invalid Batch Images")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/validate! ::dtype dtype "Invalid dtype")
(-> (.classifyImageBatch (:image-classifier wrapped-image-classifier)
- images
+ (util/vec->indexed-seq images)
(util/->int-option topk)
dtype)
(util/coerce-return-recursive)
@@ -232,8 +243,9 @@
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/validate! ::batch-images images "Invalid Batch Images")
(->> (.imageBatchObjectDetect (:object-detector wrapped-detector)
- images
+ (util/vec->indexed-seq images)
(util/->int-option topk))
(util/coerce-return-recursive)
(first)
@@ -361,4 +373,15 @@
"Loads images from a list of file names"
[image-paths]
(util/validate! ::image-paths image-paths "Invalid image paths")
- (ImageClassifier/loadInputBatch (util/convert-vector image-paths)))
+ (util/scala-vector->vec
+ (ImageClassifier/loadInputBatch (util/convert-vector image-paths))))
+
+(defn reshape-image
+ "Reshapes a buffered image"
+ [buffered-image width height]
+ (ImageClassifier/reshapeImage buffered-image (int width) (int height)))
+
+(defn buffered-image-to-pixels
+ "Turns the buffered image into a ndarray allowing to specify shape and type"
+ [buffered-image shape-vec dtype]
+ (ImageClassifier/bufferedImageToPixels buffered-image (mx-shape/->shape shape-vec) dtype))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index 448a52f..567ebc0 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -19,6 +19,7 @@
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.layout :as layout]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.test :refer :all]))
@@ -62,3 +63,67 @@
(is (= 5 (count predictions)))
(is (= "n02123159 tiger cat" (:class (first predictions))))
(is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classification-with-ndarray
+ (let [classifier (create-classifier)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ predictions-all (infer/classify-with-ndarray classifier [image])
+ predictions (infer/classify-with-ndarray classifier [image] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-single-classify
+ (let [classifier (create-classifier)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ predictions-all (infer/classify classifier [(ndarray/->vec image)])
+ predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-classification-with-ndarray
+ (let [descriptors [{:name "data"
+ :shape [1 3 224 224]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)
+ classifier (infer/create-classifier factory)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ predictions-all (infer/classify-with-ndarray classifier [image])
+ predictions (infer/classify-with-ndarray classifier [image] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+(deftest test-base-single-classify
+ (let [descriptors [{:name "data"
+ :shape [1 3 224 224]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)
+ classifier (infer/create-classifier factory)
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 224 224)
+ (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
+ predictions-all (infer/classify classifier [(ndarray/->vec image)])
+ predictions (infer/classify classifier [(ndarray/->vec image)] 5)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 5 (count predictions)))
+ (is (= "n02123159 tiger cat" (:class (first predictions))))
+ (is (= (< 0 (:prob (first predictions)) 1)))))
+
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 91d4f0e..76acbfc 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -17,7 +17,6 @@
(ns org.apache.clojure-mxnet.infer.objectdetector-test
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
- [org.apache.clojure-mxnet.image :as image]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
@@ -68,11 +67,10 @@
(deftest test-detection-with-ndarrays
(let [detector (create-detector)
- image (-> (image/read-image "test/test-images/kitten.jpg" {:to-rbg true})
- (image/resize-image 512 512)
- (ndarray/transpose)
- (ndarray/expand-dims 0)
- (ndarray/cast dtype/FLOAT32))
+ image (-> (infer/load-image-from-file "test/test-images/kitten.jpg")
+ (infer/reshape-image 512 512)
+ (infer/buffered-image-to-pixels [3 512 512] dtype/FLOAT32)
+ (ndarray/expand-dims 0))
predictions-all (infer/detect-objects-with-ndarrays detector [image])
predictions (infer/detect-objects-with-ndarrays detector [image] 1)
{:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]