You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/21 13:06:40 UTC

[GitHub] paroda closed pull request #12630: [MXNET-12627] Fixed param coercion of clojure executor/forward

paroda closed pull request #12630: [MXNET-12627] Fixed param coercion of clojure executor/forward
URL: https://github.com/apache/incubator-mxnet/pull/12630
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
index 4f4155e2d80..64857b3cb92 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
@@ -18,6 +18,8 @@
 (ns org.apache.clojure-mxnet.executor
   (:require [org.apache.clojure-mxnet.util :as util]
             [clojure.reflect :as r]
+            [clojure.string :as str]
+            [t6.from-scala.core :as $]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.shape :as mx-shape]))
 
@@ -26,6 +28,28 @@
 (defn ->vec [nd-array]
   (vec (.toArray nd-array)))
 
+(defn- coerce-map->tuple-seq
+  "* Convert a map to a scala-Seq of scala-Tubple.
+   * Should also work if a seq of seq of 2 things passed.
+   * Otherwise passed through unchanged."
+  [map-or-tuple-seq]
+  (letfn [(key->name [k]
+            (if (or (keyword? k) (string? k) (symbol? k))
+              (str/replace (name k) "-" "_")
+              k))
+          (->tuple [kvp-or-tuple]
+            (if (coll? kvp-or-tuple)
+              (let [[k v] kvp-or-tuple]
+                ($/tuple (key->name k) v))
+              ;; pass-through
+              kvp-or-tuple))]
+    (if (coll? map-or-tuple-seq)
+      (->> map-or-tuple-seq
+           (map ->tuple)
+           (apply $/immutable-list))
+      ;; pass-through
+      map-or-tuple-seq)))
+
 (defn forward
   "* Calculate the outputs specified by the binded symbol.
    * @param is-train whether this forward is for evaluation purpose.
@@ -34,7 +58,7 @@
    (do (.forward executor)
        executor))
   ([executor is-train kwargs]
-   (do (.forward executor is-train (util/nil-or-coerce-param kwargs #{"scala.collection.immutable.Map"})))))
+   (do (.forward executor is-train (coerce-map->tuple-seq kwargs)))))
 
 (defn backward
   "* Do backward pass to get the gradient of arguments.


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services