You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/11/13 23:41:32 UTC
[incubator-mxnet] 02/02: add a load /save model
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit a4fb1074ab16d68c412ccd210651b9411270b363
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Tue Nov 13 17:32:43 2018 -0500
add a load /save model
---
.../examples/gan/src/gan/gan_mnist.clj | 47 +++++++++++++++++++---
1 file changed, 42 insertions(+), 5 deletions(-)
diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index bd53946..ac2293c 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -40,10 +40,23 @@
(def data-dir "data/")
(def output-path "results/")
(def batch-size 5)
-(def num-epoch 100)
+(def num-epoch 30)
(io/make-parents (str output-path "gout"))
+(defn last-saved-model-number []
+ (some->> "."
+ clojure.java.io/file
+ file-seq
+ (filter #(.isFile %))
+ (map #(.getName %))
+ (filter #(clojure.string/includes? % "model-d"))
+ reverse
+ first
+ (re-seq #"\d{4}")
+ first
+ Integer/parseInt))
+
#_(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
@@ -267,19 +280,41 @@
(save-img-diff i n calc-diff))))
(defn train [devs]
- (let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
+ (let [last-train-num (last-saved-model-number)
+ _ (println "The last saved trained epoch is " last-train-num)
+ mod-d (-> (if last-train-num
+ (do
+ (println "Loading discriminator from checkpoint of epoch " last-train-num)
+ (m/load-checkpoint {:contexts devs
+ :data-names ["data"]
+ :label-names ["label"]
+ :prefix "model-d"
+ :epoch last-train-num
+ :load-optimizer-states true}))
+ (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]}))
(m/bind {:data-shapes (mx-io/provide-data flan-iter)
:label-shapes (mx-io/provide-label flan-iter)
:inputs-need-grad true})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
- mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
+ mod-g (-> (if last-train-num
+ (do
+ (println "Loading generator from checkpoint of epoch " last-train-num)
+ (m/load-checkpoint {:contexts devs
+ :data-names ["rand"]
+ :label-names [""]
+ :prefix "model-g"
+ :epoch last-train-num
+ :load-optimizer-states true}))
+ (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil}))
(m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]
(println "Training for " num-epoch " epochs...")
- (doseq [i (range num-epoch)]
+ (doseq [i (if last-train-num
+ (range (inc last-train-num) (inc (+ last-train-num num-epoch)))
+ (range num-epoch))]
(mx-io/reduce-batches flan-iter
(fn [n batch]
(let [rbatch (mx-io/next rand-noise-iter)
@@ -312,7 +347,9 @@
(println "iteration = " i "number = " n)
(save-img-gout i n (ndarray/copy (ffirst out-g)))
(save-img-data i n (first dbatch))
- (calc-diff i n (ffirst diff-d)))
+ (calc-diff i n (ffirst diff-d))
+ (m/save-checkpoint mod-g {:prefix "model-g" :epoch i :save-opt-states true})
+ (m/save-checkpoint mod-d {:prefix "model-d" :epoch i :save-opt-states true}))
(inc n)))))))
(defn -main [& args]