You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/10 00:53:41 UTC

[incubator-mxnet] branch master updated: Clojure example for fixed label-width captcha recognition (#13769)

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d3bd5e7   Clojure example for fixed label-width captcha recognition  (#13769)
d3bd5e7 is described below

commit d3bd5e7c9b0802899cc50ca0fcbe19b4c9f66048
Author: Kedar Bellare <ke...@gmail.com>
AuthorDate: Wed Jan 9 16:53:24 2019 -0800

     Clojure example for fixed label-width captcha recognition  (#13769)
    
    * Clojure example for fixed label-width captcha recognition
    
    * Update evaluation
    
    * Better training and inference (w/ cleanup)
    
    * Captcha generation for testing
    
    * Make simple test work
    
    * Add test and update README
    
    * Add missing consts file
    
    * Follow comments
---
 .../clojure-package/examples/captcha/.gitignore    |   3 +
 contrib/clojure-package/examples/captcha/README.md |  61 ++++++++
 .../examples/captcha/captcha_example.png           | Bin 0 -> 9762 bytes
 .../examples/captcha/gen_captcha.py                |  40 ++++++
 .../clojure-package/examples/captcha/get_data.sh   |  32 +++++
 .../clojure-package/examples/captcha/project.clj   |  28 ++++
 .../examples/captcha/src/captcha/consts.clj        |  27 ++++
 .../examples/captcha/src/captcha/infer_ocr.clj     |  56 ++++++++
 .../examples/captcha/src/captcha/train_ocr.clj     | 156 +++++++++++++++++++++
 .../captcha/test/captcha/train_ocr_test.clj        | 119 ++++++++++++++++
 10 files changed, 522 insertions(+)

diff --git a/contrib/clojure-package/examples/captcha/.gitignore b/contrib/clojure-package/examples/captcha/.gitignore
new file mode 100644
index 0000000..e1569bd
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/.gitignore
@@ -0,0 +1,3 @@
+/.lein-*
+/.nrepl-port
+images/*
diff --git a/contrib/clojure-package/examples/captcha/README.md b/contrib/clojure-package/examples/captcha/README.md
new file mode 100644
index 0000000..6b593b2
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/README.md
@@ -0,0 +1,61 @@
+# Captcha
+
+This is the clojure version of [captcha recognition](https://github.com/xlvector/learning-dl/tree/master/mxnet/ocr)
+example by xlvector and mirrors the R captcha example. It can be used as an
+example of multi-label training. For the following captcha example, we consider it as an
+image with 4 labels and train a CNN over the data set.
+
+![captcha example](captcha_example.png)
+
+## Installation
+
+Before you run this example, make sure that you have the clojure package
+installed. In the main clojure package directory, do `lein install`.
+Then you can run `lein install` in this directory.
+
+## Usage
+
+### Training
+
+First the OCR model needs to be trained based on [labeled data](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip).
+The training can be started using the following:
+```
+$ lein train [:cpu|:gpu] [num-devices]
+```
+This downloads the training/evaluation data using the `get_data.sh` script
+before starting training.
+
+It is possible that you will encounter some out-of-memory issues while training using :gpu on Ubuntu
+linux (18.04). However, the command `lein train` (training on one CPU) may resolve the issue.
+
+The training runs for 10 iterations by default and saves the model with the
+prefix `ocr-`. The model achieved an exact match accuracy of ~0.954 and
+~0.628 on training and validation data respectively.
+
+### Inference
+
+Once the model has been saved, it can be used for prediction. This can be done
+by running:
+```
+$ lein infer
+INFO  MXNetJVM: Try loading mxnet-scala from native path.
+INFO  MXNetJVM: Try loading mxnet-scala-linux-x86_64-gpu from native path.
+INFO  MXNetJVM: Try loading mxnet-scala-linux-x86_64-cpu from native path.
+WARN  MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path].
+WARN  org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis
+INFO  org.apache.mxnet.infer.Predictor: Latency increased due to batchSize mismatch 8 vs 1
+WARN  org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis
+WARN  org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis
+CAPTCHA output: 6643
+INFO  org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/libmxnet.so
+INFO  org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/mxnet-scala
+INFO  org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865
+```
+The model runs on `captcha_example.png` by default.
+
+It can be run on other generated captcha images as well. The script
+`gen_captcha.py` generates random captcha images for length 4.
+Before running the python script, you will need to install the [captcha](https://pypi.org/project/captcha/)
+library using `pip3 install --user captcha`. The captcha images are generated
+in the `images/` folder and we can run the prediction using
+`lein infer images/7534.png`.
diff --git a/contrib/clojure-package/examples/captcha/captcha_example.png b/contrib/clojure-package/examples/captcha/captcha_example.png
new file mode 100644
index 0000000..09b84f7
Binary files /dev/null and b/contrib/clojure-package/examples/captcha/captcha_example.png differ
diff --git a/contrib/clojure-package/examples/captcha/gen_captcha.py b/contrib/clojure-package/examples/captcha/gen_captcha.py
new file mode 100755
index 0000000..43e0d26
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/gen_captcha.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from captcha.image import ImageCaptcha
+import os
+import random
+
+length = 4
+width = 160
+height = 60
+IMAGE_DIR = "images"
+
+
+def random_text():
+    return ''.join(str(random.randint(0, 9))
+                   for _ in range(length))
+
+
+if __name__ == '__main__':
+    image = ImageCaptcha(width=width, height=height)
+    captcha_text = random_text()
+    if not os.path.exists(IMAGE_DIR):
+        os.makedirs(IMAGE_DIR)
+    image.write(captcha_text, os.path.join(IMAGE_DIR, captcha_text + ".png"))
diff --git a/contrib/clojure-package/examples/captcha/get_data.sh b/contrib/clojure-package/examples/captcha/get_data.sh
new file mode 100755
index 0000000..baa7f9e
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/get_data.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -evx
+
+EXAMPLE_ROOT=$(cd "$(dirname $0)"; pwd)
+
+data_path=$EXAMPLE_ROOT
+
+if [ ! -f "$data_path/captcha_example.zip" ]; then
+  wget https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip -P $data_path
+fi
+
+if [ ! -f "$data_path/captcha_example/captcha_train.rec" ]; then
+  unzip $data_path/captcha_example.zip -d $data_path
+fi
diff --git a/contrib/clojure-package/examples/captcha/project.clj b/contrib/clojure-package/examples/captcha/project.clj
new file mode 100644
index 0000000..fa37fec
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/project.clj
@@ -0,0 +1,28 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(defproject captcha "0.1.0-SNAPSHOT"
+  :description "Captcha recognition via multi-label classification"
+  :plugins [[lein-cljfmt "0.5.7"]]
+  :dependencies [[org.clojure/clojure "1.9.0"]
+                 [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
+  :main ^:skip-aot captcha.train-ocr
+  :profiles {:train {:main captcha.train-ocr}
+             :infer {:main captcha.infer-ocr}
+             :uberjar {:aot :all}}
+  :aliases {"train" ["with-profile" "train" "run"]
+            "infer" ["with-profile" "infer" "run"]})
diff --git a/contrib/clojure-package/examples/captcha/src/captcha/consts.clj b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj
new file mode 100644
index 0000000..318e0d8
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj
@@ -0,0 +1,27 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.consts)
+
+(def batch-size 8)
+(def channels 3)
+(def height 30)
+(def width 80)
+(def data-shape [channels height width])
+(def num-labels 10)
+(def label-width 4)
+(def model-prefix "ocr")
diff --git a/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj
new file mode 100644
index 0000000..f6a648e
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj
@@ -0,0 +1,56 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.infer-ocr
+  (:require [captcha.consts :refer :all]
+            [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.infer :as infer]
+            [org.apache.clojure-mxnet.layout :as layout]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]))
+
+(defn create-predictor
+  []
+  (let [data-desc {:name "data"
+                   :shape [batch-size channels height width]
+                   :layout layout/NCHW
+                   :dtype dtype/FLOAT32}
+        label-desc {:name "label"
+                    :shape [batch-size label-width]
+                    :layout layout/NT
+                    :dtype dtype/FLOAT32}
+        factory (infer/model-factory model-prefix
+                                     [data-desc label-desc])]
+    (infer/create-predictor factory)))
+
+(defn -main
+  [& args]
+  (let [[filename] args
+        image-fname (or filename "captcha_example.png")
+        image-ndarray (-> image-fname
+                          infer/load-image-from-file
+                          (infer/reshape-image width height)
+                          (infer/buffered-image-to-pixels [channels height width])
+                          (ndarray/expand-dims 0))
+        label-ndarray (ndarray/zeros [1 label-width])
+        predictor (create-predictor)
+        predictions (-> (infer/predict-with-ndarray
+                         predictor
+                         [image-ndarray label-ndarray])
+                        first
+                        (ndarray/argmax 1)
+                        ndarray/->vec)]
+    (println "CAPTCHA output:" (apply str (mapv int predictions)))))
diff --git a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj
new file mode 100644
index 0000000..91ec2ff
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj
@@ -0,0 +1,156 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.train-ocr
+  (:require [captcha.consts :refer :all]
+            [clojure.java.io :as io]
+            [clojure.java.shell :refer [sh]]
+            [org.apache.clojure-mxnet.callback :as callback]
+            [org.apache.clojure-mxnet.context :as context]
+            [org.apache.clojure-mxnet.eval-metric :as eval-metric]
+            [org.apache.clojure-mxnet.initializer :as initializer]
+            [org.apache.clojure-mxnet.io :as mx-io]
+            [org.apache.clojure-mxnet.module :as m]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]
+            [org.apache.clojure-mxnet.optimizer :as optimizer]
+            [org.apache.clojure-mxnet.symbol :as sym])
+  (:gen-class))
+
+(when-not (.exists (io/file "captcha_example/captcha_train.lst"))
+  (sh "./get_data.sh"))
+
+(defonce train-data
+  (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_train.rec"
+                            :path-imglist "captcha_example/captcha_train.lst"
+                            :batch-size batch-size
+                            :label-width label-width
+                            :data-shape data-shape
+                            :shuffle true
+                            :seed 42}))
+
+(defonce eval-data
+  (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_test.rec"
+                            :path-imglist "captcha_example/captcha_test.lst"
+                            :batch-size batch-size
+                            :label-width label-width
+                            :data-shape data-shape}))
+
+(defn accuracy
+  [label pred & {:keys [by-character]
+                 :or {by-character false} :as opts}]
+  (let [[nr nc] (ndarray/shape-vec label)
+        pred-context (ndarray/context pred)
+        label-t (-> label
+                    ndarray/transpose
+                    (ndarray/reshape [-1])
+                    (ndarray/as-in-context pred-context))
+        pred-label (ndarray/argmax pred 1)
+        matches (ndarray/equal label-t pred-label)
+        [digit-matches] (-> matches
+                            ndarray/sum
+                            ndarray/->vec)
+        [complete-matches] (-> matches
+                               (ndarray/reshape [nc nr])
+                               (ndarray/sum 0)
+                               (ndarray/equal label-width)
+                               ndarray/sum
+                               ndarray/->vec)]
+    (if by-character
+      (float (/ digit-matches nr nc))
+      (float (/ complete-matches nr)))))
+
+(defn get-data-symbol
+  []
+  (let [data (sym/variable "data")
+        ;; normalize the input pixels
+        scaled (sym/div (sym/- data 127) 128)
+
+        conv1 (sym/convolution {:data scaled :kernel [5 5] :num-filter 32})
+        pool1 (sym/pooling {:data conv1 :pool-type "max" :kernel [2 2] :stride [1 1]})
+        relu1 (sym/activation {:data pool1 :act-type "relu"})
+
+        conv2 (sym/convolution {:data relu1 :kernel [5 5] :num-filter 32})
+        pool2 (sym/pooling {:data conv2 :pool-type "avg" :kernel [2 2] :stride [1 1]})
+        relu2 (sym/activation {:data pool2 :act-type "relu"})
+
+        conv3 (sym/convolution {:data relu2 :kernel [3 3] :num-filter 32})
+        pool3 (sym/pooling {:data conv3 :pool-type "avg" :kernel [2 2] :stride [1 1]})
+        relu3 (sym/activation {:data pool3 :act-type "relu"})
+
+        conv4 (sym/convolution {:data relu3 :kernel [3 3] :num-filter 32})
+        pool4 (sym/pooling {:data conv4 :pool-type "avg" :kernel [2 2] :stride [1 1]})
+        relu4 (sym/activation {:data pool4 :act-type "relu"})
+
+        flattened (sym/flatten {:data relu4})
+        fc1 (sym/fully-connected {:data flattened :num-hidden 256})
+        fc21 (sym/fully-connected {:data fc1 :num-hidden num-labels})
+        fc22 (sym/fully-connected {:data fc1 :num-hidden num-labels})
+        fc23 (sym/fully-connected {:data fc1 :num-hidden num-labels})
+        fc24 (sym/fully-connected {:data fc1 :num-hidden num-labels})]
+    (sym/concat "concat" nil [fc21 fc22 fc23 fc24] {:dim 0})))
+
+(defn get-label-symbol
+  []
+  (as-> (sym/variable "label") label
+    (sym/transpose {:data label})
+    (sym/reshape {:data label :shape [-1]})))
+
+(defn create-captcha-net
+  []
+  (let [scores (get-data-symbol)
+        labels (get-label-symbol)]
+    (sym/softmax-output {:data scores :label labels})))
+
+(def optimizer
+  (optimizer/adam
+   {:learning-rate 0.0002
+    :wd 0.00001
+    :clip-gradient 10}))
+
+(defn train-ocr
+  [devs]
+  (println "Starting the captcha training ...")
+  (let [model (m/module
+               (create-captcha-net)
+               {:data-names ["data"] :label-names ["label"]
+                :contexts devs})]
+    (m/fit model {:train-data train-data
+                  :eval-data eval-data
+                  :num-epoch 10
+                  :fit-params (m/fit-params
+                               {:kvstore "local"
+                                :batch-end-callback
+                                (callback/speedometer batch-size 100)
+                                :initializer
+                                (initializer/xavier {:factor-type "in"
+                                                     :magnitude 2.34})
+                                :optimizer optimizer
+                                :eval-metric (eval-metric/custom-metric
+                                              #(accuracy %1 %2)
+                                              "accuracy")})})
+    (println "Finished the fit")
+    model))
+
+(defn -main
+  [& args]
+  (let [[dev dev-num] args
+        num-devices (Integer/parseInt (or dev-num "1"))
+        devs (if (= dev ":gpu")
+               (mapv #(context/gpu %) (range num-devices))
+               (mapv #(context/cpu %) (range num-devices)))
+        model (train-ocr devs)]
+    (m/save-checkpoint model {:prefix model-prefix :epoch 0})))
diff --git a/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj
new file mode 100644
index 0000000..ab785f7
--- /dev/null
+++ b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj
@@ -0,0 +1,119 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns captcha.train-ocr-test
+  (:require [clojure.test :refer :all]
+            [captcha.consts :refer :all]
+            [captcha.train-ocr :refer :all]
+            [org.apache.clojure-mxnet.io :as mx-io]
+            [org.apache.clojure-mxnet.module :as m]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]
+            [org.apache.clojure-mxnet.shape :as shape]
+            [org.apache.clojure-mxnet.util :as util]))
+
+(deftest test-consts
+  (is (= 8 batch-size))
+  (is (= [3 30 80] data-shape))
+  (is (= 4 label-width))
+  (is (= 10 num-labels)))
+
+(deftest test-labeled-data
+  (let [train-batch (mx-io/next train-data)
+        eval-batch (mx-io/next eval-data)
+        allowed-labels (into #{} (map float (range 10)))]
+    (is (= 8 (-> train-batch mx-io/batch-index count)))
+    (is (= 8 (-> eval-batch mx-io/batch-index count)))
+    (is (= [8 3 30 80] (-> train-batch
+                           mx-io/batch-data
+                           first
+                           ndarray/shape-vec)))
+    (is (= [8 3 30 80] (-> eval-batch
+                           mx-io/batch-data
+                           first
+                           ndarray/shape-vec)))
+    (is (every? #(<= 0 % 255) (-> train-batch
+                                  mx-io/batch-data
+                                  first
+                                  ndarray/->vec)))
+    (is (every? #(<= 0 % 255) (-> eval-batch
+                                  mx-io/batch-data
+                                  first
+                                  ndarray/->vec)))
+    (is (= [8 4] (-> train-batch
+                     mx-io/batch-label
+                     first
+                     ndarray/shape-vec)))
+    (is (= [8 4] (-> eval-batch
+                     mx-io/batch-label
+                     first
+                     ndarray/shape-vec)))
+    (is (every? allowed-labels (-> train-batch
+                                   mx-io/batch-label
+                                   first
+                                   ndarray/->vec)))
+    (is (every? allowed-labels (-> eval-batch
+                                   mx-io/batch-label
+                                   first
+                                   ndarray/->vec)))))
+
+(deftest test-model
+  (let [batch (mx-io/next train-data)
+        model (m/module (create-captcha-net)
+                        {:data-names ["data"] :label-names ["label"]})
+        _ (m/bind model
+                  {:data-shapes (mx-io/provide-data-desc train-data)
+                   :label-shapes (mx-io/provide-label-desc train-data)})
+        _ (m/init-params model)
+        _ (m/forward-backward model batch)
+        output-shapes (-> model
+                          m/output-shapes
+                          util/coerce-return-recursive)
+        outputs (-> model
+                    m/outputs-merged
+                    first)
+        grads (->> model m/grad-arrays (map first))]
+    (is (= [["softmaxoutput0_output" (shape/->shape [8 10])]]
+           output-shapes))
+    (is (= [32 10] (-> outputs ndarray/shape-vec)))
+    (is (every? #(<= 0.0 % 1.0) (-> outputs ndarray/->vec)))
+    (is (= [[32 3 5 5] [32]   ; convolution1 weights+bias
+            [32 32 5 5] [32]  ; convolution2 weights+bias
+            [32 32 3 3] [32]  ; convolution3 weights+bias
+            [32 32 3 3] [32]  ; convolution4 weights+bias
+            [256 28672] [256] ; fully-connected1 weights+bias
+            [10 256] [10]     ; 1st label scores
+            [10 256] [10]     ; 2nd label scores
+            [10 256] [10]     ; 3rd label scores
+            [10 256] [10]]    ; 4th label scores
+           (map ndarray/shape-vec grads)))))
+
+(deftest test-accuracy
+  (let [labels (ndarray/array [1 2 3 4,
+                               5 6 7 8]
+                              [2 4])
+        pred-labels (ndarray/array [1 0,
+                                    2 6,
+                                    3 0,
+                                    4 8]
+                                   [8])
+        preds (ndarray/one-hot pred-labels 10)]
+    (is (float? (accuracy labels preds)))
+    (is (float? (accuracy labels preds :by-character false)))
+    (is (float? (accuracy labels preds :by-character true)))
+    (is (= 0.5 (accuracy labels preds)))
+    (is (= 0.5 (accuracy labels preds :by-character false)))
+    (is (= 0.75 (accuracy labels preds :by-character true)))))