You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/09/13 00:06:09 UTC
[incubator-mxnet] branch master updated: MXNET-873 - Bring Clojure
Package Inline with New DataDesc and Layout in Scala Package (#12387)
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b8153f6 MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package (#12387)
b8153f6 is described below
commit b8153f6449d5628491ef2ec4e10871643e8ba5c6
Author: Carin Meier <cm...@gigasquidsoftware.com>
AuthorDate: Wed Sep 12 20:05:56 2018 -0400
MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package (#12387)
* Bring clojure package inline with new DataDesc and Layout in Scala package
* formatting cljfmt
* revert the implementation of module fit back now that DataDesc issue if fixed
- update Module example to use provide-data-desc and provide-label-desc
* update to provide-data-desc and provide-label-desc
* decrease epochs to speed example
* Add tests and docstrings
* remove let
---
.../examples/gan/src/gan/gan_mnist.clj | 6 +-
.../examples/module/src/mnist_mlp.clj | 2 +-
.../src/pre_trained_models/fine_tune.clj | 2 +-
.../src/pre_trained_models/predict_image.clj | 2 +-
.../examples/rnn/src/rnn/train_char_rnn.clj | 12 +--
.../src/org/apache/clojure_mxnet/io.clj | 88 +++++++++++++++-------
.../org/apache/clojure_mxnet/layout.clj} | 26 ++++---
.../src/org/apache/clojure_mxnet/module.clj | 58 ++------------
.../src/org/apache/clojure_mxnet/ndarray.clj | 2 +-
.../src/org/apache/clojure_mxnet/symbol.clj | 2 +-
.../test/org/apache/clojure_mxnet/io_test.clj | 53 ++++++++++++-
.../test/org/apache/clojure_mxnet/module_test.clj | 9 ++-
.../test/org/apache/clojure_mxnet/test_util.clj | 6 +-
13 files changed, 159 insertions(+), 109 deletions(-)
diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index 14dd2c5..e2e3364 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -159,13 +159,13 @@
(defn train [devs]
(let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
- (m/bind {:data-shapes (mx-io/provide-data mnist-iter)
- :label-shapes (mx-io/provide-label mnist-iter)
+ (m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
+ :label-shapes (mx-io/provide-label-desc mnist-iter)
:inputs-need-grad true})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
- (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
+ (m/bind {:data-shapes (mx-io/provide-data-desc rand-noise-iter)})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]
diff --git a/contrib/clojure-package/examples/module/src/mnist_mlp.clj b/contrib/clojure-package/examples/module/src/mnist_mlp.clj
index 74edf71..c5ffbbe 100644
--- a/contrib/clojure-package/examples/module/src/mnist_mlp.clj
+++ b/contrib/clojure-package/examples/module/src/mnist_mlp.clj
@@ -85,7 +85,7 @@
(m/module (get-symbol) {:contexts devs}))
metric (eval-metric/accuracy)]
(-> mod
- (m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
+ (m/bind {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label-desc train-data)})
(m/init-params)
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))
diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
index f2b9edd..93c121f 100644
--- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
+++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj
@@ -80,7 +80,7 @@
(defn fit [devs msymbol arg-params aux-params]
(let [mod (-> (m/module msymbol {:contexts devs})
- (m/bind {:data-shapes (mx-io/provide-data train-iter) :label-shapes (mx-io/provide-label val-iter)})
+ (m/bind {:data-shapes (mx-io/provide-data-desc train-iter) :label-shapes (mx-io/provide-label-desc val-iter)})
(m/init-params {:arg-params arg-params :aux-params aux-params
:allow-missing true}))]
(m/fit mod
diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
index 12bdb12..71202bc 100644
--- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
+++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj
@@ -92,7 +92,7 @@
(comment
- (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg")
+ (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg" true)
;; ({:prob 0.69066674, :label "n02122948 kitten, kitty"}
;; {:prob 0.04466057, :label "n01323155 kit"}
;; {:prob 0.029682875, :label "n01318894 pet"}
diff --git a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
index 29aba26..150cd94 100644
--- a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
+++ b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
@@ -109,16 +109,16 @@
:label-name "softmax_label"
:data-batch-size batch-size
:last-batch-handle "pad"})
- data-and-labels (merge (data-desc->map (mx-io/provide-data train-iter))
- (data-desc->map (mx-io/provide-label train-iter))
+ data-and-labels (merge (data-desc->map (mx-io/provide-data-desc train-iter))
+ (data-desc->map (mx-io/provide-label-desc train-iter))
init-states)
init-states-data (mapv (fn [[k v]] (ndarray/zeros v {:ctx ctx})) init-states)
rnn-sym (sym-gen (first buckets))
rnn-mod (-> (m/module rnn-sym {:contexts devs})
- (m/bind {:data-shapes (into (mx-io/provide-data train-iter)
+ (m/bind {:data-shapes (into (mx-io/provide-data-desc train-iter)
(mapv (fn [[k v]] {:name k :shape v}) init-states))
- :label-shapes (mx-io/provide-label train-iter)})
+ :label-shapes (mx-io/provide-label-desc train-iter)})
(m/init-params {:initializer (init/xavier {:factor-type "in" :magnitude 2.34})})
(m/init-optimizer {:optimizer (optimizer/adam {:learning-rate learning-rate :wd 0.0001})}))
metric (eval-metric/custom-metric
@@ -141,8 +141,8 @@
"perplexity")]
- ;; Train for 2 epochs and then show the results of 75
- (doseq [epoch-num (range 2)]
+ ;; Train for 1 epochs and then show the results of 75
+ (doseq [epoch-num (range 1)]
(println "Doing epoch " epoch-num)
(mx-io/reduce-batches
train-iter
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
index d6f1499..a2b6399 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
@@ -17,11 +17,12 @@
(ns org.apache.clojure-mxnet.io
(:refer-clojure :exclude [next])
- (:require [org.apache.clojure-mxnet.base :as base]
+ (:require [clojure.spec.alpha :as s]
+ [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.dtype :as dtype]
- [clojure.spec.alpha :as s]
+ [org.apache.clojure-mxnet.layout :as layout]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.random :as random])
(:import (org.apache.mxnet IO DataDesc DataBatch NDArray)
@@ -57,18 +58,48 @@
(defn resize-iter [iter nbatch])
-(defn provide-data [pack-iterator]
+(defn provide-data
+ "Provides the description of the data iterator in the form of
+ [{:name name :shape shape-vec}]"
+ [pack-iterator]
(->> pack-iterator
(.provideData)
(util/scala-map->map)
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))
-(defn provide-label [pack-iterator]
+(defn provide-label
+ "Provides the description of the label iterator in the form of
+ [{:name name :shape shape-vec}]"
+ [pack-iterator]
(->> pack-iterator
(.provideLabel)
(util/scala-map->map)
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))
+(defn data-desc->map [dd]
+ {:name (.name dd)
+ :shape (mx-shape/->vec (.shape dd))
+ :dtype (.dtype dd)
+ :layout (.layout dd)})
+
+(defn provide-data-desc
+ "Provides the Data Desc of the data iterator in the form of
+ [{:name name :shape shape-vec :dtype dtype :layout layout}]"
+ [pack-iterator]
+ (->> pack-iterator
+ (.provideDataDesc)
+ (util/scala-vector->vec)
+ (mapv data-desc->map)))
+
+(defn provide-label-desc
+ "Provides the Data Desc of the label iterator in the form of
+ [{:name name :shape shape-vec :dtype dtype :layout layout}]"
+ [pack-iterator]
+ (->> pack-iterator
+ (.provideLabelDesc)
+ (util/scala-vector->vec)
+ (mapv data-desc->map)))
+
(defn reset [iterator]
(.reset iterator))
@@ -194,7 +225,8 @@
(defn ndarray-iter
" * NDArrayIter object in mxnet. Taking NDArray to get dataiter.
*
- * @param data vector of iter
+ * @param data vector of iter - Can either by in the form for [ndarray..] or
+ * {data-desc0 ndarray0 data-desc2 ndarray2 ...}
* @opts map of:
* :label Same as data, but is not fed to the model during testing.
* :data-batch-size Batch Size (default 1)
@@ -213,14 +245,23 @@
last-batch-handle "pad"
data-name "data"
label-name "label"}}]
- (new NDArrayIter
- (util/vec->indexed-seq data)
- (if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
- (int data-batch-size)
- shuffle
- last-batch-handle
- data-name
- label-name))
+ (if (map? data)
+ (new NDArrayIter
+ (.toIndexedSeq (util/list-map data))
+ (if label
+ (.toIndexedSeq (util/list-map label))
+ (util/empty-indexed-seq))
+ (int data-batch-size)
+ shuffle
+ last-batch-handle)
+ (new NDArrayIter
+ (util/vec->indexed-seq data)
+ (if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
+ (int data-batch-size)
+ shuffle
+ last-batch-handle
+ data-name
+ label-name)))
([data]
(ndarray-iter data {})))
@@ -230,24 +271,19 @@
(s/def ::name string?)
(s/def ::shape vector?)
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
+(s/def ::layout (s/or :custom string? :standard #{layout/UNDEFINED
+ layout/NCHW
+ layout/NTC
+ layout/NT
+ layout/N}))
(s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout]))
-;; NCHW is N:batch size C: channel H: height W: width
-;;; other layouts are
-;; NT, TNC, nad N
-;; the shape length must match the lengh of the layout string size
(defn data-desc
([{:keys [name shape dtype layout] :as opts
- :or {dtype base/MX_REAL_TYPE}}]
+ :or {dtype base/MX_REAL_TYPE
+ layout layout/UNDEFINED}}]
(util/validate! ::data-desc opts "Invalid data description")
- (let [sc (count shape)
- layout (or layout (cond
- (= 1 sc) "N"
- (= 2 sc) "NT"
- (= 3 sc) "TNC"
- (= 4 sc) "NCHW"
- :else (apply str (repeat sc "?"))))]
- (new DataDesc name (mx-shape/->shape shape) dtype layout)))
+ (new DataDesc name (mx-shape/->shape shape) dtype layout))
([name shape]
(data-desc {:name name :shape shape})))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
similarity index 65%
copy from contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
copy to contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
index ecd54ca..f379a7a 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
@@ -15,15 +15,21 @@
;; limitations under the License.
;;
-(ns org.apache.clojure-mxnet.test-util
- (:require [clojure.test :as t]))
+(ns org.apache.clojure-mxnet.layout
+ (:import (org.apache.mxnet Layout)))
-(defn approx= [tolerance x y]
- (if (and (number? x) (number? y))
- (let [diff (Math/abs (- x y))]
- (< diff tolerance))
- (and
- (= (count x) (count y))
- (reduce (fn [x y] (and x y))
- (map #(approx= tolerance %1 %2) x y)))))
+;;
+;; Layout definition of DataDesc
+;; N Batch size
+;; C channels
+;; H Height
+;; W Weight
+;; T sequence length
+;; __undefined__ default value of Layout
+;;
+(def UNDEFINED (Layout/UNDEFINED)) ;"__UNDEFINED__"
+(def NCHW (Layout/NCHW)) ;=> "NCHW"
+(def NTC (Layout/NTC)) ;=> "NTC"
+(def NT (Layout/NT)) ;=> "NT"
+(def N (Layout/N)) ;=> "N
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
index ab6d345..aa5ce39 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
@@ -88,8 +88,8 @@
can perform computation with the module.
mod : module
map of opts:
- :data-shapes Typically is (provide-data data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
- :label-shapes Typically is (provide-label data-iter). map of :name :shape :dtype and :layout
+ :data-shapes Typically is (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
+ :label-shapes Typically is (provide-label-desc data-iter). map of :name :shape :dtype and :layout
:for-training Default is `true`. Whether the executors should be bind for training.
:inputs-need-grad Default is `false`.
Whether the gradients to the input data need to be computed.
@@ -547,54 +547,12 @@
`:or {num-epoch 1
fit-params (new FitParams)}}]
(util/validate! ::fit-options opts "Invalid options for fit")
- (let [fmod (-> mod
- (bind {:data-shapes (mx-io/provide-data train-data)
- :label-shapes (mx-io/provide-label train-data)
- :for-training true
- :force-rebind (.forceRebind fit-params)})
- (init-params (remove (fn [[k v]] (nil? v))
- {:initializer (.initializer fit-params)
- :arg-params (.argParams fit-params)
- :aux-params (.auxParams fit-params)
- :allow-missing (.allowMissing fit-params)}))
- (init-optimizer (remove (fn [[k v]] (nil? v))
- {:optimizer (.optimizer fit-params)
- :kvstore (.kvstore fit-params)})))
- eval-metric (or (.evalMetric fit-params) (eval-metric/accuracy))
- val-metric (or (util/option->value (.validationMetric fit-params)) (eval-metric/accuracy))]
- (doseq [i (range num-epoch)]
- (let [tic (System/currentTimeMillis)]
- (mx-io/reduce-batches train-data
- (fn [batch-num batch]
- (-> fmod
- (forward batch)
- (backward)
- (update)
- (update-metric eval-metric (mx-io/batch-label batch)))
- (when-let [cb (util/option->value (.batchEndCallback fit-params))]
- (callback/invoke cb i batch-num eval-metric))
- (.dispose batch)
- (inc batch-num)))
- (println "Epoch " i " Train-" (eval-metric/get eval-metric))
- (println "Epoch " i " Time cost-" (- (System/currentTimeMillis) tic))
-
- ;;sync across kvstores
- (get-params fmod)
- (when-let [cb (util/option->value (.epochEndCallback fit-params))]
- (callback/invoke cb i 0 val-metric))
-
- ;; evaluation on the validation set
- (when eval-data
- (let [res (score fmod {:eval-data eval-data :eval-metric eval-metric :epoch i})]
- (println "Epoch " i " Validation- " res)))))
- fmod)
- ;; old way if the problem with the sizes get resolved in DataDesc
- #_(doto mod
- (.fit
- train-data
- (util/->option eval-data)
- (int num-epoch)
- fit-params)))
+ (doto mod
+ (.fit
+ train-data
+ (util/->option eval-data)
+ (int num-epoch)
+ fit-params)))
(s/def ::eval-data ::train-data)
(s/def ::num-batch integer?)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index 7ca4ede..e37a8bc 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -89,7 +89,7 @@
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
([start stop]
(arange start stop {})))
-
+
(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
index 12135fb..58b1d6d 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
@@ -144,7 +144,7 @@
which must be known from the rest of the net."
([start {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
- :as opts}]
+ :as opts}]
(Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
([start]
(arange-with-inference start {})))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
index ace39ec..9babf1e 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
@@ -22,7 +22,9 @@
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.shape :as mx-shape]
- [clojure.test :refer :all]))
+ [clojure.test :refer :all]
+ [org.apache.clojure-mxnet.dtype :as dtype]
+ [org.apache.clojure-mxnet.layout :as layout]))
(deftest test-mnsit-iter-and-mnist-pack
(let [_ (when-not (.exists (io/file "data/train-images-idx3-ubyte"))
@@ -59,6 +61,31 @@
(is (= label1 label0))
(is (= data1 data0))))))
+(deftest test-provide-data-and-label
+ (let [test-data (mx-io/mnist-iter {:image "data/train-images-idx3-ubyte"
+ :label "data/train-labels-idx1-ubyte"
+ :label-name "softmax_label"
+ :data-shape [1 28 28]
+ :label-shape [1 1 10]
+ :batch-size 100
+ :shuffle true
+ :flat false
+ :silent false
+ :seed 10})]
+ (is (= [{:name "data", :shape [100 1 28 28]}]
+ (mx-io/provide-data test-data)))
+ (is (= [{:name "softmax_label", :shape [100]}]
+ (mx-io/provide-label test-data)))
+ (is (= [{:name "data", :shape [100 1 28 28]
+ :dtype dtype/FLOAT32
+ :layout layout/UNDEFINED}]
+ (mx-io/provide-data-desc test-data)))
+ (is (= [{:name "softmax_label"
+ :shape [100]
+ :dtype dtype/FLOAT32
+ :layout layout/UNDEFINED}]
+ (mx-io/provide-label-desc test-data)))))
+
(deftest test-image-record-iter
(let [_ (when-not (.exists (io/file "data/cifar/train.rec"))
(sh "scripts/get_cifar_data.sh"))
@@ -162,4 +189,26 @@
:last-batch-handle "discard"})
nbatch2 7]
(is (= nbatch2 (mx-io/reduce-batches data-iter2 (fn [result batch] (inc result)))))
- (is (= [] (mx-io/iter-init-label data-iter2))))))
+ (is (= [] (mx-io/iter-init-label data-iter2))))
+
+ ;;; testing with a specified layout
+ (let [label-desc (mx-io/data-desc {:name "label"
+ :shape [2 2]
+ :dtype dtype/INT32
+ :layout layout/NT})
+ data-desc (mx-io/data-desc {:name "data"
+ :shape [2 2 2]
+ :dtype dtype/FLOAT32
+ :layout layout/NTC})
+ label (ndarray/ones [2 2] {:dtype dtype/INT32})
+ data (ndarray/ones [2 2 2] {:dtype dtype/FLOAT32})
+ data-iter3 (mx-io/ndarray-iter {data-desc data}
+ {:label {label-desc label}})]
+ (is (= {:dtype dtype/FLOAT32 :layout layout/NTC}
+ (-> (mx-io/provide-data-desc data-iter3)
+ first
+ (select-keys [:dtype :layout]))))
+ (is (= {:dtype dtype/INT32 :layout layout/NT}
+ (-> (mx-io/provide-label-desc data-iter3)
+ first
+ (select-keys [:dtype :layout])))))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
index 0f71b5a..d53af2e 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
@@ -20,6 +20,7 @@
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.io :as mx-io]
+ [org.apache.clojure-mxnet.layout :as layout]
[org.apache.clojure-mxnet.module :as m]
[org.apache.clojure-mxnet.monitor :as monitor]
[org.apache.clojure-mxnet.ndarray :as ndarray]
@@ -54,9 +55,9 @@
c (sym/+ a (sym/+ (sym/* b 2) (sym/* c 3)))
mod (m/module c ["b" "c" "a"] nil [(context/cpu 0) (context/cpu 1)])]
(-> mod
- (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout "NT"}
- {:name "c" :shape [5 5] :layout "NT"}
- {:name "a" :shape [5 5] :layout "NT"}]
+ (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout layout/NT}
+ {:name "c" :shape [5 5] :layout layout/NT}
+ {:name "a" :shape [5 5] :layout layout/NT}]
:inputs-need-grad true})
(m/init-params)
(m/forward {:data [(ndarray/ones [5 5])
@@ -172,7 +173,7 @@
(sym/linear-regression-output "softmax" {:data v :grad-scale 2}))
mod (m/module x)]
- (m/bind mod {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
+ (m/bind mod {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label train-data)})
(let [arg-params-correct {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2])
"fc_0_bias" (ndarray/array [0.35 0.35] [2])
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
index ecd54ca..d632c96 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
@@ -23,7 +23,7 @@
(let [diff (Math/abs (- x y))]
(< diff tolerance))
(and
- (= (count x) (count y))
- (reduce (fn [x y] (and x y))
- (map #(approx= tolerance %1 %2) x y)))))
+ (= (count x) (count y))
+ (reduce (fn [x y] (and x y))
+ (map #(approx= tolerance %1 %2) x y)))))