You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/12/27 22:20:46 UTC
[incubator-mxnet] branch master updated: Port of scala infer
package to clojure (#13595)
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b6b197a Port of scala infer package to clojure (#13595)
b6b197a is described below
commit b6b197ae75b4b08e3880257269dcb0f2733cf95f
Author: Kedar Bellare <ke...@gmail.com>
AuthorDate: Thu Dec 27 14:20:29 2018 -0800
Port of scala infer package to clojure (#13595)
* Port of scala infer package to clojure
* Add inference examples
* Fix project.clj
* Update code for integration tests
* Address comments and add unit tests
* Add specs and simplify interface
* Minor nit
* Update README
---
.../examples/infer/imageclassifier/.gitignore | 12 +
.../examples/infer/imageclassifier/README.md | 24 ++
.../examples/infer/imageclassifier/project.clj | 25 ++
.../imageclassifier/scripts/get_resnet_18_data.sh | 45 +++
.../imageclassifier/scripts/get_resnet_data.sh} | 34 +-
.../src/infer/imageclassifier_example.clj | 95 ++++++
.../test/infer/imageclassifier_example_test.clj | 69 ++++
.../examples/infer/objectdetector/.gitignore | 12 +
.../examples/infer/objectdetector/README.md | 24 ++
.../examples/infer/objectdetector/project.clj | 25 ++
.../infer/objectdetector/scripts/get_ssd_data.sh | 49 +++
.../src/infer/objectdetector_example.clj | 121 +++++++
.../test/infer/objectdetector_example_test.clj | 65 ++++
.../examples/infer/predictor/.gitignore | 12 +
.../examples/infer/predictor/README.md | 24 ++
.../examples/infer/predictor/project.clj | 25 ++
.../infer/predictor/scripts/get_resnet_18_data.sh} | 32 +-
.../infer/predictor/scripts/get_resnet_data.sh} | 34 +-
.../predictor/src/infer/predictor_example.clj | 101 ++++++
.../test/infer/predictor_example_test.clj | 51 +++
contrib/clojure-package/integration-tests.sh | 6 +-
.../infer/get_resnet_18_data.sh} | 26 +-
.../infer/get_ssd_data.sh} | 27 +-
.../src/org/apache/clojure_mxnet/image.clj | 6 +-
.../src/org/apache/clojure_mxnet/infer.clj | 353 +++++++++++++++++++++
.../src/org/apache/clojure_mxnet/util.clj | 9 +
.../clojure_mxnet/infer/imageclassifier_test.clj | 68 ++++
.../clojure_mxnet/infer/objectdetector_test.clj | 63 ++++
.../apache/clojure_mxnet/infer/predictor_test.clj | 59 ++++
.../test/org/apache/clojure_mxnet/util_test.clj | 10 +
.../test/test-images/Pug-Cookie.jpg | Bin 0 -> 104323 bytes
.../clojure-package/test/test-images/kitten.jpg | Bin 0 -> 110969 bytes
32 files changed, 1458 insertions(+), 48 deletions(-)
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/.gitignore b/contrib/clojure-package/examples/infer/imageclassifier/.gitignore
new file mode 100644
index 0000000..35491f1
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/imageclassifier/.gitignore
@@ -0,0 +1,12 @@
+/target
+/classes
+/checkouts
+/images
+pom.xml
+pom.xml.asc
+*.jar
+*.class
+/.lein-*
+/.nrepl-port
+.hgignore
+.hg/
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/README.md b/contrib/clojure-package/examples/infer/imageclassifier/README.md
new file mode 100644
index 0000000..a832860
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/imageclassifier/README.md
@@ -0,0 +1,24 @@
+# imageclassifier
+
+Run image classification using clojure infer package.
+
+## Installation
+
+Before you run this example, make sure that you have the clojure package installed.
+In the main clojure package directory, do `lein install`. Then you can run
+`lein install` in this directory.
+
+## Usage
+
+```
+$ chmod +x scripts/get_resnet_18_data.sh
+$ ./scripts/get_resnet_18_data.sh
+$
+$ lein run -- --help
+$ lein run -- -m models/resnet-18/resnet-18 -i images/kitten.jpg -d images/
+$
+$ lein uberjar
+$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar --help
+$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar \
+ -m models/resnet-18/resnet-18 -i images/kitten.jpg -d images/
+```
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/project.clj b/contrib/clojure-package/examples/infer/imageclassifier/project.clj
new file mode 100644
index 0000000..2d5b171
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/imageclassifier/project.clj
@@ -0,0 +1,25 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(defproject imageclassifier "0.1.0-SNAPSHOT"
+ :description "Image classification using infer with MXNet"
+ :plugins [[lein-cljfmt "0.5.7"]]
+ :dependencies [[org.clojure/clojure "1.9.0"]
+ [org.clojure/tools.cli "0.4.1"]
+ [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
+ :main ^:skip-aot infer.imageclassifier-example
+ :profiles {:uberjar {:aot :all}})
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_18_data.sh b/contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_18_data.sh
new file mode 100755
index 0000000..1a142e8
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_18_data.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -evx
+
+MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)
+
+data_path=$MXNET_ROOT/models/resnet-18/
+
+image_path=$MXNET_ROOT/images/
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+fi
+
+if [ ! -d "$image_path" ]; then
+ mkdir -p "$image_path"
+fi
+
+if [ ! -f "$data_path/resnet-18-0000.params" ]; then
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $data_path
+fi
+
+if [ ! -f "$image_path/kitten.jpg" ]; then
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
+ wget https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg -P $image_path
+fi
diff --git a/contrib/clojure-package/integration-tests.sh b/contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_data.sh
similarity index 50%
copy from contrib/clojure-package/integration-tests.sh
copy to contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_data.sh
index 3297fdc..fcef59b 100755
--- a/contrib/clojure-package/integration-tests.sh
+++ b/contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_data.sh
@@ -1,4 +1,5 @@
#!/bin/bash
+
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,13 +17,28 @@
# specific language governing permissions and limitations
# under the License.
-set -evx
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)
+
+data_path=$MXNET_ROOT/models/resnet-152/
+
+image_path=$MXNET_ROOT/images/
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+fi
+
+if [ ! -d "$image_path" ]; then
+ mkdir -p "$image_path"
+fi
+
+if [ ! -f "$data_path/resnet-152-0000.params" ]; then
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-0000.params -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-symbol.json -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/synset.txt -P $data_path
+fi
-MXNET_HOME=${PWD}
-EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
-#cd ${MXNET_HOME}/contrib/clojure-package
-#lein test
-#lein cloverage --codecov
-for i in `find ${EXAMPLES_HOME} -name test` ; do
-cd ${i} && lein test
-done
+if [ ! -f "$image_path/kitten.jpg" ]; then
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
+fi
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
new file mode 100644
index 0000000..d680b9a
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
@@ -0,0 +1,95 @@
+(ns infer.imageclassifier-example
+ (:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [clojure.java.io :as io]
+ [clojure.string :refer [join]]
+ [clojure.tools.cli :refer [parse-opts]])
+ (:gen-class))
+
+(defn check-valid-dir
+ "Check that the input directory exists"
+ [input-dir]
+ (let [dir (io/file input-dir)]
+ (and
+ (.exists dir)
+ (.isDirectory dir))))
+
+(defn check-valid-file
+ "Check that the file exists"
+ [input-file]
+ (.exists (io/file input-file)))
+
+(def cli-options
+ [["-m" "--model-path-prefix PREFIX" "Model path prefix"
+ :default "models/resnet-18/resnet-18"
+ :validate [#(check-valid-file (str % "-symbol.json"))
+ "Model path prefix is invalid"]]
+ ["-i" "--input-image IMAGE" "Input image"
+ :default "images/kitten.jpg"
+ :validate [check-valid-file "Input file not found"]]
+ ["-d" "--input-dir IMAGE_DIR" "Input directory"
+ :default "images/"
+ :validate [check-valid-dir "Input directory not found"]]
+ ["-h" "--help"]])
+
+(defn print-predictions
+ "Print image classifier predictions for the given input file"
+ [predictions]
+ (println (apply str (repeat 80 "=")))
+ (doseq [[label probability] predictions]
+ (println (format "Class: %s Probability=%.8f" label probability)))
+ (println (apply str (repeat 80 "="))))
+
+(defn classify-single-image
+ "Classify a single image and print top-5 predictions"
+ [classifier input-image]
+ (let [image (infer/load-image-from-file input-image)
+ topk 5
+ [predictions] (infer/classify-image classifier image topk)]
+ predictions))
+
+(defn classify-images-in-dir
+ "Classify all jpg images in the directory"
+ [classifier input-dir]
+ (let [batch-size 20
+ image-file-batches (->> input-dir
+ io/file
+ file-seq
+ (filter #(.isFile %))
+ (filter #(re-matches #".*\.jpg$" (.getPath %)))
+ (mapv #(.getPath %))
+ (partition-all batch-size))]
+ (apply
+ concat
+ (for [image-files image-file-batches]
+ (let [image-batch (infer/load-image-paths image-files)
+ topk 5]
+ (infer/classify-image-batch classifier image-batch topk))))))
+
+(defn run-classifier
+ "Runs an image classifier based on options provided"
+ [options]
+ (let [{:keys [model-path-prefix input-image input-dir]} options
+ descriptors [{:name "data"
+ :shape [1 3 224 224]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)
+ classifier (infer/create-image-classifier
+ factory {:contexts [(context/default-context)]})]
+ (println "Classifying a single image")
+ (print-predictions (classify-single-image classifier input-image))
+ (println "Classifying images in a directory")
+ (doseq [predictions (classify-images-in-dir classifier input-dir)]
+ (print-predictions predictions))))
+
+(defn -main
+ [& args]
+ (let [{:keys [options summary errors] :as opts}
+ (parse-opts args cli-options)]
+ (cond
+ (:help options) (println summary)
+ (some? errors) (println (join "\n" errors))
+ :else (run-classifier options))))
diff --git a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
new file mode 100644
index 0000000..5b3e08d
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
@@ -0,0 +1,69 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns infer.imageclassifier-example-test
+ (:require [infer.imageclassifier-example :refer [classify-single-image
+ classify-images-in-dir]]
+ [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [clojure.java.io :as io]
+ [clojure.java.shell :refer [sh]]
+ [clojure.test :refer :all]))
+
+(def model-dir "models/")
+(def image-dir "images/")
+(def model-path-prefix (str model-dir "resnet-18/resnet-18"))
+(def image-file (str image-dir "kitten.jpg"))
+
+(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
+ (sh "./scripts/get_resnet_18_data.sh"))
+
+(defn create-classifier []
+ (let [descriptors [{:name "data"
+ :shape [1 3 224 224]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)]
+ (infer/create-image-classifier factory)))
+
+(deftest test-single-classification
+ (let [classifier (create-classifier)
+ predictions (classify-single-image classifier image-file)]
+ (is (some? predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(float? (second %)) predictions))
+ (is (every? #(< 0 (second %) 1) predictions))
+ (is (= ["n02123159 tiger cat"
+ "n02124075 Egyptian cat"
+ "n02123045 tabby, tabby cat"
+ "n02127052 lynx, catamount"
+ "n02128757 snow leopard, ounce, Panthera uncia"]
+ (map first predictions)))))
+
+(deftest test-batch-classification
+ (let [classifier (create-classifier)
+ batch-predictions (classify-images-in-dir classifier image-dir)
+ predictions (first batch-predictions)]
+ (is (some? batch-predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(float? (second %)) predictions))
+ (is (every? #(< 0 (second %) 1) predictions))))
diff --git a/contrib/clojure-package/examples/infer/objectdetector/.gitignore b/contrib/clojure-package/examples/infer/objectdetector/.gitignore
new file mode 100644
index 0000000..35491f1
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/objectdetector/.gitignore
@@ -0,0 +1,12 @@
+/target
+/classes
+/checkouts
+/images
+pom.xml
+pom.xml.asc
+*.jar
+*.class
+/.lein-*
+/.nrepl-port
+.hgignore
+.hg/
diff --git a/contrib/clojure-package/examples/infer/objectdetector/README.md b/contrib/clojure-package/examples/infer/objectdetector/README.md
new file mode 100644
index 0000000..921c53e
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/objectdetector/README.md
@@ -0,0 +1,24 @@
+# objectdetector
+
+Run object detection on images using clojure infer package.
+
+## Installation
+
+Before you run this example, make sure that you have the clojure package installed.
+In the main clojure package directory, do `lein install`. Then you can run
+`lein install` in this directory.
+
+## Usage
+
+```
+$ chmod +x scripts/get_ssd_data.sh
+$ ./scripts/get_ssd_data.sh
+$
+$ lein run -- --help
+$ lein run -- -m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/
+$
+$ lein uberjar
+$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar --help
+$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar \
+ -m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/
+```
diff --git a/contrib/clojure-package/examples/infer/objectdetector/project.clj b/contrib/clojure-package/examples/infer/objectdetector/project.clj
new file mode 100644
index 0000000..4501f14
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/objectdetector/project.clj
@@ -0,0 +1,25 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(defproject objectdetector "0.1.0-SNAPSHOT"
+ :description "Object detection using infer with MXNet"
+ :plugins [[lein-cljfmt "0.5.7"]]
+ :dependencies [[org.clojure/clojure "1.9.0"]
+ [org.clojure/tools.cli "0.4.1"]
+ [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
+ :main ^:skip-aot infer.objectdetector-example
+ :profiles {:uberjar {:aot :all}})
diff --git a/contrib/clojure-package/examples/infer/objectdetector/scripts/get_ssd_data.sh b/contrib/clojure-package/examples/infer/objectdetector/scripts/get_ssd_data.sh
new file mode 100755
index 0000000..06440a2
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/objectdetector/scripts/get_ssd_data.sh
@@ -0,0 +1,49 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)
+
+data_path=$MXNET_ROOT/models/resnet50_ssd
+
+image_path=$MXNET_ROOT/images
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+fi
+
+if [ ! -d "$image_path" ]; then
+ mkdir -p "$image_path"
+fi
+
+if [ ! -f "$data_path/resnet50_ssd_model-0000.params" ]; then
+ wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json -P $data_path
+ wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params -P $data_path
+ wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/synset.txt -P $data_path
+fi
+
+if [ ! -f "$image_path/000001.jpg" ]; then
+ cd $image_path
+ wget https://cloud.githubusercontent.com/assets/3307514/20012566/cbb53c76-a27d-11e6-9aaa-91939c9a1cd5.jpg -O 000001.jpg
+ wget https://cloud.githubusercontent.com/assets/3307514/20012567/cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg -O dog.jpg
+ wget https://cloud.githubusercontent.com/assets/3307514/20012563/cbb41382-a27d-11e6-92a9-18dab4fd1ad3.jpg -O person.jpg
+fi
+
diff --git a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
new file mode 100644
index 0000000..53172f0
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj
@@ -0,0 +1,121 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns infer.objectdetector-example
+ (:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [clojure.java.io :as io]
+ [clojure.string :refer [join]]
+ [clojure.tools.cli :refer [parse-opts]])
+ (:gen-class))
+
+(defn check-valid-dir
+ "Check that the input directory exists"
+ [input-dir]
+ (let [dir (io/file input-dir)]
+ (and
+ (.exists dir)
+ (.isDirectory dir))))
+
+(defn check-valid-file
+ "Check that the file exists"
+ [input-file]
+ (.exists (io/file input-file)))
+
+(def cli-options
+ [["-m" "--model-path-prefix PREFIX" "Model path prefix"
+ :default "models/resnet50_ssd/resnet50_ssd_model"
+ :validate [#(check-valid-file (str % "-symbol.json"))
+ "Model path prefix is invalid"]]
+ ["-i" "--input-image IMAGE" "Input image"
+ :default "images/dog.jpg"
+ :validate [check-valid-file "Input file not found"]]
+ ["-d" "--input-dir IMAGE_DIR" "Input directory"
+ :default "images/"
+ :validate [check-valid-dir "Input directory not found"]]
+ ["-h" "--help"]])
+
+(defn print-predictions
+ "Print image detector predictions for the given input file"
+ [predictions width height]
+ (println (apply str (repeat 80 "=")))
+ (doseq [[label prob-and-bounds] predictions]
+ (println (format
+ "Class: %s Prob=%.5f Coords=(%.3f, %.3f, %.3f, %.3f)"
+ label
+ (aget prob-and-bounds 0)
+ (* (aget prob-and-bounds 1) width)
+ (* (aget prob-and-bounds 2) height)
+ (* (aget prob-and-bounds 3) width)
+ (* (aget prob-and-bounds 4) height))))
+ (println (apply str (repeat 80 "="))))
+
+(defn detect-single-image
+ "Detect objects in a single image and print top-5 predictions"
+ [detector input-image]
+ (let [image (infer/load-image-from-file input-image)
+ topk 5
+ [predictions] (infer/detect-objects detector image topk)]
+ predictions))
+
+(defn detect-images-in-dir
+ "Detect objects in all jpg images in the directory"
+ [detector input-dir]
+ (let [batch-size 20
+ image-file-batches (->> input-dir
+ io/file
+ file-seq
+ (filter #(.isFile %))
+ (filter #(re-matches #".*\.jpg$" (.getPath %)))
+ (mapv #(.getPath %))
+ (partition-all batch-size))]
+ (apply
+ concat
+ (for [image-files image-file-batches]
+ (let [image-batch (infer/load-image-paths image-files)
+ topk 5]
+ (infer/detect-objects-batch detector image-batch topk))))))
+
+(defn run-detector
+ "Runs an image detector based on options provided"
+ [options]
+ (let [{:keys [model-path-prefix input-image input-dir
+ device device-id]} options
+ width 512 height 512
+ descriptors [{:name "data"
+ :shape [1 3 height width]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)
+ detector (infer/create-object-detector
+ factory
+ {:contexts [(context/default-context)]})]
+ (println "Object detection on a single image")
+ (print-predictions (detect-single-image detector input-image) width height)
+ (println "Object detection on images in a directory")
+ (doseq [predictions (detect-images-in-dir detector input-dir)]
+ (print-predictions predictions width height))))
+
+(defn -main
+ [& args]
+ (let [{:keys [options summary errors] :as opts}
+ (parse-opts args cli-options)]
+ (cond
+ (:help options) (println summary)
+ (some? errors) (println (join "\n" errors))
+ :else (run-detector options))))
diff --git a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
new file mode 100644
index 0000000..90ed02f
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj
@@ -0,0 +1,65 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns infer.objectdetector-example-test
+ (:require [infer.objectdetector-example :refer [detect-single-image
+ detect-images-in-dir]]
+ [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [clojure.java.io :as io]
+ [clojure.java.shell :refer [sh]]
+ [clojure.test :refer :all]))
+
+(def model-dir "models/")
+(def image-dir "images/")
+(def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
+(def image-file (str image-dir "dog.jpg"))
+
+(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
+ (sh "./scripts/get_ssd_data.sh"))
+
+(defn create-detector []
+ (let [descriptors [{:name "data"
+ :shape [1 3 512 512]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)]
+ (infer/create-object-detector factory)))
+
+(deftest test-single-detection
+ (let [detector (create-detector)
+ predictions (detect-single-image detector image-file)]
+ (is (some? predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(= 5 (count (second %))) predictions))
+ (is (every? #(< 0 (first (second %)) 1) predictions))
+ (is (= ["car" "bicycle" "dog" "bicycle" "person"]
+ (map first predictions)))))
+
+(deftest test-batch-detection
+ (let [detector (create-detector)
+ batch-predictions (detect-images-in-dir detector image-dir)
+ predictions (first batch-predictions)]
+ (is (some? batch-predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(= 5 (count (second %))) predictions))
+ (is (every? #(< 0 (first (second %)) 1) predictions))))
diff --git a/contrib/clojure-package/examples/infer/predictor/.gitignore b/contrib/clojure-package/examples/infer/predictor/.gitignore
new file mode 100644
index 0000000..35491f1
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/predictor/.gitignore
@@ -0,0 +1,12 @@
+/target
+/classes
+/checkouts
+/images
+pom.xml
+pom.xml.asc
+*.jar
+*.class
+/.lein-*
+/.nrepl-port
+.hgignore
+.hg/
diff --git a/contrib/clojure-package/examples/infer/predictor/README.md b/contrib/clojure-package/examples/infer/predictor/README.md
new file mode 100644
index 0000000..9ca71cf
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/predictor/README.md
@@ -0,0 +1,24 @@
+# predictor
+
+Run model prediction using clojure infer package.
+
+## Installation
+
+Before you run this example, make sure that you have the clojure package installed.
+In the main clojure package directory, do `lein install`. Then you can run
+`lein install` in this directory.
+
+## Usage
+
+```
+$ chmod +x scripts/get_resnet_18_data.sh
+$ ./scripts/get_resnet_18_data.sh
+$
+$ lein run -- --help
+$ lein run -- -m models/resnet-18/resnet-18 -i images/kitten.jpg
+$
+$ lein uberjar
+$ java -jar target/predictor-0.1.0-SNAPSHOT-standalone.jar --help
+$ java -jar target/predictor-0.1.0-SNAPSHOT-standalone.jar \
+ -m models/resnet-18/resnet-18 -i images/kitten.jpg
+```
diff --git a/contrib/clojure-package/examples/infer/predictor/project.clj b/contrib/clojure-package/examples/infer/predictor/project.clj
new file mode 100644
index 0000000..0bd1eae
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/predictor/project.clj
@@ -0,0 +1,25 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(defproject predictor "0.1.0-SNAPSHOT"
+ :description "Model prediction using infer with MXNet"
+ :plugins [[lein-cljfmt "0.5.7"]]
+ :dependencies [[org.clojure/clojure "1.9.0"]
+ [org.clojure/tools.cli "0.4.1"]
+ [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
+ :main ^:skip-aot infer.predictor-example
+ :profiles {:uberjar {:aot :all}})
diff --git a/contrib/clojure-package/integration-tests.sh b/contrib/clojure-package/examples/infer/predictor/scripts/get_resnet_18_data.sh
similarity index 51%
copy from contrib/clojure-package/integration-tests.sh
copy to contrib/clojure-package/examples/infer/predictor/scripts/get_resnet_18_data.sh
index 3297fdc..cf85355 100755
--- a/contrib/clojure-package/integration-tests.sh
+++ b/contrib/clojure-package/examples/infer/predictor/scripts/get_resnet_18_data.sh
@@ -1,4 +1,5 @@
#!/bin/bash
+
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -18,11 +19,26 @@
set -evx
-MXNET_HOME=${PWD}
-EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
-#cd ${MXNET_HOME}/contrib/clojure-package
-#lein test
-#lein cloverage --codecov
-for i in `find ${EXAMPLES_HOME} -name test` ; do
-cd ${i} && lein test
-done
+MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)
+
+data_path=$MXNET_ROOT/models/resnet-18/
+
+image_path=$MXNET_ROOT/images/
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+fi
+
+if [ ! -d "$image_path" ]; then
+ mkdir -p "$image_path"
+fi
+
+if [ ! -f "$data_path/resnet-18-0000.params" ]; then
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $data_path
+fi
+
+if [ ! -f "$image_path/kitten.jpg" ]; then
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
+fi
diff --git a/contrib/clojure-package/integration-tests.sh b/contrib/clojure-package/examples/infer/predictor/scripts/get_resnet_data.sh
similarity index 50%
copy from contrib/clojure-package/integration-tests.sh
copy to contrib/clojure-package/examples/infer/predictor/scripts/get_resnet_data.sh
index 3297fdc..fcef59b 100755
--- a/contrib/clojure-package/integration-tests.sh
+++ b/contrib/clojure-package/examples/infer/predictor/scripts/get_resnet_data.sh
@@ -1,4 +1,5 @@
#!/bin/bash
+
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,13 +17,28 @@
# specific language governing permissions and limitations
# under the License.
-set -evx
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)
+
+data_path=$MXNET_ROOT/models/resnet-152/
+
+image_path=$MXNET_ROOT/images/
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+fi
+
+if [ ! -d "$image_path" ]; then
+ mkdir -p "$image_path"
+fi
+
+if [ ! -f "$data_path/resnet-152-0000.params" ]; then
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-0000.params -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-symbol.json -P $data_path
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/synset.txt -P $data_path
+fi
-MXNET_HOME=${PWD}
-EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
-#cd ${MXNET_HOME}/contrib/clojure-package
-#lein test
-#lein cloverage --codecov
-for i in `find ${EXAMPLES_HOME} -name test` ; do
-cd ${i} && lein test
-done
+if [ ! -f "$image_path/kitten.jpg" ]; then
+ wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
+fi
diff --git a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
new file mode 100644
index 0000000..4989641
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
@@ -0,0 +1,101 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns infer.predictor-example
+ (:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.image :as image]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]
+ [clojure.java.io :as io]
+ [clojure.string :refer [join split]]
+ [clojure.tools.cli :refer [parse-opts]])
+ (:gen-class))
+
+(defn check-valid-file
+ "Check that the file exists"
+ [input-file]
+ (.exists (io/file input-file)))
+
+(def cli-options
+ [["-m" "--model-path-prefix PREFIX" "Model path prefix"
+ :default "models/resnet-18/resnet-18"
+ :validate [#(check-valid-file (str % "-symbol.json"))
+ "Model path prefix is invalid"]]
+ ["-i" "--input-image IMAGE" "Image path"
+ :default "images/kitten.jpg"
+ :validate [check-valid-file "Input image path not found"]]
+ ["-h" "--help"]])
+
+(defn print-prediction
+ [prediction]
+ (println (apply str (repeat 80 "=")))
+ (println prediction)
+ (println (apply str (repeat 80 "="))))
+
+(defn preprocess
+ "Preprocesses image to make it ready for prediction"
+ [image-path width height]
+ (-> image-path
+ infer/load-image-from-file
+ (infer/reshape-image width height)
+ (infer/buffered-image-to-pixels [3 width height])
+ (ndarray/expand-dims 0)))
+
+(defn do-inference
+ "Run inference using given predictor"
+ [predictor image]
+ (let [[predictions] (infer/predict-with-ndarray predictor [image])]
+ predictions))
+
+(defn postprocess
+ [model-path-prefix predictions]
+ (let [synset-file (-> model-path-prefix
+ io/file
+ (.getParent)
+ (io/file "synset.txt"))
+ synset-names (split (slurp synset-file) #"\n")
+ [max-idx] (ndarray/->int-vec (ndarray/argmax predictions 1))]
+ (synset-names max-idx)))
+
+(defn run-predictor
+ "Runs an image classifier based on options provided"
+ [options]
+ (let [{:keys [model-path-prefix input-image]} options
+ width 224
+ height 224
+ descriptors [{:name "data"
+ :shape [1 3 height width]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)
+ predictor (infer/create-predictor
+ factory
+ {:contexts [(context/default-context)]})
+ image-ndarray (preprocess input-image width height)
+ predictions (do-inference predictor image-ndarray)
+ best-prediction (postprocess model-path-prefix predictions)]
+ (print-prediction best-prediction)))
+
+(defn -main
+ [& args]
+ (let [{:keys [options summary errors] :as opts}
+ (parse-opts args cli-options)]
+ (cond
+ (:help options) (println summary)
+ (some? errors) (println (join "\n" errors))
+ :else (run-predictor options))))
diff --git a/contrib/clojure-package/examples/infer/predictor/test/infer/predictor_example_test.clj b/contrib/clojure-package/examples/infer/predictor/test/infer/predictor_example_test.clj
new file mode 100644
index 0000000..02f826f
--- /dev/null
+++ b/contrib/clojure-package/examples/infer/predictor/test/infer/predictor_example_test.clj
@@ -0,0 +1,51 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns infer.predictor-example-test
+ (:require [infer.predictor-example :refer [preprocess
+ do-inference
+ postprocess]]
+ [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [clojure.java.io :as io]
+ [clojure.java.shell :refer [sh]]
+ [clojure.test :refer :all]))
+
+(def model-dir "models/")
+(def image-file "images/kitten.jpg")
+(def model-path-prefix (str model-dir "resnet-18/resnet-18"))
+(def width 224)
+(def height 224)
+
+(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
+ (sh "./scripts/get_resnet_18_data.sh"))
+
+(defn create-predictor []
+ (let [descriptors [{:name "data"
+ :shape [1 3 height width]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)]
+ (infer/create-predictor factory)))
+
+(deftest predictor-test
+ (let [predictor (create-predictor)
+ image-ndarray (preprocess image-file width height)
+ predictions (do-inference predictor image-ndarray)
+ best-prediction (postprocess model-path-prefix predictions)]
+ (is (= "n02123159 tiger cat" best-prediction))))
diff --git a/contrib/clojure-package/integration-tests.sh b/contrib/clojure-package/integration-tests.sh
index 3297fdc..6e58687 100755
--- a/contrib/clojure-package/integration-tests.sh
+++ b/contrib/clojure-package/integration-tests.sh
@@ -18,11 +18,11 @@
set -evx
-MXNET_HOME=${PWD}
+MXNET_HOME=$(cd "$(dirname $0)/../.."; pwd)
EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
#cd ${MXNET_HOME}/contrib/clojure-package
#lein test
#lein cloverage --codecov
-for i in `find ${EXAMPLES_HOME} -name test` ; do
-cd ${i} && lein test
+for test_dir in `find ${EXAMPLES_HOME} -name test` ; do
+ cd ${test_dir} && lein test
done
diff --git a/contrib/clojure-package/integration-tests.sh b/contrib/clojure-package/scripts/infer/get_resnet_18_data.sh
similarity index 54%
copy from contrib/clojure-package/integration-tests.sh
copy to contrib/clojure-package/scripts/infer/get_resnet_18_data.sh
index 3297fdc..601f362 100755
--- a/contrib/clojure-package/integration-tests.sh
+++ b/contrib/clojure-package/scripts/infer/get_resnet_18_data.sh
@@ -1,4 +1,5 @@
#!/bin/bash
+
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -18,11 +19,20 @@
set -evx
-MXNET_HOME=${PWD}
-EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
-#cd ${MXNET_HOME}/contrib/clojure-package
-#lein test
-#lein cloverage --codecov
-for i in `find ${EXAMPLES_HOME} -name test` ; do
-cd ${i} && lein test
-done
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME/data"
+else
+ MXNET_ROOT=$(cd "$(dirname $0)/../.."; pwd)
+ data_path="$MXNET_ROOT/data"
+fi
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+fi
+
+resnet_18_data_path="$data_path/resnet-18"
+if [ ! -f "$resnet_18_data_path/resnet-18-0000.params" ]; then
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $resnet_18_data_path
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $resnet_18_data_path
+ wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $resnet_18_data_path
+fi
diff --git a/contrib/clojure-package/integration-tests.sh b/contrib/clojure-package/scripts/infer/get_ssd_data.sh
similarity index 54%
copy from contrib/clojure-package/integration-tests.sh
copy to contrib/clojure-package/scripts/infer/get_ssd_data.sh
index 3297fdc..96e27a1 100755
--- a/contrib/clojure-package/integration-tests.sh
+++ b/contrib/clojure-package/scripts/infer/get_ssd_data.sh
@@ -1,4 +1,5 @@
#!/bin/bash
+
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,13 +17,23 @@
# specific language governing permissions and limitations
# under the License.
+
set -evx
-MXNET_HOME=${PWD}
-EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
-#cd ${MXNET_HOME}/contrib/clojure-package
-#lein test
-#lein cloverage --codecov
-for i in `find ${EXAMPLES_HOME} -name test` ; do
-cd ${i} && lein test
-done
+if [ ! -z "$MXNET_HOME" ]; then
+ data_path="$MXNET_HOME/data"
+else
+ MXNET_ROOT=$(cd "$(dirname $0)/../.."; pwd)
+ data_path="$MXNET_ROOT/data"
+fi
+
+if [ ! -d "$data_path" ]; then
+ mkdir -p "$data_path"
+fi
+
+resnet50_ssd_data_path="$data_path/resnet50_ssd"
+if [ ! -f "$resnet50_ssd_data_path/resnet50_ssd_model-0000.params" ]; then
+ wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json -P $resnet50_ssd_data_path
+ wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params -P $resnet50_ssd_data_path
+ wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/synset.txt -P $resnet50_ssd_data_path
+fi
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/image.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/image.clj
index 6e726eb..e2e87ed 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/image.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/image.clj
@@ -62,8 +62,8 @@
(util/validate! ::optional-color-flag color-flag "Invalid color flag")
(util/validate! ::optional-to-rgb to-rgb "Invalid conversion flag")
(util/validate! ::output output "Invalid output")
- (Image/imRead
- filename
+ (Image/imRead
+ filename
($/option color-flag)
($/option to-rgb)
($/option output)))
@@ -89,7 +89,7 @@
(defn apply-border
"Pad image border"
- ([input top bottom left right
+ ([input top bottom left right
{:keys [fill-type value values output]
:or {fill-type nil value nil values nil output nil}
:as opts}]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
new file mode 100644
index 0000000..b2b23da
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -0,0 +1,353 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.infer
+ (:refer-clojure :exclude [type])
+ (:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.io :as mx-io]
+ [org.apache.clojure-mxnet.shape :as shape]
+ [org.apache.clojure-mxnet.util :as util]
+ [clojure.spec.alpha :as s])
+ (:import (java.awt.image BufferedImage)
+ (org.apache.mxnet NDArray)
+ (org.apache.mxnet.infer Classifier ImageClassifier
+ ObjectDetector Predictor)))
+
+(s/def ::predictor #(instance? Predictor %))
+(s/def ::classifier #(instance? Classifier %))
+(s/def ::image-classifier #(instance? ImageClassifier %))
+(s/def ::object-detector #(instance? ObjectDetector %))
+
+(defrecord WrappedPredictor [predictor])
+(defrecord WrappedClassifier [classifier])
+(defrecord WrappedImageClassifier [image-classifier])
+(defrecord WrappedObjectDetector [object-detector])
+
+(s/def ::ndarray #(instance? NDArray %))
+(s/def ::float-array (s/and #(.isArray (class %)) #(every? float? %)))
+(s/def ::vec-of-float-arrays (s/coll-of ::float-array :kind vector?))
+(s/def ::vec-of-ndarrays (s/coll-of ::ndarray :kind vector?))
+
+(s/def ::wrapped-predictor (s/keys :req-un [::predictor]))
+(s/def ::wrapped-classifier (s/keys :req-un [::classifier]))
+(s/def ::wrapped-image-classifier (s/keys :req-un [::image-classifier]))
+(s/def ::wrapped-detector (s/keys :req-un [::object-detector]))
+
+(defprotocol APredictor
+ (predict [wrapped-predictor inputs])
+ (predict-with-ndarray [wrapped-predictor input-arrays]))
+
+(defprotocol AClassifier
+ (classify
+ [wrapped-classifier inputs]
+ [wrapped-classifier inputs topk])
+ (classify-with-ndarray
+ [wrapped-classifier inputs]
+ [wrapped-classifier inputs topk]))
+
+(defprotocol AImageClassifier
+ (classify-image
+ [wrapped-image-classifier image]
+ [wrapped-image-classifier image topk])
+ (classify-image-batch
+ [wrapped-image-classifier images]
+ [wrapped-image-classifier images topk]))
+
+(defprotocol AObjectDetector
+ (detect-objects
+ [wrapped-detector image]
+ [wrapped-detector image topk])
+ (detect-objects-batch
+ [wrapped-detector images]
+ [wrapped-detector images topk])
+ (detect-objects-with-ndarrays
+ [wrapped-detector input-arrays]
+ [wrapped-detector input-arrays topk]))
+
+(extend-protocol APredictor
+ WrappedPredictor
+ (predict [wrapped-predictor inputs]
+ (util/validate! ::wrapped-predictor wrapped-predictor
+ "Invalid predictor")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (util/coerce-return-recursive
+ (.predict (:predictor wrapped-predictor)
+ (util/vec->indexed-seq inputs))))
+ (predict-with-ndarray [wrapped-predictor input-arrays]
+ (util/validate! ::wrapped-predictor wrapped-predictor
+ "Invalid predictor")
+ (util/validate! ::vec-of-ndarrays input-arrays
+ "Invalid input arrays")
+ (util/coerce-return-recursive
+ (.predictWithNDArray (:predictor wrapped-predictor)
+ (util/vec->indexed-seq input-arrays)))))
+
+(s/def ::nil-or-int (s/nilable int?))
+
+(extend-protocol AClassifier
+ WrappedClassifier
+ (classify [wrapped-classifier inputs]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (classify wrapped-classifier inputs nil))
+ (classify [wrapped-classifier inputs topk]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classify (:classifier wrapped-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk))))
+ (classify-with-ndarray [wrapped-classifier inputs]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-ndarrays inputs
+ "Invalid inputs")
+ (classify-with-ndarray wrapped-classifier inputs nil))
+ (classify-with-ndarray [wrapped-classifier inputs topk]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-ndarrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classifyWithNDArray (:classifier wrapped-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk))))
+ WrappedImageClassifier
+ (classify [wrapped-image-classifier inputs]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (classify wrapped-image-classifier inputs nil))
+ (classify [wrapped-image-classifier inputs topk]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classify (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk))))
+ (classify-with-ndarray [wrapped-image-classifier inputs]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-ndarrays inputs
+ "Invalid inputs")
+ (classify-with-ndarray wrapped-image-classifier inputs nil))
+ (classify-with-ndarray [wrapped-image-classifier inputs topk]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-ndarrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk)))))
+
+(s/def ::image #(instance? BufferedImage %))
+
+(extend-protocol AImageClassifier
+ WrappedImageClassifier
+ (classify-image [wrapped-image-classifier image]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::image image "Invalid image")
+ (classify-image wrapped-image-classifier image nil))
+ (classify-image [wrapped-image-classifier image topk]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::image image "Invalid image")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classifyImage (:image-classifier wrapped-image-classifier)
+ image
+ (util/->int-option topk))))
+ (classify-image-batch [wrapped-image-classifier images]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (classify-image-batch wrapped-image-classifier images nil))
+ (classify-image-batch [wrapped-image-classifier images topk]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+ images
+ (util/->int-option topk)))))
+
+(extend-protocol AObjectDetector
+ WrappedObjectDetector
+ (detect-objects [wrapped-detector image]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::image image "Invalid image")
+ (detect-objects wrapped-detector image nil))
+ (detect-objects [wrapped-detector image topk]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::image image "Invalid image")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.imageObjectDetect (:object-detector wrapped-detector)
+ image
+ (util/->int-option topk))))
+ (detect-objects-batch [wrapped-detector images]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (detect-objects-batch wrapped-detector images nil))
+ (detect-objects-batch [wrapped-detector images topk]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.imageBatchObjectDetect (:object-detector wrapped-detector)
+ images
+ (util/->int-option topk))))
+ (detect-objects-with-ndarrays [wrapped-detector input-arrays]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::vec-of-ndarrays input-arrays
+ "Invalid inputs")
+ (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
+ (detect-objects-with-ndarrays [wrapped-detector input-arrays topk]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::vec-of-ndarrays input-arrays
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.objectDetectWithNDArray (:object-detector wrapped-detector)
+ (util/vec->indexed-seq input-arrays)
+ (util/->int-option topk)))))
+
+(defprotocol AInferenceFactory
+ (create-predictor [factory] [factory opts])
+ (create-classifier [factory] [factory opts])
+ (create-image-classifier [factory] [factory opts])
+ (create-object-detector [factory] [factory opts]))
+
+(defn convert-descriptors
+ [descriptors]
+ (util/vec->indexed-seq
+ (into [] (map mx-io/data-desc descriptors))))
+
+(defrecord InferenceFactory [model-path-prefix input-descriptors]
+ AInferenceFactory
+ (create-predictor
+ [factory]
+ (create-predictor factory {}))
+ (create-predictor
+ [factory opts]
+ (let [{:keys [contexts epoch]
+ :or {contexts [(context/cpu)] epoch 0}} opts]
+ (->WrappedPredictor
+ (new Predictor
+ model-path-prefix
+ (convert-descriptors input-descriptors)
+ (into-array contexts)
+ (util/->int-option epoch)))))
+ (create-classifier
+ [factory]
+ (create-classifier factory {}))
+ (create-classifier
+ [factory opts]
+ (let [{:keys [contexts epoch]
+ :or {contexts [(context/cpu)] epoch 0}} opts]
+ (->WrappedClassifier
+ (new Classifier
+ model-path-prefix
+ (convert-descriptors input-descriptors)
+ (into-array contexts)
+ (util/->int-option epoch)))))
+ (create-image-classifier
+ [factory]
+ (create-image-classifier factory {}))
+ (create-image-classifier
+ [factory opts]
+ (let [{:keys [contexts epoch]
+ :or {contexts [(context/cpu)] epoch 0}} opts]
+ (->WrappedImageClassifier
+ (new ImageClassifier
+ model-path-prefix
+ (convert-descriptors input-descriptors)
+ (into-array contexts)
+ (util/->int-option epoch)))))
+ (create-object-detector
+ [factory]
+ (create-object-detector factory {}))
+ (create-object-detector
+ [factory opts]
+ (let [{:keys [contexts epoch]
+ :or {contexts [(context/cpu)] epoch 0}} opts]
+ (->WrappedObjectDetector
+ (new ObjectDetector
+ model-path-prefix
+ (convert-descriptors input-descriptors)
+ (into-array contexts)
+ (util/->int-option epoch))))))
+
+(s/def ::model-path-prefix string?)
+(s/def ::input-descriptors (s/coll-of ::mx-io/data-desc))
+
+(defn model-factory
+ "Creates a factory that can be used to instantiate an image classifier
+ predictor or object detector"
+ [model-path-prefix input-descriptors]
+ (util/validate! ::model-path-prefix model-path-prefix
+ "Invalid model path prefix")
+ (util/validate! ::input-descriptors input-descriptors
+ "Invalid input descriptors")
+ (->InferenceFactory model-path-prefix input-descriptors))
+
+(defn reshape-image
+ "Reshape an image to a new shape"
+ [image width height]
+ (util/validate! ::image image "Invalid image")
+ (util/validate! int? width "Invalid width")
+ (util/validate! int? height "Invalid height")
+ (ImageClassifier/reshapeImage image width height))
+
+(defn buffered-image-to-pixels
+ "Convert input BufferedImage to NDArray of input shape"
+ [image input-shape-vec]
+ (util/validate! ::image image "Invalid image")
+ (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
+ (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec)))
+
+(s/def ::image-path string?)
+(s/def ::image-paths (s/coll-of ::image-path))
+
+(defn load-image-from-file
+ "Loads an input image given a file name"
+ [image-path]
+ (util/validate! ::image-path image-path "Invalid image path")
+ (ImageClassifier/loadImageFromFile image-path))
+
+(defn load-image-paths
+ "Loads images from a list of file names"
+ [image-paths]
+ (util/validate! ::image-paths image-paths "Invalid image paths")
+ (ImageClassifier/loadInputBatch (util/convert-vector image-paths)))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 6f22b0e..21e31ba 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -66,6 +66,9 @@
(defn ->option [v]
($ Option v))
+(defn ->int-option [v]
+ (->option (when v (int v))))
+
(defn option->value [opt]
($/view opt))
@@ -176,6 +179,12 @@
(instance? Tuple3 return-val) (tuple->vec return-val)
:else return-val))
+(defn coerce-return-recursive [return-val]
+ (let [coerced-val (coerce-return return-val)]
+ (if (vector? coerced-val)
+ (into [] (map coerce-return-recursive coerced-val))
+ coerced-val)))
+
(defmacro scala-fn
"Creates a scala fn from an anonymous clojure fn of the form (fn [x] body)"
[f]
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
new file mode 100644
index 0000000..9badfed
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -0,0 +1,68 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.infer.imageclassifier-test
+ (:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [clojure.java.io :as io]
+ [clojure.java.shell :refer [sh]]
+ [clojure.test :refer :all]))
+
+(def model-dir "data/")
+(def model-path-prefix (str model-dir "resnet-18/resnet-18"))
+
+(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
+ (sh "./scripts/infer/get_resnet_18_data.sh"))
+
+(defn create-classifier []
+ (let [descriptors [{:name "data"
+ :shape [1 3 224 224]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)]
+ (infer/create-image-classifier factory)))
+
+(deftest test-single-classification
+ (let [classifier (create-classifier)
+ image (infer/load-image-from-file "test/test-images/kitten.jpg")
+ [predictions] (infer/classify-image classifier image 5)]
+ (is (some? predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(float? (second %)) predictions))
+ (is (every? #(< 0 (second %) 1) predictions))
+ (is (= ["n02123159 tiger cat"
+ "n02124075 Egyptian cat"
+ "n02123045 tabby, tabby cat"
+ "n02127052 lynx, catamount"
+ "n02128757 snow leopard, ounce, Panthera uncia"]
+ (map first predictions)))))
+
+(deftest test-batch-classification
+ (let [classifier (create-classifier)
+ image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
+ "test/test-images/Pug-Cookie.jpg"])
+ batch-predictions (infer/classify-image-batch classifier image-batch 5)
+ predictions (first batch-predictions)]
+ (is (some? batch-predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(float? (second %)) predictions))
+ (is (every? #(< 0 (second %) 1) predictions))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
new file mode 100644
index 0000000..788a594
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -0,0 +1,63 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.infer.objectdetector-test
+ (:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [clojure.java.io :as io]
+ [clojure.java.shell :refer [sh]]
+ [clojure.test :refer :all]))
+
+(def model-dir "data/")
+(def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
+
+(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
+ (sh "./scripts/infer/get_ssd_data.sh"))
+
+(defn create-detector []
+ (let [descriptors [{:name "data"
+ :shape [1 3 512 512]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)]
+ (infer/create-object-detector factory)))
+
+(deftest test-single-detection
+ (let [detector (create-detector)
+ image (infer/load-image-from-file "test/test-images/kitten.jpg")
+ [predictions] (infer/detect-objects detector image 5)]
+ (is (some? predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(= 5 (count (second %))) predictions))
+ (is (every? #(< 0 (first (second %)) 1) predictions))
+ (is (= "cat" (first (first predictions))))))
+
+(deftest test-batch-detection
+ (let [detector (create-detector)
+ image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
+ "test/test-images/Pug-Cookie.jpg"])
+ batch-predictions (infer/detect-objects-batch detector image-batch 5)
+ predictions (first batch-predictions)]
+ (is (some? batch-predictions))
+ (is (= 5 (count predictions)))
+ (is (every? #(= 2 (count %)) predictions))
+ (is (every? #(string? (first %)) predictions))
+ (is (every? #(= 5 (count (second %))) predictions))
+ (is (every? #(< 0 (first (second %)) 1) predictions))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
new file mode 100644
index 0000000..0e7532b
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj
@@ -0,0 +1,59 @@
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.infer.predictor-test
+ (:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]
+ [org.apache.clojure-mxnet.shape :as shape]
+ [clojure.java.io :as io]
+ [clojure.java.shell :refer [sh]]
+ [clojure.string :refer [split]]
+ [clojure.test :refer :all]))
+
+(def model-dir "data/")
+(def model-path-prefix (str model-dir "resnet-18/resnet-18"))
+(def width 224)
+(def height 224)
+
+(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
+ (sh "./scripts/infer/get_resnet_18_data.sh"))
+
+(defn create-predictor []
+ (let [descriptors [{:name "data"
+ :shape [1 3 height width]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}]
+ factory (infer/model-factory model-path-prefix descriptors)]
+ (infer/create-predictor factory)))
+
+(deftest predictor-test
+ (let [predictor (create-predictor)
+ image-ndarray (-> "test/test-images/kitten.jpg"
+ infer/load-image-from-file
+ (infer/reshape-image width height)
+ (infer/buffered-image-to-pixels [3 width height])
+ (ndarray/expand-dims 0))
+ [predictions] (infer/predict-with-ndarray predictor [image-ndarray])
+ synset-file (-> (io/file model-path-prefix)
+ (.getParent)
+ (io/file "synset.txt"))
+ synset-names (split (slurp synset-file) #"\n")
+ [best-index] (ndarray/->int-vec (ndarray/argmax predictions 1))
+ best-prediction (synset-names best-index)]
+ (is (= "n02123159 tiger cat" best-prediction))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index ee77103..bd77a8a 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -54,6 +54,16 @@
(is (instance? Option x))
(is (= 1 (.get x)))))
+(deftest test->int-option
+ (let [x (util/->int-option 4.5)]
+ (is (instance? Option x))
+ (is (= 4 (.get x)))))
+
+(deftest test-empty->int-option
+ (let [x (util/->int-option nil)]
+ (is (instance? Option x))
+ (is (.isEmpty x))))
+
(deftest test-option->value
(is (= 2 (-> (util/->option 2)
(util/option->value)))))
diff --git a/contrib/clojure-package/test/test-images/Pug-Cookie.jpg b/contrib/clojure-package/test/test-images/Pug-Cookie.jpg
new file mode 100644
index 0000000..56f5dc1
Binary files /dev/null and b/contrib/clojure-package/test/test-images/Pug-Cookie.jpg differ
diff --git a/contrib/clojure-package/test/test-images/kitten.jpg b/contrib/clojure-package/test/test-images/kitten.jpg
new file mode 100644
index 0000000..ffcd2be
Binary files /dev/null and b/contrib/clojure-package/test/test-images/kitten.jpg differ