This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch can-you-gan in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit bcb527515569292b40460892d4318e23b48e8d55 Author: gigasquid <[email protected]> AuthorDate: Fri Nov 2 19:03:26 2018 -0400 wip --- .../examples/gan/src/gan/gan_mnist.clj | 41 ++++++++++++++++------ 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj index 9a7bc35..593fe31 100644 --- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj +++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj @@ -40,7 +40,7 @@ (def data-dir "data/") (def output-path "results/") (def batch-size 10) -(def num-epoch 100) +(def num-epoch 1) (io/make-parents (str output-path "gout")) @@ -59,6 +59,24 @@ :data-shape [3 28 28] :batch-size batch-size})) +(defn normalize-rgb [x] + (/ (- x 128.0) 128.0)) + +(defn normalize-rgb-ndarray [nda] + (let [nda-shape (ndarray/shape-vec nda) + new-values (mapv #(normalize-rgb %) (ndarray/->vec nda))] + (ndarray/array new-values nda-shape))) + + +(defn denormalize-rgb [x] + (+ (* x 128.0) 128.0)) + +(defn clip [x] + (cond + (< x 0) 0 + (> x 255) 255 + :else (int x))) + (defn postprocess-image [img] (let [datas (ndarray/->vec img) @@ -69,11 +87,11 @@ (fn [pic] (let [[rs gs bs] (doall (partition spatial-size pic)) this-pixels (mapv (fn [r g b] - (pixel/pack-pixel - (int r) - (int g) - (int b) - (int 255))) + (pixel/pack-pixel + (int (clip (denormalize-rgb r))) + (int (clip (denormalize-rgb g))) + (int (clip (denormalize-rgb b))) + (int 255))) rs gs bs)] this-pixels)) pics) @@ -84,7 +102,8 @@ (defn postprocess-write-img [img filename] (img/write (-> (postprocess-image img) - (img/zoom 1.5)) filename "png")) + (img/zoom 1.5)) filename "png")) + (comment (def test-img (first (mx-io/batch-data (mx-io/next flan-iter)))) @@ -194,7 +213,6 @@ (defn save-img-gout [i n x] (do - (println "Carin gout shape is " (ndarray/shape x)) (postprocess-write-img x (str output-path "/" "gout-" i "-" n ".png")))) (defn save-img-diff [i n x] @@ -204,7 +222,7 @@ (defn save-img-data [i n batch] (do (postprocess-write-img - (first (mx-io/batch-data batch)) (str output-path "/" "data-" i "-" n ".png")))) + (first batch) (str output-path "/" "data-" i "-" n ".png")))) (defn calc-diff [i n diff-d] (let [diff (ndarray/copy diff-d) @@ -233,6 +251,7 @@ (mx-io/reduce-batches flan-iter (fn [n batch] (let [rbatch (mx-io/next rand-noise-iter) + dbatch (mapv normalize-rgb-ndarray (mx-io/batch-data batch)) out-g (-> mod-g (m/forward rbatch) (m/outputs)) @@ -243,7 +262,7 @@ (m/grad-arrays))) ;; update the discrimintator on the real grads-r (-> mod-d - (m/forward {:data (mx-io/batch-data batch) :label [(ndarray/ones [batch-size])]}) + (m/forward {:data dbatch :label [(ndarray/ones [batch-size])]}) (m/backward) (m/grad-arrays)) _ (mapv (fn [real fake] (let [r (first real)] @@ -260,7 +279,7 @@ (when (zero? n) (println "iteration = " i "number = " n) (save-img-gout i n (ndarray/copy (ffirst out-g))) - (save-img-data i n batch) + (save-img-data i n dbatch) (calc-diff i n (ffirst diff-d))) (inc n)))))))
