This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch can-you-gan in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 95593b67da2ec7e1eb9b96f3fa5afa6c55487c1c Author: gigasquid <[email protected]> AuthorDate: Fri Nov 2 17:57:51 2018 -0400 wip --- contrib/clojure-package/examples/gan/project.clj | 6 +- .../examples/gan/src/gan/gan_mnist.clj | 111 ++++++++++++++++----- .../clojure-package/examples/gan/src/gan/viz.clj | 13 ++- 3 files changed, 98 insertions(+), 32 deletions(-) diff --git a/contrib/clojure-package/examples/gan/project.clj b/contrib/clojure-package/examples/gan/project.clj index bebbc20..72354af 100644 --- a/contrib/clojure-package/examples/gan/project.clj +++ b/contrib/clojure-package/examples/gan/project.clj @@ -19,6 +19,8 @@ :description "GAN MNIST with MXNet" :plugins [[lein-cljfmt "0.5.7"]] :dependencies [[org.clojure/clojure "1.9.0"] - [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"] - [nu.pattern/opencv "2.4.9-7"]] + [org.apache.mxnet.contrib.clojure/clojure-mxnet-osx-cpu "1.3.0"] + [nu.pattern/opencv "2.4.9-7"] + [net.mikera/imagez "0.12.0"] + [thinktopic/think.image "0.4.16"]] :main gan.gan-mnist) diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj index 14dd2c5..9a7bc35 100644 --- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj +++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj @@ -29,7 +29,9 @@ [org.apache.clojure-mxnet.shape :as mx-shape] [org.apache.clojure-mxnet.util :as util] [gan.viz :as viz] - [org.apache.clojure-mxnet.context :as context]) + [org.apache.clojure-mxnet.context :as context] + [think.image.pixel :as pixel] + [mikera.image.core :as img]) (:gen-class)) ;; based off of https://medium.com/@julsimon/generative-adversarial-networks-on-apache-mxnet-part-1-b6d39e6b5df1 @@ -37,19 +39,76 @@ (def data-dir "data/") (def output-path "results/") -(def batch-size 100) -(def num-epoch 10) +(def batch-size 10) +(def num-epoch 100) (io/make-parents (str output-path "gout")) -(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) + + +#_(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) (sh "../../scripts/get_mnist_data.sh")) -(defonce mnist-iter (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") +#_(defonce mnist-iter (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") :label (str data-dir "train-labels-idx1-ubyte") :input-shape [1 28 28] :batch-size batch-size - :shuffle true})) + :shuffle true})) + +(def flan-iter (mx-io/image-record-iter {:path-imgrec "flan.rec" + :data-shape [3 28 28] + :batch-size batch-size})) + + +(defn postprocess-image [img] + (let [datas (ndarray/->vec img) + image-shape (mx-shape/->vec (ndarray/shape img)) + spatial-size (* (get image-shape 2) (get image-shape 3)) + pics (doall (partition (* 3 spatial-size) datas)) + pixels (mapv + (fn [pic] + (let [[rs gs bs] (doall (partition spatial-size pic)) + this-pixels (mapv (fn [r g b] + (pixel/pack-pixel + (int r) + (int g) + (int b) + (int 255))) + rs gs bs)] + this-pixels)) + pics) + new-pixels (into [] (flatten pixels)) + new-image (img/new-image (* 1 (get image-shape 3)) (* batch-size (get image-shape 2))) + _ (img/set-pixels new-image (int-array new-pixels))] + new-image)) + +(defn postprocess-write-img [img filename] + (img/write (-> (postprocess-image img) + (img/zoom 1.5)) filename "png")) + +(comment + (def test-img (first (mx-io/batch-data (mx-io/next flan-iter)))) + + (ndarray/shape test-img) + + (postprocess-image test-img) + (do (img/show (-> (postprocess-image test-img) + (img/zoom 1.5)))) + + (img/write (-> (postprocess-image test-img) + (img/zoom 1.5)) "results/flan.png" "png") + + (ndarray/->vec test-img) + + (viz/im-sav {:title "Carin" + :output-path output-path + :x test-img + :flip false}) + + + +) + (def rand-noise-iter (mx-io/rand-iter [batch-size 100 1 1])) @@ -66,6 +125,11 @@ (conv-output-size 8 4 1 2) ;=> 4.0 (conv-output-size 4 4 0 1) ;=> 1 + (conv-output-size 128 4 3 2) ;=> 66 + (conv-output-size 66 4 2 2) ;=> 34 + (conv-output-size 34 4 2 2) ;=> 18.0 + (conv-output-size 18 5 2 2) ;=> 1 + ;; Calcing the layer sizes for generator (defn deconv-output-size [input-size kernel-size padding stride] (- @@ -80,7 +144,7 @@ (def ndf 28) ;; image height /width -(def nc 1) ;; number of channels +(def nc 3) ;; number of channels (def eps (float (+ 1e-5 1e-12))) (def lr 0.0005) ;; learning rate (def beta1 0.5) @@ -130,22 +194,17 @@ (defn save-img-gout [i n x] (do - (viz/im-sav {:title (str "gout-" i "-" n) - :output-path output-path - :x x - :flip false}))) + (println "Carin gout shape is " (ndarray/shape x)) + (postprocess-write-img x (str output-path "/" "gout-" i "-" n ".png")))) (defn save-img-diff [i n x] - (do (viz/im-sav {:title (str "diff-" i "-" n) - :output-path output-path - :x x - :flip false}))) + (do + (postprocess-write-img x (str output-path "/" "diff-" i "-" n ".png")))) (defn save-img-data [i n batch] - (do (viz/im-sav {:title (str "data-" i "-" n) - :output-path output-path - :x (first (mx-io/batch-data batch)) - :flip false}))) + (do + (postprocess-write-img + (first (mx-io/batch-data batch)) (str output-path "/" "data-" i "-" n ".png")))) (defn calc-diff [i n diff-d] (let [diff (ndarray/copy diff-d) @@ -159,8 +218,8 @@ (defn train [devs] (let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]}) - (m/bind {:data-shapes (mx-io/provide-data mnist-iter) - :label-shapes (mx-io/provide-label mnist-iter) + (m/bind {:data-shapes (mx-io/provide-data flan-iter) + :label-shapes (mx-io/provide-label flan-iter) :inputs-need-grad true}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})})) @@ -171,13 +230,13 @@ (println "Training for " num-epoch " epochs...") (doseq [i (range num-epoch)] - (mx-io/reduce-batches mnist-iter + (mx-io/reduce-batches flan-iter (fn [n batch] (let [rbatch (mx-io/next rand-noise-iter) out-g (-> mod-g (m/forward rbatch) (m/outputs)) - ;; update the discriminiator on the fake + ;; update the discriminiator on the fake grads-f (mapv #(ndarray/copy (first %)) (-> mod-d (m/forward {:data (first out-g) :label [(ndarray/zeros [batch-size])]}) (m/backward) @@ -198,7 +257,7 @@ _ (-> mod-g (m/backward (first diff-d)) (m/update))] - (when (zero? (mod n 100)) + (when (zero? n) (println "iteration = " i "number = " n) (save-img-gout i n (ndarray/copy (ffirst out-g))) (save-img-data i n batch) @@ -214,4 +273,6 @@ (train devs))) (comment - (train [(context/cpu)])) + (train [(context/cpu)]) + + ) diff --git a/contrib/clojure-package/examples/gan/src/gan/viz.clj b/contrib/clojure-package/examples/gan/src/gan/viz.clj index 8b57b94..2780252 100644 --- a/contrib/clojure-package/examples/gan/src/gan/viz.clj +++ b/contrib/clojure-package/examples/gan/src/gan/viz.clj @@ -36,18 +36,21 @@ :else (int %))) (mapv #(.byteValue %)))) + + (defn get-img [raw-data channels height width flip] (let [totals (* height width) img (if (> channels 1) ;; rgb image - (let [[ra ga ba] (byte-array (partition totals raw-data)) + (let [[ra ga ba] (doall (partition totals raw-data)) rr (new Mat height width (CvType/CV_8U)) gg (new Mat height width (CvType/CV_8U)) bb (new Mat height width (CvType/CV_8U)) - result (new Mat)] - (.put rr (int 0) (int 0) ra) - (.put gg (int 0) (int 0) ga) - (.put bb (int 0) (int 0) ba) + result (new Mat height width (CvType/CV_8U))] + (do + (.put rr 0 0 (byte-array ra)) + (.put gg 0 0 (byte-array ga)) + (.put bb 0 0 (byte-array ba))) (Core/merge (java.util.ArrayList. [bb gg rr]) result) result) ;; gray image
