You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/09/21 23:22:47 UTC

[incubator-mxnet] branch master updated: Fixed param coercion of clojure executor/forward (#12627) (#12630)

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 846bda4   Fixed param coercion of clojure executor/forward (#12627) (#12630)
846bda4 is described below

commit 846bda4a5dbee5bc82ad71e5695d2fa380dffd97
Author: paroda <pa...@gmail.com>
AuthorDate: Sat Sep 22 04:52:32 2018 +0530

     Fixed param coercion of clojure executor/forward (#12627) (#12630)
---
 .../src/org/apache/clojure_mxnet/executor.clj       |  3 ++-
 .../src/org/apache/clojure_mxnet/util.clj           | 21 +++++++++++++++++++++
 .../test/org/apache/clojure_mxnet/executor_test.clj | 20 ++++++++++++++++++++
 .../test/org/apache/clojure_mxnet/util_test.clj     | 17 ++++++++++++++++-
 4 files changed, 59 insertions(+), 2 deletions(-)

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
index 4f4155e..b9883f7 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
@@ -34,7 +34,8 @@
    (do (.forward executor)
        executor))
   ([executor is-train kwargs]
-   (do (.forward executor is-train (util/nil-or-coerce-param kwargs #{"scala.collection.immutable.Map"})))))
+   (do (.forward executor is-train (util/map->scala-tuple-seq kwargs))
+       executor)))
 
 (defn backward
   "* Do backward pass to get the gradient of arguments.
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 8f2bb3b..6f22b0e 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -204,3 +204,24 @@
     (throw (ex-info error-msg
                     (s/explain-data spec value)))))
 
+(defn map->scala-tuple-seq
+  "* Convert a map to a scala-Seq of scala-Tubple.
+   * Should also work if a seq of seq of 2 things passed.
+   * Otherwise passed through unchanged."
+  [map-or-tuple-seq]
+  (letfn [(key->name [k]
+            (if (or (keyword? k) (string? k) (symbol? k))
+              (string/replace (name k) "-" "_")
+              k))
+          (->tuple [kvp-or-tuple]
+            (if (coll? kvp-or-tuple)
+              (let [[k v] kvp-or-tuple]
+                ($/tuple (key->name k) v))
+              ;; pass-through
+              kvp-or-tuple))]
+    (if (coll? map-or-tuple-seq)
+      (->> map-or-tuple-seq
+           (map ->tuple)
+           (apply $/immutable-list))
+      ;; pass-through
+      map-or-tuple-seq)))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
index b2a87d4..fb73f00 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
@@ -74,3 +74,23 @@
     (is (every? #(= 4.0 %) (->> (executor/outputs exec)
                                 (map ndarray/->vec)
                                 first)))))
+
+(deftest test-forward
+  (let [a (sym/variable "a")
+        b (sym/variable "b")
+        c (sym/+ a b)
+        ex (sym/bind c {:a (ndarray/* (ndarray/ones [1 2]) 2)
+                        :b (ndarray/* (ndarray/ones [1 2]) 3)})]
+    ;; test forward with binded values
+    (executor/forward ex)
+    (is (= [5.0 5.0] (-> ex executor/outputs first ndarray/->vec)))
+    ;; test forward with new a (b is still [3.0 3.0]
+    (executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 4)})
+    (is (= [7.0 7.0] (-> ex executor/outputs first ndarray/->vec)))
+    ;; test forward with new b (a is still [4.0 4.0]
+    (executor/forward ex false {:b (ndarray/* (ndarray/ones [1 2]) 5)})
+    (is (= [9.0 9.0] (-> ex executor/outputs first ndarray/->vec)))
+    ;; test forward with new a & b
+    (executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 6)
+                                :b (ndarray/* (ndarray/ones [1 2]) 7)})
+    (is (= [13.0 13.0] (-> ex executor/outputs first ndarray/->vec)))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index de34808..ee77103 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -190,4 +190,19 @@
         data2 [1 1 1 1 9 9 9 9]
         data3 [1 1 1 2]]
     (is (not (test-util/approx= 1e-9 data1 data2)))
-    (is (test-util/approx= 2 data1 data3))))
\ No newline at end of file
+    (is (test-util/approx= 2 data1 data3))))
+
+(deftest test-map->scala-tuple-seq
+  ;; convert as much, and pass-through the rest
+  (is (nil? (util/map->scala-tuple-seq nil)))
+  (is (= "List()"
+         (str (util/map->scala-tuple-seq {}))
+         (str (util/map->scala-tuple-seq []))
+         (str (util/map->scala-tuple-seq '()))))
+  (is (= "List(a, b)" (str (util/map->scala-tuple-seq ["a" "b"]))))
+  (is (= "List((a,b), (c,d), (e,f), (a_b,g), (c_d,h), (e_f,i))"
+         (str (util/map->scala-tuple-seq {:a "b", 'c "d", "e" "f"
+                                          :a-b "g", 'c-d "h", "e-f" "i"}))))
+  (let [nda (util/map->scala-tuple-seq {:a-b (ndarray/ones [1 2])})]
+    (is (= "a_b" (._1 (.head nda))))
+    (is (= [1.0 1.0] (ndarray/->vec (._2 (.head nda)))))))