You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/11/13 23:41:32 UTC

[incubator-mxnet] 02/02: add a load /save model

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit a4fb1074ab16d68c412ccd210651b9411270b363
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Tue Nov 13 17:32:43 2018 -0500

    add a load /save model
---
 .../examples/gan/src/gan/gan_mnist.clj             | 47 +++++++++++++++++++---
 1 file changed, 42 insertions(+), 5 deletions(-)

diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index bd53946..ac2293c 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -40,10 +40,23 @@
 (def data-dir "data/")
 (def output-path "results/")
 (def batch-size 5)
-(def num-epoch 100)
+(def num-epoch 30)
 
 (io/make-parents (str output-path "gout"))
 
+(defn last-saved-model-number []
+  (some->> "."
+           clojure.java.io/file
+           file-seq
+           (filter #(.isFile %))
+           (map #(.getName %))
+           (filter #(clojure.string/includes? %  "model-d"))
+           reverse
+           first
+           (re-seq #"\d{4}")
+           first
+           Integer/parseInt))
+
 
 
 #_(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
@@ -267,19 +280,41 @@
       (save-img-diff i n calc-diff))))
 
 (defn train [devs]
-  (let [mod-d  (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
+  (let [last-train-num (last-saved-model-number)
+        _ (println "The last saved trained epoch is " last-train-num)
+        mod-d  (-> (if last-train-num
+                     (do
+                       (println "Loading discriminator from checkpoint of epoch " last-train-num)
+                       (m/load-checkpoint {:contexts devs
+                                           :data-names ["data"]
+                                           :label-names ["label"]
+                                           :prefix "model-d"
+                                           :epoch last-train-num
+                                           :load-optimizer-states true}))
+                      (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]}))
                    (m/bind {:data-shapes (mx-io/provide-data flan-iter)
                             :label-shapes (mx-io/provide-label flan-iter)
                             :inputs-need-grad true})
                    (m/init-params {:initializer (init/normal 0.02)})
                    (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
-        mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
+        mod-g (-> (if last-train-num
+                    (do
+                     (println "Loading generator from checkpoint of epoch " last-train-num)
+                     (m/load-checkpoint {:contexts devs
+                                         :data-names ["rand"]
+                                         :label-names [""]
+                                         :prefix "model-g"
+                                         :epoch last-train-num
+                                         :load-optimizer-states true}))
+                    (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil}))
                   (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
                   (m/init-params {:initializer (init/normal 0.02)})
                   (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]
 
     (println "Training for " num-epoch " epochs...")
-    (doseq [i (range num-epoch)]
+    (doseq [i (if last-train-num
+                (range (inc last-train-num) (inc (+ last-train-num num-epoch)))
+                (range num-epoch))]
       (mx-io/reduce-batches flan-iter
                             (fn [n batch]
                               (let [rbatch (mx-io/next rand-noise-iter)
@@ -312,7 +347,9 @@
                                   (println "iteration = " i  "number = " n)
                                   (save-img-gout i n (ndarray/copy (ffirst out-g)))
                                   (save-img-data i n (first dbatch))
-                                  (calc-diff i n (ffirst diff-d)))
+                                  (calc-diff i n (ffirst diff-d))
+                                  (m/save-checkpoint mod-g {:prefix "model-g" :epoch i :save-opt-states true})
+                                  (m/save-checkpoint mod-d {:prefix "model-d" :epoch i :save-opt-states true}))
                                 (inc n)))))))
 
 (defn -main [& args]