You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/13 00:05:58 UTC

[GitHub] gigasquid closed pull request #12387: MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package

gigasquid closed pull request #12387: MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package
URL: https://github.com/apache/incubator-mxnet/pull/12387
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index 14dd2c5cc3f..e2e3364535e 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -159,13 +159,13 @@
 
 (defn train [devs]
   (let [mod-d  (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
-                   (m/bind {:data-shapes (mx-io/provide-data mnist-iter)
-                            :label-shapes (mx-io/provide-label mnist-iter)
+                   (m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
+                            :label-shapes (mx-io/provide-label-desc mnist-iter)
                             :inputs-need-grad true})
                    (m/init-params {:initializer (init/normal 0.02)})
                    (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
         mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
-                  (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
+                  (m/bind {:data-shapes (mx-io/provide-data-desc rand-noise-iter)})
                   (m/init-params {:initializer (init/normal 0.02)})
                   (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]
 
diff --git a/contrib/clojure-package/examples/module/src/mnist_mlp.clj b/contrib/clojure-package/examples/module/src/mnist_mlp.clj
index 74edf71172c..c5ffbbede85 100644
--- a/contrib/clojure-package/examples/module/src/mnist_mlp.clj
+++ b/contrib/clojure-package/examples/module/src/mnist_mlp.clj
@@ -85,7 +85,7 @@
               (m/module (get-symbol) {:contexts devs}))
         metric (eval-metric/accuracy)]
     (-> mod
-        (m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
+        (m/bind {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label-desc train-data)})
         (m/init-params)
         (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))
 
diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
index f2b9eddeb2a..93c121f9fc1 100644
--- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
+++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
@@ -80,7 +80,7 @@
 
 (defn fit [devs msymbol arg-params aux-params]
   (let [mod (-> (m/module msymbol {:contexts devs})
-                (m/bind {:data-shapes (mx-io/provide-data train-iter) :label-shapes (mx-io/provide-label val-iter)})
+                (m/bind {:data-shapes (mx-io/provide-data-desc train-iter) :label-shapes (mx-io/provide-label-desc val-iter)})
                 (m/init-params {:arg-params arg-params :aux-params aux-params
                                 :allow-missing true}))]
     (m/fit mod
diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
index 12bdb12fb5a..71202bc000f 100644
--- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
+++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
@@ -92,7 +92,7 @@
 
 (comment
 
-  (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg")
+  (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg" true)
   ;; ({:prob 0.69066674, :label "n02122948 kitten, kitty"}
   ;;  {:prob 0.04466057, :label "n01323155 kit"}
   ;;  {:prob 0.029682875, :label "n01318894 pet"}
diff --git a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
index 29aba26b195..150cd94e673 100644
--- a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
+++ b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
@@ -109,16 +109,16 @@
                                         :label-name "softmax_label"
                                         :data-batch-size batch-size
                                         :last-batch-handle "pad"})
-        data-and-labels (merge (data-desc->map (mx-io/provide-data train-iter))
-                               (data-desc->map (mx-io/provide-label train-iter))
+        data-and-labels (merge (data-desc->map (mx-io/provide-data-desc train-iter))
+                               (data-desc->map (mx-io/provide-label-desc train-iter))
                                init-states)
         init-states-data (mapv (fn [[k v]] (ndarray/zeros v {:ctx ctx})) init-states)
         rnn-sym (sym-gen (first buckets))
 
         rnn-mod (-> (m/module rnn-sym {:contexts devs})
-                    (m/bind {:data-shapes (into (mx-io/provide-data train-iter)
+                    (m/bind {:data-shapes (into (mx-io/provide-data-desc train-iter)
                                                 (mapv (fn [[k v]] {:name k :shape v}) init-states))
-                             :label-shapes (mx-io/provide-label train-iter)})
+                             :label-shapes (mx-io/provide-label-desc train-iter)})
                     (m/init-params {:initializer (init/xavier {:factor-type "in" :magnitude 2.34})})
                     (m/init-optimizer {:optimizer (optimizer/adam {:learning-rate learning-rate :wd 0.0001})}))
         metric (eval-metric/custom-metric
@@ -141,8 +141,8 @@
 
                 "perplexity")]
 
-    ;; Train for 2 epochs and then show the results of 75
-    (doseq [epoch-num (range 2)]
+    ;; Train for 1 epochs and then show the results of 75
+    (doseq [epoch-num (range 1)]
       (println "Doing epoch " epoch-num)
       (mx-io/reduce-batches
        train-iter
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
index d6f1499ba82..a2b639934f4 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
@@ -17,11 +17,12 @@
 
 (ns org.apache.clojure-mxnet.io
   (:refer-clojure :exclude [next])
-  (:require [org.apache.clojure-mxnet.base :as base]
+  (:require [clojure.spec.alpha :as s]
+            [org.apache.clojure-mxnet.base :as base]
             [org.apache.clojure-mxnet.shape :as mx-shape]
             [org.apache.clojure-mxnet.util :as util]
             [org.apache.clojure-mxnet.dtype :as dtype]
-            [clojure.spec.alpha :as s]
+            [org.apache.clojure-mxnet.layout :as layout]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.random :as random])
   (:import (org.apache.mxnet IO DataDesc DataBatch NDArray)
@@ -57,18 +58,48 @@
 
 (defn resize-iter [iter nbatch])
 
-(defn provide-data [pack-iterator]
+(defn provide-data
+  "Provides the description of the data iterator in the form of
+  [{:name name :shape shape-vec}]"
+  [pack-iterator]
   (->> pack-iterator
        (.provideData)
        (util/scala-map->map)
        (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))
 
-(defn provide-label [pack-iterator]
+(defn provide-label
+  "Provides the description of the label iterator in the form of
+  [{:name name :shape shape-vec}]"
+  [pack-iterator]
   (->> pack-iterator
        (.provideLabel)
        (util/scala-map->map)
        (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))
 
+(defn data-desc->map [dd]
+  {:name (.name dd)
+   :shape (mx-shape/->vec (.shape dd))
+   :dtype (.dtype dd)
+   :layout (.layout dd)})
+
+(defn provide-data-desc
+  "Provides the Data Desc of the data iterator in the form of
+  [{:name name :shape shape-vec :dtype dtype :layout layout}]"
+  [pack-iterator]
+  (->> pack-iterator
+       (.provideDataDesc)
+       (util/scala-vector->vec)
+       (mapv data-desc->map)))
+
+(defn provide-label-desc
+  "Provides the Data Desc of the label iterator in the form of
+  [{:name name :shape shape-vec :dtype dtype :layout layout}]"
+  [pack-iterator]
+  (->> pack-iterator
+       (.provideLabelDesc)
+       (util/scala-vector->vec)
+       (mapv data-desc->map)))
+
 (defn reset [iterator]
   (.reset iterator))
 
@@ -194,7 +225,8 @@
 (defn ndarray-iter
   " * NDArrayIter object in mxnet. Taking NDArray to get dataiter.
   *
-  * @param data vector of iter
+  * @param data vector of iter - Can either by in the form for [ndarray..] or
+  *  {data-desc0 ndarray0 data-desc2 ndarray2 ...}
   * @opts map of:
   *     :label Same as data, but is not fed to the model during testing.
   *     :data-batch-size Batch Size (default 1)
@@ -213,14 +245,23 @@
                last-batch-handle "pad"
                data-name "data"
                label-name "label"}}]
-   (new NDArrayIter
-        (util/vec->indexed-seq data)
-        (if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
-        (int data-batch-size)
-        shuffle
-        last-batch-handle
-        data-name
-        label-name))
+   (if (map? data)
+     (new NDArrayIter
+          (.toIndexedSeq (util/list-map data))
+          (if label
+            (.toIndexedSeq (util/list-map label))
+            (util/empty-indexed-seq))
+          (int data-batch-size)
+          shuffle
+          last-batch-handle)
+     (new NDArrayIter
+          (util/vec->indexed-seq data)
+          (if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
+          (int data-batch-size)
+          shuffle
+          last-batch-handle
+          data-name
+          label-name)))
   ([data]
    (ndarray-iter data {})))
 
@@ -230,24 +271,19 @@
 (s/def ::name string?)
 (s/def ::shape vector?)
 (s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
+(s/def ::layout (s/or :custom string? :standard #{layout/UNDEFINED
+                                                  layout/NCHW
+                                                  layout/NTC
+                                                  layout/NT
+                                                  layout/N}))
 (s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout]))
 
-;; NCHW is N:batch size C: channel H: height W: width
-;;; other layouts are
-;; NT, TNC, nad N
-;; the shape length must match the lengh of the layout string size
 (defn data-desc
   ([{:keys [name shape dtype layout] :as opts
-     :or {dtype base/MX_REAL_TYPE}}]
+     :or {dtype base/MX_REAL_TYPE
+          layout layout/UNDEFINED}}]
    (util/validate! ::data-desc opts "Invalid data description")
-   (let [sc (count shape)
-         layout (or layout (cond
-                             (= 1 sc) "N"
-                             (= 2 sc) "NT"
-                             (= 3 sc) "TNC"
-                             (= 4 sc) "NCHW"
-                             :else (apply str (repeat sc "?"))))]
-     (new DataDesc name (mx-shape/->shape shape) dtype layout)))
+   (new DataDesc name (mx-shape/->shape shape) dtype layout))
   ([name shape]
    (data-desc {:name name :shape shape})))
 
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
new file mode 100644
index 00000000000..f379a7a02d2
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
@@ -0,0 +1,35 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.layout
+  (:import (org.apache.mxnet Layout)))
+
+;;
+;;    Layout definition of DataDesc
+;;    N Batch size
+;;    C channels
+;;    H Height
+;;    W Weight
+;;    T sequence length
+;;    __undefined__ default value of Layout
+;;   
+
+(def UNDEFINED (Layout/UNDEFINED)) ;"__UNDEFINED__"
+(def NCHW (Layout/NCHW)) ;=> "NCHW"
+(def NTC (Layout/NTC)) ;=> "NTC"
+(def NT (Layout/NT)) ;=> "NT"
+(def N (Layout/N)) ;=> "N
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
index ab6d345fe91..aa5ce39f7a8 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
@@ -88,8 +88,8 @@
    can perform computation with the module.
    mod : module
    map of opts:
-     :data-shapes Typically is  (provide-data data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
-     :label-shapes Typically is  (provide-label data-iter). map of :name :shape :dtype and :layout
+     :data-shapes Typically is  (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
+     :label-shapes Typically is  (provide-label-desc data-iter). map of :name :shape :dtype and :layout
      :for-training Default is `true`. Whether the executors should be bind for training.
      :inputs-need-grad Default is `false`.
                        Whether the gradients to the input data need to be computed.
@@ -547,54 +547,12 @@
         `:or {num-epoch 1
               fit-params (new FitParams)}}]
   (util/validate! ::fit-options opts "Invalid options for fit")
-  (let [fmod (-> mod
-                 (bind {:data-shapes (mx-io/provide-data train-data)
-                        :label-shapes (mx-io/provide-label train-data)
-                        :for-training true
-                        :force-rebind (.forceRebind fit-params)})
-                 (init-params (remove (fn [[k v]] (nil? v))
-                                      {:initializer (.initializer fit-params)
-                                       :arg-params (.argParams fit-params)
-                                       :aux-params (.auxParams fit-params)
-                                       :allow-missing (.allowMissing fit-params)}))
-                 (init-optimizer (remove (fn [[k v]] (nil? v))
-                                         {:optimizer (.optimizer fit-params)
-                                          :kvstore (.kvstore fit-params)})))
-        eval-metric (or (.evalMetric fit-params) (eval-metric/accuracy))
-        val-metric (or (util/option->value (.validationMetric fit-params)) (eval-metric/accuracy))]
-    (doseq [i (range num-epoch)]
-      (let [tic (System/currentTimeMillis)]
-        (mx-io/reduce-batches train-data
-                              (fn [batch-num batch]
-                                (-> fmod
-                                    (forward batch)
-                                    (backward)
-                                    (update)
-                                    (update-metric eval-metric (mx-io/batch-label batch)))
-                                (when-let [cb (util/option->value (.batchEndCallback fit-params))]
-                                  (callback/invoke cb i batch-num eval-metric))
-                                (.dispose batch)
-                                (inc batch-num)))
-        (println "Epoch " i " Train-" (eval-metric/get eval-metric))
-        (println "Epoch " i " Time cost-" (- (System/currentTimeMillis) tic))
-
-       ;;sync across kvstores
-        (get-params fmod)
-        (when-let [cb (util/option->value (.epochEndCallback fit-params))]
-          (callback/invoke cb i 0 val-metric))
-
-       ;; evaluation on the validation set
-        (when eval-data
-          (let [res (score fmod {:eval-data eval-data :eval-metric eval-metric :epoch i})]
-            (println "Epoch " i " Validation- " res)))))
-    fmod)
-  ;; old way if the problem with the sizes get resolved in DataDesc
-  #_(doto mod
-      (.fit
-       train-data
-       (util/->option eval-data)
-       (int num-epoch)
-       fit-params)))
+  (doto mod
+    (.fit
+     train-data
+     (util/->option eval-data)
+     (int num-epoch)
+     fit-params)))
 
 (s/def ::eval-data ::train-data)
 (s/def ::num-batch integer?)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index 7ca4ede9733..e37a8bc8c98 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -89,7 +89,7 @@
    (NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
   ([start stop]
    (arange start stop {})))
-  
+
 (defn slice
   "Return a sliced NDArray that shares memory with current one."
   ([ndarray i]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
index 12135fb75ca..58b1d6d49ff 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
@@ -144,7 +144,7 @@
    which must be known from the rest of the net."
   ([start {:keys [step repeat dtype]
            :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
-          :as opts}]
+           :as opts}]
    (Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
   ([start]
    (arange-with-inference start {})))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
index ace39ec201e..9babf1e2253 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
@@ -22,7 +22,9 @@
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.util :as util]
             [org.apache.clojure-mxnet.shape :as mx-shape]
-            [clojure.test :refer :all]))
+            [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.layout :as layout]))
 
 (deftest test-mnsit-iter-and-mnist-pack
   (let [_ (when-not (.exists (io/file "data/train-images-idx3-ubyte"))
@@ -59,6 +61,31 @@
         (is (= label1 label0))
         (is (= data1 data0))))))
 
+(deftest test-provide-data-and-label
+  (let [test-data (mx-io/mnist-iter {:image "data/train-images-idx3-ubyte"
+                                     :label "data/train-labels-idx1-ubyte"
+                                     :label-name "softmax_label"
+                                     :data-shape [1 28 28]
+                                     :label-shape [1 1 10]
+                                     :batch-size 100
+                                     :shuffle true
+                                     :flat false
+                                     :silent false
+                                     :seed 10})]
+    (is (= [{:name "data", :shape [100 1 28 28]}]
+           (mx-io/provide-data test-data)))
+    (is (= [{:name "softmax_label", :shape [100]}]
+           (mx-io/provide-label test-data)))
+    (is (= [{:name "data", :shape [100 1 28 28]
+             :dtype dtype/FLOAT32
+             :layout layout/UNDEFINED}]
+           (mx-io/provide-data-desc test-data)))
+    (is (= [{:name "softmax_label"
+             :shape [100]
+             :dtype dtype/FLOAT32
+             :layout layout/UNDEFINED}]
+           (mx-io/provide-label-desc test-data)))))
+
 (deftest test-image-record-iter
   (let [_ (when-not (.exists (io/file "data/cifar/train.rec"))
             (sh "scripts/get_cifar_data.sh"))
@@ -162,4 +189,26 @@
                                                :last-batch-handle "discard"})
           nbatch2 7]
       (is (= nbatch2 (mx-io/reduce-batches data-iter2 (fn [result batch] (inc result)))))
-      (is (= [] (mx-io/iter-init-label data-iter2))))))
+      (is (= [] (mx-io/iter-init-label data-iter2))))
+
+    ;;; testing with a specified layout
+    (let [label-desc (mx-io/data-desc {:name "label"
+                                       :shape [2 2]
+                                       :dtype dtype/INT32
+                                       :layout layout/NT})
+          data-desc (mx-io/data-desc {:name "data"
+                                      :shape [2 2 2]
+                                      :dtype dtype/FLOAT32
+                                      :layout layout/NTC})
+          label (ndarray/ones [2 2] {:dtype dtype/INT32})
+          data (ndarray/ones [2 2 2] {:dtype dtype/FLOAT32})
+          data-iter3 (mx-io/ndarray-iter {data-desc data}
+                                         {:label {label-desc label}})]
+      (is (= {:dtype dtype/FLOAT32 :layout layout/NTC}
+             (-> (mx-io/provide-data-desc data-iter3)
+                 first
+                 (select-keys [:dtype :layout]))))
+      (is (= {:dtype dtype/INT32 :layout layout/NT}
+             (-> (mx-io/provide-label-desc data-iter3)
+                 first
+                 (select-keys [:dtype :layout])))))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
index 0f71b5a850c..d53af2ec249 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
@@ -20,6 +20,7 @@
             [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.io :as mx-io]
+            [org.apache.clojure-mxnet.layout :as layout]
             [org.apache.clojure-mxnet.module :as m]
             [org.apache.clojure-mxnet.monitor :as monitor]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
@@ -54,9 +55,9 @@
         c (sym/+ a (sym/+ (sym/* b 2) (sym/* c 3)))
         mod (m/module c ["b" "c" "a"] nil [(context/cpu 0) (context/cpu 1)])]
     (-> mod
-        (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout "NT"}
-                               {:name "c" :shape [5 5] :layout "NT"}
-                               {:name "a" :shape [5 5] :layout "NT"}]
+        (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout layout/NT}
+                               {:name "c" :shape [5 5] :layout layout/NT}
+                               {:name "a" :shape [5 5] :layout layout/NT}]
                  :inputs-need-grad true})
         (m/init-params)
         (m/forward {:data [(ndarray/ones [5 5])
@@ -172,7 +173,7 @@
             (sym/linear-regression-output "softmax" {:data v :grad-scale 2}))
 
         mod (m/module x)]
-    (m/bind mod {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
+    (m/bind mod {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label train-data)})
 
     (let [arg-params-correct {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2])
                               "fc_0_bias" (ndarray/array [0.35 0.35] [2])
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
index ecd54ca7277..d632c969eae 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
@@ -23,7 +23,7 @@
     (let [diff (Math/abs (- x y))]
       (< diff tolerance))
     (and
-    	(= (count x) (count y))
-		(reduce (fn [x y] (and x y))
-            (map #(approx= tolerance %1 %2) x y)))))
+     (= (count x) (count y))
+     (reduce (fn [x y] (and x y))
+             (map #(approx= tolerance %1 %2) x y)))))
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services