You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/09/13 00:06:09 UTC

[incubator-mxnet] branch master updated: MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package (#12387)

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b8153f6  MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package (#12387)
b8153f6 is described below

commit b8153f6449d5628491ef2ec4e10871643e8ba5c6
Author: Carin Meier <cm...@gigasquidsoftware.com>
AuthorDate: Wed Sep 12 20:05:56 2018 -0400

    MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package (#12387)
    
    * Bring clojure package inline with new DataDesc and Layout in Scala package
    
    * formatting cljfmt
    
    * revert the implementation of module fit back now that DataDesc issue if fixed
    - update Module example to use provide-data-desc and provide-label-desc
    
    * update to provide-data-desc and provide-label-desc
    
    * decrease epochs to speed example
    
    * Add tests and docstrings
    
    * remove let
---
 .../examples/gan/src/gan/gan_mnist.clj             |  6 +-
 .../examples/module/src/mnist_mlp.clj              |  2 +-
 .../src/pre_trained_models/fine_tune.clj           |  2 +-
 .../src/pre_trained_models/predict_image.clj       |  2 +-
 .../examples/rnn/src/rnn/train_char_rnn.clj        | 12 +--
 .../src/org/apache/clojure_mxnet/io.clj            | 88 +++++++++++++++-------
 .../org/apache/clojure_mxnet/layout.clj}           | 26 ++++---
 .../src/org/apache/clojure_mxnet/module.clj        | 58 ++------------
 .../src/org/apache/clojure_mxnet/ndarray.clj       |  2 +-
 .../src/org/apache/clojure_mxnet/symbol.clj        |  2 +-
 .../test/org/apache/clojure_mxnet/io_test.clj      | 53 ++++++++++++-
 .../test/org/apache/clojure_mxnet/module_test.clj  |  9 ++-
 .../test/org/apache/clojure_mxnet/test_util.clj    |  6 +-
 13 files changed, 159 insertions(+), 109 deletions(-)

diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index 14dd2c5..e2e3364 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -159,13 +159,13 @@
 
 (defn train [devs]
   (let [mod-d  (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
-                   (m/bind {:data-shapes (mx-io/provide-data mnist-iter)
-                            :label-shapes (mx-io/provide-label mnist-iter)
+                   (m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
+                            :label-shapes (mx-io/provide-label-desc mnist-iter)
                             :inputs-need-grad true})
                    (m/init-params {:initializer (init/normal 0.02)})
                    (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
         mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
-                  (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
+                  (m/bind {:data-shapes (mx-io/provide-data-desc rand-noise-iter)})
                   (m/init-params {:initializer (init/normal 0.02)})
                   (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]
 
diff --git a/contrib/clojure-package/examples/module/src/mnist_mlp.clj b/contrib/clojure-package/examples/module/src/mnist_mlp.clj
index 74edf71..c5ffbbe 100644
--- a/contrib/clojure-package/examples/module/src/mnist_mlp.clj
+++ b/contrib/clojure-package/examples/module/src/mnist_mlp.clj
@@ -85,7 +85,7 @@
               (m/module (get-symbol) {:contexts devs}))
         metric (eval-metric/accuracy)]
     (-> mod
-        (m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
+        (m/bind {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label-desc train-data)})
         (m/init-params)
         (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))
 
diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
index f2b9edd..93c121f 100644
--- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
+++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
@@ -80,7 +80,7 @@
 
 (defn fit [devs msymbol arg-params aux-params]
   (let [mod (-> (m/module msymbol {:contexts devs})
-                (m/bind {:data-shapes (mx-io/provide-data train-iter) :label-shapes (mx-io/provide-label val-iter)})
+                (m/bind {:data-shapes (mx-io/provide-data-desc train-iter) :label-shapes (mx-io/provide-label-desc val-iter)})
                 (m/init-params {:arg-params arg-params :aux-params aux-params
                                 :allow-missing true}))]
     (m/fit mod
diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
index 12bdb12..71202bc 100644
--- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
+++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
@@ -92,7 +92,7 @@
 
 (comment
 
-  (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg")
+  (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg" true)
   ;; ({:prob 0.69066674, :label "n02122948 kitten, kitty"}
   ;;  {:prob 0.04466057, :label "n01323155 kit"}
   ;;  {:prob 0.029682875, :label "n01318894 pet"}
diff --git a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
index 29aba26..150cd94 100644
--- a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
+++ b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
@@ -109,16 +109,16 @@
                                         :label-name "softmax_label"
                                         :data-batch-size batch-size
                                         :last-batch-handle "pad"})
-        data-and-labels (merge (data-desc->map (mx-io/provide-data train-iter))
-                               (data-desc->map (mx-io/provide-label train-iter))
+        data-and-labels (merge (data-desc->map (mx-io/provide-data-desc train-iter))
+                               (data-desc->map (mx-io/provide-label-desc train-iter))
                                init-states)
         init-states-data (mapv (fn [[k v]] (ndarray/zeros v {:ctx ctx})) init-states)
         rnn-sym (sym-gen (first buckets))
 
         rnn-mod (-> (m/module rnn-sym {:contexts devs})
-                    (m/bind {:data-shapes (into (mx-io/provide-data train-iter)
+                    (m/bind {:data-shapes (into (mx-io/provide-data-desc train-iter)
                                                 (mapv (fn [[k v]] {:name k :shape v}) init-states))
-                             :label-shapes (mx-io/provide-label train-iter)})
+                             :label-shapes (mx-io/provide-label-desc train-iter)})
                     (m/init-params {:initializer (init/xavier {:factor-type "in" :magnitude 2.34})})
                     (m/init-optimizer {:optimizer (optimizer/adam {:learning-rate learning-rate :wd 0.0001})}))
         metric (eval-metric/custom-metric
@@ -141,8 +141,8 @@
 
                 "perplexity")]
 
-    ;; Train for 2 epochs and then show the results of 75
-    (doseq [epoch-num (range 2)]
+    ;; Train for 1 epochs and then show the results of 75
+    (doseq [epoch-num (range 1)]
       (println "Doing epoch " epoch-num)
       (mx-io/reduce-batches
        train-iter
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
index d6f1499..a2b6399 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
@@ -17,11 +17,12 @@
 
 (ns org.apache.clojure-mxnet.io
   (:refer-clojure :exclude [next])
-  (:require [org.apache.clojure-mxnet.base :as base]
+  (:require [clojure.spec.alpha :as s]
+            [org.apache.clojure-mxnet.base :as base]
             [org.apache.clojure-mxnet.shape :as mx-shape]
             [org.apache.clojure-mxnet.util :as util]
             [org.apache.clojure-mxnet.dtype :as dtype]
-            [clojure.spec.alpha :as s]
+            [org.apache.clojure-mxnet.layout :as layout]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.random :as random])
   (:import (org.apache.mxnet IO DataDesc DataBatch NDArray)
@@ -57,18 +58,48 @@
 
 (defn resize-iter [iter nbatch])
 
-(defn provide-data [pack-iterator]
+(defn provide-data
+  "Provides the description of the data iterator in the form of
+  [{:name name :shape shape-vec}]"
+  [pack-iterator]
   (->> pack-iterator
        (.provideData)
        (util/scala-map->map)
        (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))
 
-(defn provide-label [pack-iterator]
+(defn provide-label
+  "Provides the description of the label iterator in the form of
+  [{:name name :shape shape-vec}]"
+  [pack-iterator]
   (->> pack-iterator
        (.provideLabel)
        (util/scala-map->map)
        (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))
 
+(defn data-desc->map [dd]
+  {:name (.name dd)
+   :shape (mx-shape/->vec (.shape dd))
+   :dtype (.dtype dd)
+   :layout (.layout dd)})
+
+(defn provide-data-desc
+  "Provides the Data Desc of the data iterator in the form of
+  [{:name name :shape shape-vec :dtype dtype :layout layout}]"
+  [pack-iterator]
+  (->> pack-iterator
+       (.provideDataDesc)
+       (util/scala-vector->vec)
+       (mapv data-desc->map)))
+
+(defn provide-label-desc
+  "Provides the Data Desc of the label iterator in the form of
+  [{:name name :shape shape-vec :dtype dtype :layout layout}]"
+  [pack-iterator]
+  (->> pack-iterator
+       (.provideLabelDesc)
+       (util/scala-vector->vec)
+       (mapv data-desc->map)))
+
 (defn reset [iterator]
   (.reset iterator))
 
@@ -194,7 +225,8 @@
 (defn ndarray-iter
   " * NDArrayIter object in mxnet. Taking NDArray to get dataiter.
   *
-  * @param data vector of iter
+  * @param data vector of iter - Can either by in the form for [ndarray..] or
+  *  {data-desc0 ndarray0 data-desc2 ndarray2 ...}
   * @opts map of:
   *     :label Same as data, but is not fed to the model during testing.
   *     :data-batch-size Batch Size (default 1)
@@ -213,14 +245,23 @@
                last-batch-handle "pad"
                data-name "data"
                label-name "label"}}]
-   (new NDArrayIter
-        (util/vec->indexed-seq data)
-        (if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
-        (int data-batch-size)
-        shuffle
-        last-batch-handle
-        data-name
-        label-name))
+   (if (map? data)
+     (new NDArrayIter
+          (.toIndexedSeq (util/list-map data))
+          (if label
+            (.toIndexedSeq (util/list-map label))
+            (util/empty-indexed-seq))
+          (int data-batch-size)
+          shuffle
+          last-batch-handle)
+     (new NDArrayIter
+          (util/vec->indexed-seq data)
+          (if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
+          (int data-batch-size)
+          shuffle
+          last-batch-handle
+          data-name
+          label-name)))
   ([data]
    (ndarray-iter data {})))
 
@@ -230,24 +271,19 @@
 (s/def ::name string?)
 (s/def ::shape vector?)
 (s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
+(s/def ::layout (s/or :custom string? :standard #{layout/UNDEFINED
+                                                  layout/NCHW
+                                                  layout/NTC
+                                                  layout/NT
+                                                  layout/N}))
 (s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout]))
 
-;; NCHW is N:batch size C: channel H: height W: width
-;;; other layouts are
-;; NT, TNC, nad N
-;; the shape length must match the lengh of the layout string size
 (defn data-desc
   ([{:keys [name shape dtype layout] :as opts
-     :or {dtype base/MX_REAL_TYPE}}]
+     :or {dtype base/MX_REAL_TYPE
+          layout layout/UNDEFINED}}]
    (util/validate! ::data-desc opts "Invalid data description")
-   (let [sc (count shape)
-         layout (or layout (cond
-                             (= 1 sc) "N"
-                             (= 2 sc) "NT"
-                             (= 3 sc) "TNC"
-                             (= 4 sc) "NCHW"
-                             :else (apply str (repeat sc "?"))))]
-     (new DataDesc name (mx-shape/->shape shape) dtype layout)))
+   (new DataDesc name (mx-shape/->shape shape) dtype layout))
   ([name shape]
    (data-desc {:name name :shape shape})))
 
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
similarity index 65%
copy from contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
copy to contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
index ecd54ca..f379a7a 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
@@ -15,15 +15,21 @@
 ;; limitations under the License.
 ;;
 
-(ns org.apache.clojure-mxnet.test-util
-  (:require [clojure.test :as t]))
+(ns org.apache.clojure-mxnet.layout
+  (:import (org.apache.mxnet Layout)))
 
-(defn approx= [tolerance x y]
-  (if (and (number? x) (number? y))
-    (let [diff (Math/abs (- x y))]
-      (< diff tolerance))
-    (and
-    	(= (count x) (count y))
-		(reduce (fn [x y] (and x y))
-            (map #(approx= tolerance %1 %2) x y)))))
+;;
+;;    Layout definition of DataDesc
+;;    N Batch size
+;;    C channels
+;;    H Height
+;;    W Weight
+;;    T sequence length
+;;    __undefined__ default value of Layout
+;;   
 
+(def UNDEFINED (Layout/UNDEFINED)) ;"__UNDEFINED__"
+(def NCHW (Layout/NCHW)) ;=> "NCHW"
+(def NTC (Layout/NTC)) ;=> "NTC"
+(def NT (Layout/NT)) ;=> "NT"
+(def N (Layout/N)) ;=> "N
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
index ab6d345..aa5ce39 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
@@ -88,8 +88,8 @@
    can perform computation with the module.
    mod : module
    map of opts:
-     :data-shapes Typically is  (provide-data data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
-     :label-shapes Typically is  (provide-label data-iter). map of :name :shape :dtype and :layout
+     :data-shapes Typically is  (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
+     :label-shapes Typically is  (provide-label-desc data-iter). map of :name :shape :dtype and :layout
      :for-training Default is `true`. Whether the executors should be bind for training.
      :inputs-need-grad Default is `false`.
                        Whether the gradients to the input data need to be computed.
@@ -547,54 +547,12 @@
         `:or {num-epoch 1
               fit-params (new FitParams)}}]
   (util/validate! ::fit-options opts "Invalid options for fit")
-  (let [fmod (-> mod
-                 (bind {:data-shapes (mx-io/provide-data train-data)
-                        :label-shapes (mx-io/provide-label train-data)
-                        :for-training true
-                        :force-rebind (.forceRebind fit-params)})
-                 (init-params (remove (fn [[k v]] (nil? v))
-                                      {:initializer (.initializer fit-params)
-                                       :arg-params (.argParams fit-params)
-                                       :aux-params (.auxParams fit-params)
-                                       :allow-missing (.allowMissing fit-params)}))
-                 (init-optimizer (remove (fn [[k v]] (nil? v))
-                                         {:optimizer (.optimizer fit-params)
-                                          :kvstore (.kvstore fit-params)})))
-        eval-metric (or (.evalMetric fit-params) (eval-metric/accuracy))
-        val-metric (or (util/option->value (.validationMetric fit-params)) (eval-metric/accuracy))]
-    (doseq [i (range num-epoch)]
-      (let [tic (System/currentTimeMillis)]
-        (mx-io/reduce-batches train-data
-                              (fn [batch-num batch]
-                                (-> fmod
-                                    (forward batch)
-                                    (backward)
-                                    (update)
-                                    (update-metric eval-metric (mx-io/batch-label batch)))
-                                (when-let [cb (util/option->value (.batchEndCallback fit-params))]
-                                  (callback/invoke cb i batch-num eval-metric))
-                                (.dispose batch)
-                                (inc batch-num)))
-        (println "Epoch " i " Train-" (eval-metric/get eval-metric))
-        (println "Epoch " i " Time cost-" (- (System/currentTimeMillis) tic))
-
-       ;;sync across kvstores
-        (get-params fmod)
-        (when-let [cb (util/option->value (.epochEndCallback fit-params))]
-          (callback/invoke cb i 0 val-metric))
-
-       ;; evaluation on the validation set
-        (when eval-data
-          (let [res (score fmod {:eval-data eval-data :eval-metric eval-metric :epoch i})]
-            (println "Epoch " i " Validation- " res)))))
-    fmod)
-  ;; old way if the problem with the sizes get resolved in DataDesc
-  #_(doto mod
-      (.fit
-       train-data
-       (util/->option eval-data)
-       (int num-epoch)
-       fit-params)))
+  (doto mod
+    (.fit
+     train-data
+     (util/->option eval-data)
+     (int num-epoch)
+     fit-params)))
 
 (s/def ::eval-data ::train-data)
 (s/def ::num-batch integer?)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index 7ca4ede..e37a8bc 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -89,7 +89,7 @@
    (NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
   ([start stop]
    (arange start stop {})))
-  
+
 (defn slice
   "Return a sliced NDArray that shares memory with current one."
   ([ndarray i]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
index 12135fb..58b1d6d 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
@@ -144,7 +144,7 @@
    which must be known from the rest of the net."
   ([start {:keys [step repeat dtype]
            :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
-          :as opts}]
+           :as opts}]
    (Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
   ([start]
    (arange-with-inference start {})))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
index ace39ec..9babf1e 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
@@ -22,7 +22,9 @@
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.util :as util]
             [org.apache.clojure-mxnet.shape :as mx-shape]
-            [clojure.test :refer :all]))
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.layout :as layout]))
 
 (deftest test-mnsit-iter-and-mnist-pack
   (let [_ (when-not (.exists (io/file "data/train-images-idx3-ubyte"))
@@ -59,6 +61,31 @@
         (is (= label1 label0))
         (is (= data1 data0))))))
 
+(deftest test-provide-data-and-label
+  (let [test-data (mx-io/mnist-iter {:image "data/train-images-idx3-ubyte"
+                                     :label "data/train-labels-idx1-ubyte"
+                                     :label-name "softmax_label"
+                                     :data-shape [1 28 28]
+                                     :label-shape [1 1 10]
+                                     :batch-size 100
+                                     :shuffle true
+                                     :flat false
+                                     :silent false
+                                     :seed 10})]
+    (is (= [{:name "data", :shape [100 1 28 28]}]
+           (mx-io/provide-data test-data)))
+    (is (= [{:name "softmax_label", :shape [100]}]
+           (mx-io/provide-label test-data)))
+    (is (= [{:name "data", :shape [100 1 28 28]
+             :dtype dtype/FLOAT32
+             :layout layout/UNDEFINED}]
+           (mx-io/provide-data-desc test-data)))
+    (is (= [{:name "softmax_label"
+             :shape [100]
+             :dtype dtype/FLOAT32
+             :layout layout/UNDEFINED}]
+           (mx-io/provide-label-desc test-data)))))
+
 (deftest test-image-record-iter
   (let [_ (when-not (.exists (io/file "data/cifar/train.rec"))
             (sh "scripts/get_cifar_data.sh"))
@@ -162,4 +189,26 @@
                                                :last-batch-handle "discard"})
           nbatch2 7]
       (is (= nbatch2 (mx-io/reduce-batches data-iter2 (fn [result batch] (inc result)))))
-      (is (= [] (mx-io/iter-init-label data-iter2))))))
+      (is (= [] (mx-io/iter-init-label data-iter2))))
+
+    ;;; testing with a specified layout
+    (let [label-desc (mx-io/data-desc {:name "label"
+                                       :shape [2 2]
+                                       :dtype dtype/INT32
+                                       :layout layout/NT})
+          data-desc (mx-io/data-desc {:name "data"
+                                      :shape [2 2 2]
+                                      :dtype dtype/FLOAT32
+                                      :layout layout/NTC})
+          label (ndarray/ones [2 2] {:dtype dtype/INT32})
+          data (ndarray/ones [2 2 2] {:dtype dtype/FLOAT32})
+          data-iter3 (mx-io/ndarray-iter {data-desc data}
+                                         {:label {label-desc label}})]
+      (is (= {:dtype dtype/FLOAT32 :layout layout/NTC}
+             (-> (mx-io/provide-data-desc data-iter3)
+                 first
+                 (select-keys [:dtype :layout]))))
+      (is (= {:dtype dtype/INT32 :layout layout/NT}
+             (-> (mx-io/provide-label-desc data-iter3)
+                 first
+                 (select-keys [:dtype :layout])))))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
index 0f71b5a..d53af2e 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
@@ -20,6 +20,7 @@
             [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.io :as mx-io]
+            [org.apache.clojure-mxnet.layout :as layout]
             [org.apache.clojure-mxnet.module :as m]
             [org.apache.clojure-mxnet.monitor :as monitor]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
@@ -54,9 +55,9 @@
         c (sym/+ a (sym/+ (sym/* b 2) (sym/* c 3)))
         mod (m/module c ["b" "c" "a"] nil [(context/cpu 0) (context/cpu 1)])]
     (-> mod
-        (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout "NT"}
-                               {:name "c" :shape [5 5] :layout "NT"}
-                               {:name "a" :shape [5 5] :layout "NT"}]
+        (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout layout/NT}
+                               {:name "c" :shape [5 5] :layout layout/NT}
+                               {:name "a" :shape [5 5] :layout layout/NT}]
                  :inputs-need-grad true})
         (m/init-params)
         (m/forward {:data [(ndarray/ones [5 5])
@@ -172,7 +173,7 @@
             (sym/linear-regression-output "softmax" {:data v :grad-scale 2}))
 
         mod (m/module x)]
-    (m/bind mod {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
+    (m/bind mod {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label train-data)})
 
     (let [arg-params-correct {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2])
                               "fc_0_bias" (ndarray/array [0.35 0.35] [2])
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
index ecd54ca..d632c96 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
@@ -23,7 +23,7 @@
     (let [diff (Math/abs (- x y))]
       (< diff tolerance))
     (and
-    	(= (count x) (count y))
-		(reduce (fn [x y] (and x y))
-            (map #(approx= tolerance %1 %2) x y)))))
+     (= (count x) (count y))
+     (reduce (fn [x y] (and x y))
+             (map #(approx= tolerance %1 %2) x y)))))