You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/07/19 17:24:50 UTC

[incubator-mxnet] 03/05: basic autoencoder

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-autoencoder
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit ef3451a38c99b116fddead55eff8145b0f73bafb
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Jul 19 10:08:08 2019 -0400

    basic autoencoder
---
 .../examples/gan/src/gan/auto_encoder.clj          | 56 ++++++++++++----------
 1 file changed, 32 insertions(+), 24 deletions(-)

diff --git a/contrib/clojure-package/examples/gan/src/gan/auto_encoder.clj b/contrib/clojure-package/examples/gan/src/gan/auto_encoder.clj
index ee1a9a6..db44b6a 100644
--- a/contrib/clojure-package/examples/gan/src/gan/auto_encoder.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/auto_encoder.clj
@@ -71,6 +71,29 @@
 
     ))
 
+(def data-desc (first (mx-io/provide-data-desc train-data)))
+
+(def model (-> (m/module (get-symbol) {:data-names ["input"] :label-names ["input_"]})
+               (m/bind {:data-shapes [(assoc data-desc :name "input")]
+                        :label-shapes [(assoc data-desc :name "input_")]})
+               (m/init-params {:initializer  (initializer/uniform 1)})
+               (m/init-optimizer {:optimizer (optimizer/adam {:learning-rage 0.001})})))
+
+(def my-metric (eval-metric/mse))
+
+(defn train [num-epochs]
+  (doseq [epoch-num (range 0 num-epochs)]
+    (println "starting epoch " epoch-num)
+    (mx-io/do-batches
+     train-data
+     (fn [batch]
+       (-> model
+           (m/forward {:data (mx-io/batch-data batch) :label (mx-io/batch-data batch)})
+           (m/update-metric my-metric (mx-io/batch-data batch))
+           (m/backward)
+           (m/update))))
+    (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset my-metric))))
+
 (comment
 
   (mx-io/provide-data train-data)
@@ -79,38 +102,23 @@
   (def my-batch (mx-io/next train-data))
   (def images (mx-io/batch-data my-batch))
   (ndarray/shape (ndarray/reshape (first images) [100 1 28 28]))
-  (viz/im-sav {:title "first" :output-path "results/" :x (first images)})
-  (viz/im-sav {:title "cm-first" :output-path "results/" :x (ndarray/reshape (first images) [100 1 28 28])})
+  (viz/im-sav {:title "originals" :output-path "results/" :x (ndarray/reshape (first images) [100 1 28 28])})
+
 
+  (train 1)
 
-  (def preds (m/predict-batch my-mod {:data images} ))
-  (ndarray/shape (ndarray/reshape (first preds) [100 1 28 28]))
-    (viz/im-sav {:title "cm-preds" :output-path "results/" :x (ndarray/reshape (first preds) [100 1 28 28])})
   
-  (def my-metric (eval-metric/mse))
+  (def preds (m/predict-batch model {:data images} ))
+  (viz/im-sav {:title "preds" :output-path "results/" :x (ndarray/reshape (first preds) [100 1 28 28])})
+  
 
 
   (sym/list-arguments (m/symbol my-mod))
   (def data-desc (first (mx-io/provide-data-desc train-data)))
 
-  (def my-mod (-> (m/module (get-symbol) {:data-names ["input"] :label-names ["input_"]})
-                  (m/bind {:data-shapes [(assoc data-desc :name "input")]
-                           :label-shapes [(assoc data-desc :name "input_")]})
-                  (m/init-params {:initializer  (initializer/uniform 1)})
-                  (m/init-optimizer {:optimizer (optimizer/adam {:learning-rage 0.001})})))
-
-
-  (doseq [epoch-num (range 0 1)]
-      (println "starting epoch " epoch-num)
-      (mx-io/do-batches
-       train-data
-       (fn [batch]
-         (-> my-mod
-             (m/forward {:data (mx-io/batch-data batch) :label (mx-io/batch-data batch)})
-             (m/update-metric my-metric (mx-io/batch-data batch))
-             (m/backward)
-             (m/update))))
-      (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset my-metric)))
+
+
+