You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/08/31 15:18:39 UTC

[incubator-mxnet] 01/01: update rand-iter as well

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch update-data-desc-clojure
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 4e653de2ae16e4e3217fafc413eb6e07c3acfc22
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Aug 31 11:18:00 2018 -0400

    update rand-iter as well
---
 contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj   |  6 ++++++
 .../clojure-package/test/org/apache/clojure_mxnet/io_test.clj | 11 +++++++++--
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
index 3463545..5674e9c 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
@@ -340,7 +340,13 @@
              label-name]
        (provideData []
          (util/list-map {data-name (mx-shape/->vec (ndarray/shape (first data)))}))
+       (provideDataDesc []
+         (util/vec->indexed-seq [(data-desc {:name data-name
+                                             :shape (mx-shape/->vec
+                                                     (ndarray/shape
+                                                      (first data)))})]))
        (provideLabel [] (util/empty-list-map))
+       (provideLabelDesc [] (util/empty-list-map))
        (hasNext [] true)
        (getData
          ([] (util/vec->indexed-seq [(random/normal 0 1 (mx-shape/->vec (ndarray/shape (first data))))])))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
index 9babf1e..7eef73c 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj
@@ -202,8 +202,8 @@
                                       :layout layout/NTC})
           label (ndarray/ones [2 2] {:dtype dtype/INT32})
           data (ndarray/ones [2 2 2] {:dtype dtype/FLOAT32})
-          data-iter3 (mx-io/ndarray-iter {data-desc data}
-                                         {:label {label-desc label}})]
+          data-iter3 (mx-io/ndarray-iter [{data-desc data}]
+                                         {:label [{label-desc label}]})]
       (is (= {:dtype dtype/FLOAT32 :layout layout/NTC}
              (-> (mx-io/provide-data-desc data-iter3)
                  first
@@ -212,3 +212,10 @@
              (-> (mx-io/provide-label-desc data-iter3)
                  first
                  (select-keys [:dtype :layout])))))))
+
+(deftest test-rand-iter
+  (let [rand-iter (mx-io/rand-iter [3 100 1 1])]
+    (is (= [{:name "rand", :shape [3 100 1 1]}]
+           (mx-io/provide-data rand-iter)))
+    (is (= [] (mx-io/provide-label rand-iter)))
+    (mx-io/provide-data-desc rand-iter)))