This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch can-you-gan in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit a4fb1074ab16d68c412ccd210651b9411270b363 Author: gigasquid <[email protected]> AuthorDate: Tue Nov 13 17:32:43 2018 -0500 add a load /save model --- .../examples/gan/src/gan/gan_mnist.clj | 47 +++++++++++++++++++--- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj index bd53946..ac2293c 100644 --- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj +++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj @@ -40,10 +40,23 @@ (def data-dir "data/") (def output-path "results/") (def batch-size 5) -(def num-epoch 100) +(def num-epoch 30) (io/make-parents (str output-path "gout")) +(defn last-saved-model-number [] + (some->> "." + clojure.java.io/file + file-seq + (filter #(.isFile %)) + (map #(.getName %)) + (filter #(clojure.string/includes? % "model-d")) + reverse + first + (re-seq #"\d{4}") + first + Integer/parseInt)) + #_(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) @@ -267,19 +280,41 @@ (save-img-diff i n calc-diff)))) (defn train [devs] - (let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]}) + (let [last-train-num (last-saved-model-number) + _ (println "The last saved trained epoch is " last-train-num) + mod-d (-> (if last-train-num + (do + (println "Loading discriminator from checkpoint of epoch " last-train-num) + (m/load-checkpoint {:contexts devs + :data-names ["data"] + :label-names ["label"] + :prefix "model-d" + :epoch last-train-num + :load-optimizer-states true})) + (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})) (m/bind {:data-shapes (mx-io/provide-data flan-iter) :label-shapes (mx-io/provide-label flan-iter) :inputs-need-grad true}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})})) - mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil}) + mod-g (-> (if last-train-num + (do + (println "Loading generator from checkpoint of epoch " last-train-num) + (m/load-checkpoint {:contexts devs + :data-names ["rand"] + :label-names [""] + :prefix "model-g" + :epoch last-train-num + :load-optimizer-states true})) + (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})) (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))] (println "Training for " num-epoch " epochs...") - (doseq [i (range num-epoch)] + (doseq [i (if last-train-num + (range (inc last-train-num) (inc (+ last-train-num num-epoch))) + (range num-epoch))] (mx-io/reduce-batches flan-iter (fn [n batch] (let [rbatch (mx-io/next rand-noise-iter) @@ -312,7 +347,9 @@ (println "iteration = " i "number = " n) (save-img-gout i n (ndarray/copy (ffirst out-g))) (save-img-data i n (first dbatch)) - (calc-diff i n (ffirst diff-d))) + (calc-diff i n (ffirst diff-d)) + (m/save-checkpoint mod-g {:prefix "model-g" :epoch i :save-opt-states true}) + (m/save-checkpoint mod-d {:prefix "model-d" :epoch i :save-opt-states true})) (inc n))))))) (defn -main [& args]
