You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/11/13 23:41:30 UTC

[incubator-mxnet] branch can-you-gan updated (d51230a -> a4fb107)

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a change to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git.


    from d51230a  flan 256x256
     new 837c3f1  128x128
     new a4fb107  add a load /save model

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../examples/gan/src/gan/gan_mnist.clj             | 66 +++++++++++++++-------
 1 file changed, 47 insertions(+), 19 deletions(-)


[incubator-mxnet] 01/02: 128x128

Posted by cm...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 837c3f1ccf0fe88ba085acb347172234a0bc2399
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Mon Nov 5 17:08:10 2018 -0500

    128x128
---
 .../examples/gan/src/gan/gan_mnist.clj                | 19 +++++--------------
 1 file changed, 5 insertions(+), 14 deletions(-)

diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index 7460f40..bd53946 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -55,8 +55,8 @@
                                        :batch-size batch-size
                                          :shuffle true}))
 
-(def flan-iter (mx-io/image-record-iter {:path-imgrec "flan-256.rec"
-                                         :data-shape [3 256 256]
+(def flan-iter (mx-io/image-record-iter {:path-imgrec "flan-128.rec"
+                                         :data-shape [3 128 128]
                                          :batch-size batch-size}))
 
 (defn normalize-rgb [x]
@@ -138,9 +138,8 @@
   (conv-output-size 8 4 1 2) ;=> 4.0
   (conv-output-size 4 4 0 1) ;=> 1
 
-  ;;;; for 256
-  (conv-output-size 256 4 3 2) ;=> 130
-  (conv-output-size 130 4 2 2) ;=> 67
+  ;;;; for 128
+  (conv-output-size 128 4 3 2) ;=> 66
   (conv-output-size 66 4 2 2) ;=> 34.0
   (conv-output-size 34 4 0 2) ;=> 16
   (conv-output-size 16 4 1 2) ;=> 8
@@ -170,11 +169,7 @@
 
 (defn discriminator []
   (as-> (sym/variable "data") data
-    (sym/convolution "d1" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter ndf :no-bias true})
-    (sym/batch-norm "dbn1" {:data data :fix-gamma true :eps eps})
-    (sym/leaky-re-lu "dact1" {:data data :act-type "leaky" :slope 0.2})
-
-    (sym/convolution "d2" {:data data :kernel [4 4] :pad [2 2] :stride [2 2] :num-filter (* 2 ndf) :no-bias true})
+    (sym/convolution "d2" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter (* 2 ndf) :no-bias true})
     (sym/batch-norm "dbn2" {:data data :fix-gamma true :eps eps})
     (sym/leaky-re-lu "dact1" {:data data :act_type "leaky" :slope 0.2})
 
@@ -222,10 +217,6 @@
     (sym/batch-norm "gbn5" {:data data :fix-gamma true :eps eps})
     (sym/activation "gact5" {:data data :act-type "relu"})
 
-    (sym/deconvolution "g6" {:data data :kernel [4 4] :pad [2 2] :stride [2 2] :num-filter ndf :no-bias true})
-    (sym/batch-norm "gbn6" {:data data :fix-gamma true :eps eps})
-    (sym/activation "gact6" {:data data :act-type "relu"})
-
     (sym/deconvolution "g7" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter nc :no-bias true})
     (sym/activation "gact7" {:data data :act-type "tanh"})))
 


[incubator-mxnet] 02/02: add a load /save model

Posted by cm...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit a4fb1074ab16d68c412ccd210651b9411270b363
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Tue Nov 13 17:32:43 2018 -0500

    add a load /save model
---
 .../examples/gan/src/gan/gan_mnist.clj             | 47 +++++++++++++++++++---
 1 file changed, 42 insertions(+), 5 deletions(-)

diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index bd53946..ac2293c 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -40,10 +40,23 @@
 (def data-dir "data/")
 (def output-path "results/")
 (def batch-size 5)
-(def num-epoch 100)
+(def num-epoch 30)
 
 (io/make-parents (str output-path "gout"))
 
+(defn last-saved-model-number []
+  (some->> "."
+           clojure.java.io/file
+           file-seq
+           (filter #(.isFile %))
+           (map #(.getName %))
+           (filter #(clojure.string/includes? %  "model-d"))
+           reverse
+           first
+           (re-seq #"\d{4}")
+           first
+           Integer/parseInt))
+
 
 
 #_(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
@@ -267,19 +280,41 @@
       (save-img-diff i n calc-diff))))
 
 (defn train [devs]
-  (let [mod-d  (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
+  (let [last-train-num (last-saved-model-number)
+        _ (println "The last saved trained epoch is " last-train-num)
+        mod-d  (-> (if last-train-num
+                     (do
+                       (println "Loading discriminator from checkpoint of epoch " last-train-num)
+                       (m/load-checkpoint {:contexts devs
+                                           :data-names ["data"]
+                                           :label-names ["label"]
+                                           :prefix "model-d"
+                                           :epoch last-train-num
+                                           :load-optimizer-states true}))
+                      (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]}))
                    (m/bind {:data-shapes (mx-io/provide-data flan-iter)
                             :label-shapes (mx-io/provide-label flan-iter)
                             :inputs-need-grad true})
                    (m/init-params {:initializer (init/normal 0.02)})
                    (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
-        mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
+        mod-g (-> (if last-train-num
+                    (do
+                     (println "Loading generator from checkpoint of epoch " last-train-num)
+                     (m/load-checkpoint {:contexts devs
+                                         :data-names ["rand"]
+                                         :label-names [""]
+                                         :prefix "model-g"
+                                         :epoch last-train-num
+                                         :load-optimizer-states true}))
+                    (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil}))
                   (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
                   (m/init-params {:initializer (init/normal 0.02)})
                   (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]
 
     (println "Training for " num-epoch " epochs...")
-    (doseq [i (range num-epoch)]
+    (doseq [i (if last-train-num
+                (range (inc last-train-num) (inc (+ last-train-num num-epoch)))
+                (range num-epoch))]
       (mx-io/reduce-batches flan-iter
                             (fn [n batch]
                               (let [rbatch (mx-io/next rand-noise-iter)
@@ -312,7 +347,9 @@
                                   (println "iteration = " i  "number = " n)
                                   (save-img-gout i n (ndarray/copy (ffirst out-g)))
                                   (save-img-data i n (first dbatch))
-                                  (calc-diff i n (ffirst diff-d)))
+                                  (calc-diff i n (ffirst diff-d))
+                                  (m/save-checkpoint mod-g {:prefix "model-g" :epoch i :save-opt-states true})
+                                  (m/save-checkpoint mod-d {:prefix "model-d" :epoch i :save-opt-states true}))
                                 (inc n)))))))
 
 (defn -main [& args]