You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/09/21 23:22:47 UTC
[incubator-mxnet] branch master updated: Fixed param coercion of
clojure executor/forward (#12627) (#12630)
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 846bda4 Fixed param coercion of clojure executor/forward (#12627) (#12630)
846bda4 is described below
commit 846bda4a5dbee5bc82ad71e5695d2fa380dffd97
Author: paroda <pa...@gmail.com>
AuthorDate: Sat Sep 22 04:52:32 2018 +0530
Fixed param coercion of clojure executor/forward (#12627) (#12630)
---
.../src/org/apache/clojure_mxnet/executor.clj | 3 ++-
.../src/org/apache/clojure_mxnet/util.clj | 21 +++++++++++++++++++++
.../test/org/apache/clojure_mxnet/executor_test.clj | 20 ++++++++++++++++++++
.../test/org/apache/clojure_mxnet/util_test.clj | 17 ++++++++++++++++-
4 files changed, 59 insertions(+), 2 deletions(-)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
index 4f4155e..b9883f7 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
@@ -34,7 +34,8 @@
(do (.forward executor)
executor))
([executor is-train kwargs]
- (do (.forward executor is-train (util/nil-or-coerce-param kwargs #{"scala.collection.immutable.Map"})))))
+ (do (.forward executor is-train (util/map->scala-tuple-seq kwargs))
+ executor)))
(defn backward
"* Do backward pass to get the gradient of arguments.
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 8f2bb3b..6f22b0e 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -204,3 +204,24 @@
(throw (ex-info error-msg
(s/explain-data spec value)))))
+(defn map->scala-tuple-seq
+ "* Convert a map to a scala-Seq of scala-Tubple.
+ * Should also work if a seq of seq of 2 things passed.
+ * Otherwise passed through unchanged."
+ [map-or-tuple-seq]
+ (letfn [(key->name [k]
+ (if (or (keyword? k) (string? k) (symbol? k))
+ (string/replace (name k) "-" "_")
+ k))
+ (->tuple [kvp-or-tuple]
+ (if (coll? kvp-or-tuple)
+ (let [[k v] kvp-or-tuple]
+ ($/tuple (key->name k) v))
+ ;; pass-through
+ kvp-or-tuple))]
+ (if (coll? map-or-tuple-seq)
+ (->> map-or-tuple-seq
+ (map ->tuple)
+ (apply $/immutable-list))
+ ;; pass-through
+ map-or-tuple-seq)))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
index b2a87d4..fb73f00 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
@@ -74,3 +74,23 @@
(is (every? #(= 4.0 %) (->> (executor/outputs exec)
(map ndarray/->vec)
first)))))
+
+(deftest test-forward
+ (let [a (sym/variable "a")
+ b (sym/variable "b")
+ c (sym/+ a b)
+ ex (sym/bind c {:a (ndarray/* (ndarray/ones [1 2]) 2)
+ :b (ndarray/* (ndarray/ones [1 2]) 3)})]
+ ;; test forward with binded values
+ (executor/forward ex)
+ (is (= [5.0 5.0] (-> ex executor/outputs first ndarray/->vec)))
+ ;; test forward with new a (b is still [3.0 3.0]
+ (executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 4)})
+ (is (= [7.0 7.0] (-> ex executor/outputs first ndarray/->vec)))
+ ;; test forward with new b (a is still [4.0 4.0]
+ (executor/forward ex false {:b (ndarray/* (ndarray/ones [1 2]) 5)})
+ (is (= [9.0 9.0] (-> ex executor/outputs first ndarray/->vec)))
+ ;; test forward with new a & b
+ (executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 6)
+ :b (ndarray/* (ndarray/ones [1 2]) 7)})
+ (is (= [13.0 13.0] (-> ex executor/outputs first ndarray/->vec)))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index de34808..ee77103 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -190,4 +190,19 @@
data2 [1 1 1 1 9 9 9 9]
data3 [1 1 1 2]]
(is (not (test-util/approx= 1e-9 data1 data2)))
- (is (test-util/approx= 2 data1 data3))))
\ No newline at end of file
+ (is (test-util/approx= 2 data1 data3))))
+
+(deftest test-map->scala-tuple-seq
+ ;; convert as much, and pass-through the rest
+ (is (nil? (util/map->scala-tuple-seq nil)))
+ (is (= "List()"
+ (str (util/map->scala-tuple-seq {}))
+ (str (util/map->scala-tuple-seq []))
+ (str (util/map->scala-tuple-seq '()))))
+ (is (= "List(a, b)" (str (util/map->scala-tuple-seq ["a" "b"]))))
+ (is (= "List((a,b), (c,d), (e,f), (a_b,g), (c_d,h), (e_f,i))"
+ (str (util/map->scala-tuple-seq {:a "b", 'c "d", "e" "f"
+ :a-b "g", 'c-d "h", "e-f" "i"}))))
+ (let [nda (util/map->scala-tuple-seq {:a-b (ndarray/ones [1 2])})]
+ (is (= "a_b" (._1 (.head nda))))
+ (is (= [1.0 1.0] (ndarray/->vec (._2 (.head nda)))))))