You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/03/10 23:47:11 UTC

[incubator-mxnet] branch master updated: [clojure-package][wip] add `->nd-vec` function in `ndarray.clj` (#14308)

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8be97d7  [clojure-package][wip] add `->nd-vec` function in `ndarray.clj` (#14308)
8be97d7 is described below

commit 8be97d7a79f9ea9815e41956e5f15ddcf25026b6
Author: Arthur Caillau <Ch...@users.noreply.github.com>
AuthorDate: Mon Mar 11 00:46:50 2019 +0100

    [clojure-package][wip] add `->nd-vec` function in `ndarray.clj` (#14308)
    
    * [clojure-package][wip] add `->nd-vec` function in `ndarray.clj`
    
    * WIP
    * Unit tests need to be added
    
    * [clojure-package][ndarray] add unit tests for `->nd-vec` util fn
---
 .../src/org/apache/clojure_mxnet/ndarray.clj       | 58 +++++++++++++++++++---
 .../test/org/apache/clojure_mxnet/ndarray_test.clj | 12 +++++
 2 files changed, 64 insertions(+), 6 deletions(-)

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index 651bdcb..151e18b 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -16,15 +16,18 @@
 ;;
 
 (ns org.apache.clojure-mxnet.ndarray
+  "NDArray API for Clojure package."
   (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
                             min repeat reverse set sort take to-array empty shuffle
                             ref])
-  (:require [org.apache.clojure-mxnet.base :as base]
-            [org.apache.clojure-mxnet.context :as mx-context]
-            [org.apache.clojure-mxnet.shape :as mx-shape]
-            [org.apache.clojure-mxnet.util :as util]
-            [clojure.reflect :as r]
-            [t6.from-scala.core :refer [$] :as $])
+  (:require
+    [clojure.spec.alpha :as s]
+
+    [org.apache.clojure-mxnet.base :as base]
+    [org.apache.clojure-mxnet.context :as mx-context]
+    [org.apache.clojure-mxnet.shape :as mx-shape]
+    [org.apache.clojure-mxnet.util :as util]
+    [t6.from-scala.core :refer [$] :as $])
   (:import (org.apache.mxnet NDArray)))
 
 ;; loads the generated functions into the namespace
@@ -167,3 +170,46 @@
 
 (defn shape-vec [ndarray]
   (mx-shape/->vec (shape ndarray)))
+
+(s/def ::ndarray #(instance? NDArray %))
+(s/def ::vector vector?)
+(s/def ::sequential sequential?)
+(s/def ::shape-vec-match-vec
+  (fn [[v vec-shape]] (= (count v) (reduce clojure.core/* 1 vec-shape))))
+
+(s/fdef vec->nd-vec
+        :args (s/cat :v ::sequential :shape-vec ::sequential)
+        :ret ::vector)
+
+(defn- vec->nd-vec
+  "Convert a vector `v` into a n-dimensional vector given the `shape-vec`
+   Ex:
+    (vec->nd-vec [1 2 3] [1 1 3])       ;[[[1 2 3]]]
+    (vec->nd-vec [1 2 3 4 5 6] [2 3 1]) ;[[[1] [2] [3]] [[4] [5] [6]]]
+    (vec->nd-vec [1 2 3 4 5 6] [1 2 3]) ;[[[1 2 3]] [4 5 6]]]
+    (vec->nd-vec [1 2 3 4 5 6] [3 1 2]) ;[[[1 2]] [[3 4]] [[5 6]]]
+    (vec->nd-vec [1 2 3 4 5 6] [3 2])   ;[[1 2] [3 4] [5 6]]"
+  [v [s1 & ss :as shape-vec]]
+  (util/validate! ::sequential v "Invalid input vector `v`")
+  (util/validate! ::sequential shape-vec "Invalid input vector `shape-vec`")
+  (util/validate! ::shape-vec-match-vec
+                  [v shape-vec]
+                  "Mismatch between vector `v` and vector `shape-vec`")
+  (if-not (seq ss)
+    (vec v)
+    (->> v
+         (partition (clojure.core// (count v) s1))
+         vec
+         (mapv #(vec->nd-vec % ss)))))
+
+(s/fdef ->nd-vec :args (s/cat :ndarray ::ndarray) :ret ::vector)
+
+(defn ->nd-vec
+  "Convert an ndarray `ndarray` into a n-dimensional Clojure vector.
+  Ex:
+    (->nd-vec (array [1] [1 1 1]))           ;[[[1.0]]]
+    (->nd-vec (array [1 2 3] [3 1 1]))       ;[[[1.0]] [[2.0]] [[3.0]]]
+    (->nd-vec (array [1 2 3 4 5 6]) [3 1 2]) ;[[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]]"
+  [ndarray]
+  (util/validate! ::ndarray ndarray "Invalid input array")
+  (vec->nd-vec (->vec ndarray) (shape-vec ndarray)))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
index 9ffd3ab..a9ae296 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
@@ -473,3 +473,15 @@
     (is (= [2 2] (ndarray/->int-vec nda)))
     (is (= [2.0 2.0] (ndarray/->double-vec nda)))
     (is (= [(byte 2) (byte 2)] (ndarray/->byte-vec nda)))))
+
+(deftest test->nd-vec
+  (is (= [[[1.0]]]
+         (ndarray/->nd-vec (ndarray/array [1] [1 1 1]))))
+  (is (= [[[1.0]] [[2.0]] [[3.0]]]
+         (ndarray/->nd-vec (ndarray/array [1 2 3] [3 1 1]))))
+  (is (= [[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]]
+         (ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 1 2]))))
+  (is (= [[[1.0] [2.0]] [[3.0] [4.0]] [[5.0] [6.0]]]
+         (ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 2 1]))))
+  (is (thrown-with-msg? Exception #"Invalid input array"
+                         (ndarray/->nd-vec [1 2 3 4 5]))))