You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/10 00:53:41 UTC
[incubator-mxnet] branch master updated: Clojure example for fixed
label-width captcha recognition (#13769)
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d3bd5e7 Clojure example for fixed label-width captcha recognition (#13769)
d3bd5e7 is described below
commit d3bd5e7c9b0802899cc50ca0fcbe19b4c9f66048
Author: Kedar Bellare <ke...@gmail.com>
AuthorDate: Wed Jan 9 16:53:24 2019 -0800
Clojure example for fixed label-width captcha recognition (#13769)
* Clojure example for fixed label-width captcha recognition
* Update evaluation
* Better training and inference (w/ cleanup)
* Captcha generation for testing
* Make simple test work
* Add test and update README
* Add missing consts file
* Follow comments
---
.../clojure-package/examples/captcha/.gitignore | 3 +
contrib/clojure-package/examples/captcha/README.md | 61 ++++++++
.../examples/captcha/captcha_example.png | Bin 0 -> 9762 bytes
.../examples/captcha/gen_captcha.py | 40 ++++++
.../clojure-package/examples/captcha/get_data.sh | 32 +++++
.../clojure-package/examples/captcha/project.clj | 28 ++++
.../examples/captcha/src/captcha/consts.clj | 27 ++++
.../examples/captcha/src/captcha/infer_ocr.clj | 56 ++++++++
.../examples/captcha/src/captcha/train_ocr.clj | 156 +++++++++++++++++++++
.../captcha/test/captcha/train_ocr_test.clj | 119 ++++++++++++++++
10 files changed, 522 insertions(+)
diff --git a/contrib/clojure-package/examples/captcha/.gitignore b/contrib/clojure-package/examples/captcha/.gitignore
new file mode 100644
index 0000000..e1569bd
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/.gitignore
@@ -0,0 +1,3 @@
+/.lein-*
+/.nrepl-port
+images/*
diff --git a/contrib/clojure-package/examples/captcha/README.md b/contrib/clojure-package/examples/captcha/README.md
new file mode 100644
index 0000000..6b593b2
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/README.md
@@ -0,0 +1,61 @@
+# Captcha
+
+This is the clojure version of [captcha recognition](https://github.com/xlvector/learning-dl/tree/master/mxnet/ocr)
+example by xlvector and mirrors the R captcha example. It can be used as an
+example of multi-label training. For the following captcha example, we consider it as an
+image with 4 labels and train a CNN over the data set.
+
+![captcha example](captcha_example.png)
+
+## Installation
+
+Before you run this example, make sure that you have the clojure package
+installed. In the main clojure package directory, do `lein install`.
+Then you can run `lein install` in this directory.
+
+## Usage
+
+### Training
+
+First the OCR model needs to be trained based on [labeled data](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip).
+The training can be started using the following:
+```
+$ lein train [:cpu|:gpu] [num-devices]
+```
+This downloads the training/evaluation data using the `get_data.sh` script
+before starting training.
+
+It is possible that you will encounter some out-of-memory issues while training using :gpu on Ubuntu
+linux (18.04). However, the command `lein train` (training on one CPU) may resolve the issue.
+
+The training runs for 10 iterations by default and saves the model with the
+prefix `ocr-`. The model achieved an exact match accuracy of ~0.954 and
+~0.628 on training and validation data respectively.
+
+### Inference
+
+Once the model has been saved, it can be used for prediction. This can be done
+by running:
+```
+$ lein infer
+INFO MXNetJVM: Try loading mxnet-scala from native path.
+INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-gpu from native path.
+INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-cpu from native path.
+WARN MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path].
+WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis
+INFO org.apache.mxnet.infer.Predictor: Latency increased due to batchSize mismatch 8 vs 1
+WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis
+WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis
+CAPTCHA output: 6643
+INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/libmxnet.so
+INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/mxnet-scala
+INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865
+```
+The model runs on `captcha_example.png` by default.
+
+It can be run on other generated captcha images as well. The script
+`gen_captcha.py` generates random captcha images for length 4.
+Before running the python script, you will need to install the [captcha](https://pypi.org/project/captcha/)
+library using `pip3 install --user captcha`. The captcha images are generated
+in the `images/` folder and we can run the prediction using
+`lein infer images/7534.png`.
diff --git a/contrib/clojure-package/examples/captcha/captcha_example.png b/contrib/clojure-package/examples/captcha/captcha_example.png
new file mode 100644
index 0000000..09b84f7
Binary files /dev/null and b/contrib/clojure-package/examples/captcha/captcha_example.png differ
diff --git a/contrib/clojure-package/examples/captcha/gen_captcha.py b/contrib/clojure-package/examples/captcha/gen_captcha.py
new file mode 100755
index 0000000..43e0d26
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/gen_captcha.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from captcha.image import ImageCaptcha
+import os
+import random
+
+length = 4
+width = 160
+height = 60
+IMAGE_DIR = "images"
+
+
+def random_text():
+ return ''.join(str(random.randint(0, 9))
+ for _ in range(length))
+
+
+if __name__ == '__main__':
+ image = ImageCaptcha(width=width, height=height)
+ captcha_text = random_text()
+ if not os.path.exists(IMAGE_DIR):
+ os.makedirs(IMAGE_DIR)
+ image.write(captcha_text, os.path.join(IMAGE_DIR, captcha_text + ".png"))
diff --git a/contrib/clojure-package/examples/captcha/get_data.sh b/contrib/clojure-package/examples/captcha/get_data.sh
new file mode 100755
index 0000000..baa7f9e
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/get_data.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -evx
+
+EXAMPLE_ROOT=$(cd "$(dirname $0)"; pwd)
+
+data_path=$EXAMPLE_ROOT
+
+if [ ! -f "$data_path/captcha_example.zip" ]; then
+ wget https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip -P $data_path
+fi
+
+if [ ! -f "$data_path/captcha_example/captcha_train.rec" ]; then
+ unzip $data_path/captcha_example.zip -d $data_path
+fi
diff --git a/contrib/clojure-package/examples/captcha/project.clj b/contrib/clojure-package/examples/captcha/project.clj
new file mode 100644
index 0000000..fa37fec
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/project.clj
@@ -0,0 +1,28 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(defproject captcha "0.1.0-SNAPSHOT"
+ :description "Captcha recognition via multi-label classification"
+ :plugins [[lein-cljfmt "0.5.7"]]
+ :dependencies [[org.clojure/clojure "1.9.0"]
+ [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
+ :main ^:skip-aot captcha.train-ocr
+ :profiles {:train {:main captcha.train-ocr}
+ :infer {:main captcha.infer-ocr}
+ :uberjar {:aot :all}}
+ :aliases {"train" ["with-profile" "train" "run"]
+ "infer" ["with-profile" "infer" "run"]})
diff --git a/contrib/clojure-package/examples/captcha/src/captcha/consts.clj b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj
new file mode 100644
index 0000000..318e0d8
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj
@@ -0,0 +1,27 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.consts)
+
+(def batch-size 8)
+(def channels 3)
+(def height 30)
+(def width 80)
+(def data-shape [channels height width])
+(def num-labels 10)
+(def label-width 4)
+(def model-prefix "ocr")
diff --git a/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj
new file mode 100644
index 0000000..f6a648e
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj
@@ -0,0 +1,56 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.infer-ocr
+ (:require [captcha.consts :refer :all]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.infer :as infer]
+ [org.apache.clojure-mxnet.layout :as layout]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]))
+
+(defn create-predictor
+ []
+ (let [data-desc {:name "data"
+ :shape [batch-size channels height width]
+ :layout layout/NCHW
+ :dtype dtype/FLOAT32}
+ label-desc {:name "label"
+ :shape [batch-size label-width]
+ :layout layout/NT
+ :dtype dtype/FLOAT32}
+ factory (infer/model-factory model-prefix
+ [data-desc label-desc])]
+ (infer/create-predictor factory)))
+
+(defn -main
+ [& args]
+ (let [[filename] args
+ image-fname (or filename "captcha_example.png")
+ image-ndarray (-> image-fname
+ infer/load-image-from-file
+ (infer/reshape-image width height)
+ (infer/buffered-image-to-pixels [channels height width])
+ (ndarray/expand-dims 0))
+ label-ndarray (ndarray/zeros [1 label-width])
+ predictor (create-predictor)
+ predictions (-> (infer/predict-with-ndarray
+ predictor
+ [image-ndarray label-ndarray])
+ first
+ (ndarray/argmax 1)
+ ndarray/->vec)]
+ (println "CAPTCHA output:" (apply str (mapv int predictions)))))
diff --git a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj
new file mode 100644
index 0000000..91ec2ff
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj
@@ -0,0 +1,156 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.train-ocr
+ (:require [captcha.consts :refer :all]
+ [clojure.java.io :as io]
+ [clojure.java.shell :refer [sh]]
+ [org.apache.clojure-mxnet.callback :as callback]
+ [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.eval-metric :as eval-metric]
+ [org.apache.clojure-mxnet.initializer :as initializer]
+ [org.apache.clojure-mxnet.io :as mx-io]
+ [org.apache.clojure-mxnet.module :as m]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]
+ [org.apache.clojure-mxnet.optimizer :as optimizer]
+ [org.apache.clojure-mxnet.symbol :as sym])
+ (:gen-class))
+
+(when-not (.exists (io/file "captcha_example/captcha_train.lst"))
+ (sh "./get_data.sh"))
+
+(defonce train-data
+ (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_train.rec"
+ :path-imglist "captcha_example/captcha_train.lst"
+ :batch-size batch-size
+ :label-width label-width
+ :data-shape data-shape
+ :shuffle true
+ :seed 42}))
+
+(defonce eval-data
+ (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_test.rec"
+ :path-imglist "captcha_example/captcha_test.lst"
+ :batch-size batch-size
+ :label-width label-width
+ :data-shape data-shape}))
+
+(defn accuracy
+ [label pred & {:keys [by-character]
+ :or {by-character false} :as opts}]
+ (let [[nr nc] (ndarray/shape-vec label)
+ pred-context (ndarray/context pred)
+ label-t (-> label
+ ndarray/transpose
+ (ndarray/reshape [-1])
+ (ndarray/as-in-context pred-context))
+ pred-label (ndarray/argmax pred 1)
+ matches (ndarray/equal label-t pred-label)
+ [digit-matches] (-> matches
+ ndarray/sum
+ ndarray/->vec)
+ [complete-matches] (-> matches
+ (ndarray/reshape [nc nr])
+ (ndarray/sum 0)
+ (ndarray/equal label-width)
+ ndarray/sum
+ ndarray/->vec)]
+ (if by-character
+ (float (/ digit-matches nr nc))
+ (float (/ complete-matches nr)))))
+
+(defn get-data-symbol
+ []
+ (let [data (sym/variable "data")
+ ;; normalize the input pixels
+ scaled (sym/div (sym/- data 127) 128)
+
+ conv1 (sym/convolution {:data scaled :kernel [5 5] :num-filter 32})
+ pool1 (sym/pooling {:data conv1 :pool-type "max" :kernel [2 2] :stride [1 1]})
+ relu1 (sym/activation {:data pool1 :act-type "relu"})
+
+ conv2 (sym/convolution {:data relu1 :kernel [5 5] :num-filter 32})
+ pool2 (sym/pooling {:data conv2 :pool-type "avg" :kernel [2 2] :stride [1 1]})
+ relu2 (sym/activation {:data pool2 :act-type "relu"})
+
+ conv3 (sym/convolution {:data relu2 :kernel [3 3] :num-filter 32})
+ pool3 (sym/pooling {:data conv3 :pool-type "avg" :kernel [2 2] :stride [1 1]})
+ relu3 (sym/activation {:data pool3 :act-type "relu"})
+
+ conv4 (sym/convolution {:data relu3 :kernel [3 3] :num-filter 32})
+ pool4 (sym/pooling {:data conv4 :pool-type "avg" :kernel [2 2] :stride [1 1]})
+ relu4 (sym/activation {:data pool4 :act-type "relu"})
+
+ flattened (sym/flatten {:data relu4})
+ fc1 (sym/fully-connected {:data flattened :num-hidden 256})
+ fc21 (sym/fully-connected {:data fc1 :num-hidden num-labels})
+ fc22 (sym/fully-connected {:data fc1 :num-hidden num-labels})
+ fc23 (sym/fully-connected {:data fc1 :num-hidden num-labels})
+ fc24 (sym/fully-connected {:data fc1 :num-hidden num-labels})]
+ (sym/concat "concat" nil [fc21 fc22 fc23 fc24] {:dim 0})))
+
+(defn get-label-symbol
+ []
+ (as-> (sym/variable "label") label
+ (sym/transpose {:data label})
+ (sym/reshape {:data label :shape [-1]})))
+
+(defn create-captcha-net
+ []
+ (let [scores (get-data-symbol)
+ labels (get-label-symbol)]
+ (sym/softmax-output {:data scores :label labels})))
+
+(def optimizer
+ (optimizer/adam
+ {:learning-rate 0.0002
+ :wd 0.00001
+ :clip-gradient 10}))
+
+(defn train-ocr
+ [devs]
+ (println "Starting the captcha training ...")
+ (let [model (m/module
+ (create-captcha-net)
+ {:data-names ["data"] :label-names ["label"]
+ :contexts devs})]
+ (m/fit model {:train-data train-data
+ :eval-data eval-data
+ :num-epoch 10
+ :fit-params (m/fit-params
+ {:kvstore "local"
+ :batch-end-callback
+ (callback/speedometer batch-size 100)
+ :initializer
+ (initializer/xavier {:factor-type "in"
+ :magnitude 2.34})
+ :optimizer optimizer
+ :eval-metric (eval-metric/custom-metric
+ #(accuracy %1 %2)
+ "accuracy")})})
+ (println "Finished the fit")
+ model))
+
+(defn -main
+ [& args]
+ (let [[dev dev-num] args
+ num-devices (Integer/parseInt (or dev-num "1"))
+ devs (if (= dev ":gpu")
+ (mapv #(context/gpu %) (range num-devices))
+ (mapv #(context/cpu %) (range num-devices)))
+ model (train-ocr devs)]
+ (m/save-checkpoint model {:prefix model-prefix :epoch 0})))
diff --git a/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj
new file mode 100644
index 0000000..ab785f7
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj
@@ -0,0 +1,119 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.train-ocr-test
+ (:require [clojure.test :refer :all]
+ [captcha.consts :refer :all]
+ [captcha.train-ocr :refer :all]
+ [org.apache.clojure-mxnet.io :as mx-io]
+ [org.apache.clojure-mxnet.module :as m]
+ [org.apache.clojure-mxnet.ndarray :as ndarray]
+ [org.apache.clojure-mxnet.shape :as shape]
+ [org.apache.clojure-mxnet.util :as util]))
+
+(deftest test-consts
+ (is (= 8 batch-size))
+ (is (= [3 30 80] data-shape))
+ (is (= 4 label-width))
+ (is (= 10 num-labels)))
+
+(deftest test-labeled-data
+ (let [train-batch (mx-io/next train-data)
+ eval-batch (mx-io/next eval-data)
+ allowed-labels (into #{} (map float (range 10)))]
+ (is (= 8 (-> train-batch mx-io/batch-index count)))
+ (is (= 8 (-> eval-batch mx-io/batch-index count)))
+ (is (= [8 3 30 80] (-> train-batch
+ mx-io/batch-data
+ first
+ ndarray/shape-vec)))
+ (is (= [8 3 30 80] (-> eval-batch
+ mx-io/batch-data
+ first
+ ndarray/shape-vec)))
+ (is (every? #(<= 0 % 255) (-> train-batch
+ mx-io/batch-data
+ first
+ ndarray/->vec)))
+ (is (every? #(<= 0 % 255) (-> eval-batch
+ mx-io/batch-data
+ first
+ ndarray/->vec)))
+ (is (= [8 4] (-> train-batch
+ mx-io/batch-label
+ first
+ ndarray/shape-vec)))
+ (is (= [8 4] (-> eval-batch
+ mx-io/batch-label
+ first
+ ndarray/shape-vec)))
+ (is (every? allowed-labels (-> train-batch
+ mx-io/batch-label
+ first
+ ndarray/->vec)))
+ (is (every? allowed-labels (-> eval-batch
+ mx-io/batch-label
+ first
+ ndarray/->vec)))))
+
+(deftest test-model
+ (let [batch (mx-io/next train-data)
+ model (m/module (create-captcha-net)
+ {:data-names ["data"] :label-names ["label"]})
+ _ (m/bind model
+ {:data-shapes (mx-io/provide-data-desc train-data)
+ :label-shapes (mx-io/provide-label-desc train-data)})
+ _ (m/init-params model)
+ _ (m/forward-backward model batch)
+ output-shapes (-> model
+ m/output-shapes
+ util/coerce-return-recursive)
+ outputs (-> model
+ m/outputs-merged
+ first)
+ grads (->> model m/grad-arrays (map first))]
+ (is (= [["softmaxoutput0_output" (shape/->shape [8 10])]]
+ output-shapes))
+ (is (= [32 10] (-> outputs ndarray/shape-vec)))
+ (is (every? #(<= 0.0 % 1.0) (-> outputs ndarray/->vec)))
+ (is (= [[32 3 5 5] [32] ; convolution1 weights+bias
+ [32 32 5 5] [32] ; convolution2 weights+bias
+ [32 32 3 3] [32] ; convolution3 weights+bias
+ [32 32 3 3] [32] ; convolution4 weights+bias
+ [256 28672] [256] ; fully-connected1 weights+bias
+ [10 256] [10] ; 1st label scores
+ [10 256] [10] ; 2nd label scores
+ [10 256] [10] ; 3rd label scores
+ [10 256] [10]] ; 4th label scores
+ (map ndarray/shape-vec grads)))))
+
+(deftest test-accuracy
+ (let [labels (ndarray/array [1 2 3 4,
+ 5 6 7 8]
+ [2 4])
+ pred-labels (ndarray/array [1 0,
+ 2 6,
+ 3 0,
+ 4 8]
+ [8])
+ preds (ndarray/one-hot pred-labels 10)]
+ (is (float? (accuracy labels preds)))
+ (is (float? (accuracy labels preds :by-character false)))
+ (is (float? (accuracy labels preds :by-character true)))
+ (is (= 0.5 (accuracy labels preds)))
+ (is (= 0.5 (accuracy labels preds :by-character false)))
+ (is (= 0.75 (accuracy labels preds :by-character true)))))