This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 95593b67da2ec7e1eb9b96f3fa5afa6c55487c1c
Author: gigasquid <[email protected]>
AuthorDate: Fri Nov 2 17:57:51 2018 -0400

    wip
---
 contrib/clojure-package/examples/gan/project.clj   |   6 +-
 .../examples/gan/src/gan/gan_mnist.clj             | 111 ++++++++++++++++-----
 .../clojure-package/examples/gan/src/gan/viz.clj   |  13 ++-
 3 files changed, 98 insertions(+), 32 deletions(-)

diff --git a/contrib/clojure-package/examples/gan/project.clj 
b/contrib/clojure-package/examples/gan/project.clj
index bebbc20..72354af 100644
--- a/contrib/clojure-package/examples/gan/project.clj
+++ b/contrib/clojure-package/examples/gan/project.clj
@@ -19,6 +19,8 @@
   :description "GAN MNIST with MXNet"
   :plugins [[lein-cljfmt "0.5.7"]]
   :dependencies [[org.clojure/clojure "1.9.0"]
-                 [org.apache.mxnet.contrib.clojure/clojure-mxnet 
"1.3.0-SNAPSHOT"]
-                 [nu.pattern/opencv "2.4.9-7"]]
+                 [org.apache.mxnet.contrib.clojure/clojure-mxnet-osx-cpu 
"1.3.0"]
+                 [nu.pattern/opencv "2.4.9-7"]
+                 [net.mikera/imagez "0.12.0"]
+                 [thinktopic/think.image "0.4.16"]]
   :main gan.gan-mnist)
diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj 
b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index 14dd2c5..9a7bc35 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -29,7 +29,9 @@
             [org.apache.clojure-mxnet.shape :as mx-shape]
             [org.apache.clojure-mxnet.util :as util]
             [gan.viz :as viz]
-            [org.apache.clojure-mxnet.context :as context])
+            [org.apache.clojure-mxnet.context :as context]
+            [think.image.pixel :as pixel]
+            [mikera.image.core :as img])
   (:gen-class))
 
 ;; based off of 
https://medium.com/@julsimon/generative-adversarial-networks-on-apache-mxnet-part-1-b6d39e6b5df1
@@ -37,19 +39,76 @@
 
 (def data-dir "data/")
 (def output-path "results/")
-(def batch-size 100)
-(def num-epoch 10)
+(def batch-size 10)
+(def num-epoch 100)
 
 (io/make-parents (str output-path "gout"))
 
-(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
+
+
+#_(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
   (sh "../../scripts/get_mnist_data.sh"))
 
-(defonce mnist-iter (mx-io/mnist-iter {:image (str data-dir 
"train-images-idx3-ubyte")
+#_(defonce mnist-iter (mx-io/mnist-iter {:image (str data-dir 
"train-images-idx3-ubyte")
                                        :label (str data-dir 
"train-labels-idx1-ubyte")
                                        :input-shape [1 28 28]
                                        :batch-size batch-size
-                                       :shuffle true}))
+                                         :shuffle true}))
+
+(def flan-iter (mx-io/image-record-iter {:path-imgrec "flan.rec"
+                                         :data-shape [3 28 28]
+                                         :batch-size batch-size}))
+
+
+(defn postprocess-image [img]
+  (let [datas (ndarray/->vec img)
+        image-shape (mx-shape/->vec (ndarray/shape img))
+        spatial-size (* (get image-shape 2) (get image-shape 3))
+        pics (doall (partition (* 3 spatial-size) datas))
+        pixels  (mapv
+                 (fn [pic]
+                   (let [[rs gs bs] (doall (partition spatial-size pic))
+                         this-pixels (mapv (fn [r g b]
+                                             (pixel/pack-pixel
+                                              (int r)
+                                              (int g)
+                                              (int b)
+                                              (int 255)))
+                                           rs gs bs)]
+                     this-pixels))
+                 pics)
+        new-pixels (into [] (flatten pixels))
+        new-image (img/new-image (* 1 (get image-shape 3)) (* batch-size (get 
image-shape 2)))
+        _  (img/set-pixels new-image (int-array new-pixels))]
+    new-image))
+
+(defn postprocess-write-img [img filename]
+  (img/write (-> (postprocess-image img)
+                    (img/zoom 1.5)) filename "png"))
+
+(comment 
+  (def test-img (first (mx-io/batch-data (mx-io/next flan-iter))))  
+
+  (ndarray/shape test-img)
+
+  (postprocess-image test-img)
+  (do (img/show (-> (postprocess-image test-img)
+                    (img/zoom 1.5))))
+  
+  (img/write (-> (postprocess-image test-img)
+                    (img/zoom 1.5)) "results/flan.png" "png")
+
+  (ndarray/->vec  test-img)
+
+    (viz/im-sav {:title "Carin"
+               :output-path output-path
+               :x test-img
+               :flip false})
+
+
+
+)
+
 
 (def rand-noise-iter (mx-io/rand-iter [batch-size 100 1 1]))
 
@@ -66,6 +125,11 @@
   (conv-output-size 8 4 1 2) ;=> 4.0
   (conv-output-size 4 4 0 1) ;=> 1
 
+  (conv-output-size 128 4 3 2) ;=> 66
+  (conv-output-size 66 4 2 2) ;=> 34
+  (conv-output-size 34 4 2 2) ;=> 18.0
+  (conv-output-size 18 5 2 2) ;=> 1
+
   ;; Calcing the layer sizes for generator
   (defn deconv-output-size [input-size kernel-size padding stride]
     (-
@@ -80,7 +144,7 @@
 
 
 (def ndf 28) ;; image height /width
-(def nc 1) ;; number of channels
+(def nc 3) ;; number of channels
 (def eps (float (+ 1e-5  1e-12)))
 (def lr  0.0005) ;; learning rate
 (def beta1 0.5)
@@ -130,22 +194,17 @@
 
 (defn save-img-gout [i n x]
   (do
-    (viz/im-sav {:title (str "gout-" i "-" n)
-                 :output-path output-path
-                 :x x
-                 :flip false})))
+    (println "Carin gout shape is " (ndarray/shape x))
+    (postprocess-write-img x (str output-path "/" "gout-" i "-" n ".png"))))
 
 (defn save-img-diff [i n x]
-  (do (viz/im-sav {:title (str "diff-" i "-" n)
-                   :output-path output-path
-                   :x x
-                   :flip false})))
+  (do
+    (postprocess-write-img x (str output-path "/" "diff-" i "-" n ".png"))))
 
 (defn save-img-data [i n batch]
-  (do (viz/im-sav {:title (str "data-" i "-" n)
-                   :output-path output-path
-                   :x (first (mx-io/batch-data batch))
-                   :flip false})))
+  (do
+    (postprocess-write-img
+     (first (mx-io/batch-data batch)) (str output-path "/" "data-" i "-" n 
".png"))))
 
 (defn calc-diff [i n diff-d]
   (let [diff (ndarray/copy diff-d)
@@ -159,8 +218,8 @@
 
 (defn train [devs]
   (let [mod-d  (-> (m/module (discriminator) {:contexts devs :data-names 
["data"] :label-names ["label"]})
-                   (m/bind {:data-shapes (mx-io/provide-data mnist-iter)
-                            :label-shapes (mx-io/provide-label mnist-iter)
+                   (m/bind {:data-shapes (mx-io/provide-data flan-iter)
+                            :label-shapes (mx-io/provide-label flan-iter)
                             :inputs-need-grad true})
                    (m/init-params {:initializer (init/normal 0.02)})
                    (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr 
:wd 0.0 :beta1 beta1})}))
@@ -171,13 +230,13 @@
 
     (println "Training for " num-epoch " epochs...")
     (doseq [i (range num-epoch)]
-      (mx-io/reduce-batches mnist-iter
+      (mx-io/reduce-batches flan-iter
                             (fn [n batch]
                               (let [rbatch (mx-io/next rand-noise-iter)
                                     out-g (-> mod-g
                                               (m/forward rbatch)
                                               (m/outputs))
-                                   ;; update the discriminiator on the fake
+                                    ;; update the discriminiator on the fake
                                     grads-f  (mapv #(ndarray/copy (first %)) 
(-> mod-d
                                                                                
  (m/forward {:data (first out-g) :label [(ndarray/zeros [batch-size])]})
                                                                                
  (m/backward)
@@ -198,7 +257,7 @@
                                     _ (-> mod-g
                                           (m/backward (first diff-d))
                                           (m/update))]
-                                (when (zero? (mod n 100))
+                                (when (zero? n)
                                   (println "iteration = " i  "number = " n)
                                   (save-img-gout i n (ndarray/copy (ffirst 
out-g)))
                                   (save-img-data i n batch)
@@ -214,4 +273,6 @@
     (train devs)))
 
 (comment
-  (train [(context/cpu)]))
+  (train [(context/cpu)])
+
+  )
diff --git a/contrib/clojure-package/examples/gan/src/gan/viz.clj 
b/contrib/clojure-package/examples/gan/src/gan/viz.clj
index 8b57b94..2780252 100644
--- a/contrib/clojure-package/examples/gan/src/gan/viz.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/viz.clj
@@ -36,18 +36,21 @@
                 :else (int %)))
        (mapv #(.byteValue %))))
 
+
+
 (defn get-img [raw-data channels height width flip]
   (let [totals (* height width)
         img (if (> channels 1)
               ;; rgb image
-              (let [[ra ga ba] (byte-array (partition totals raw-data))
+              (let [[ra ga ba] (doall (partition totals raw-data))
                     rr (new Mat height width (CvType/CV_8U))
                     gg (new Mat height width (CvType/CV_8U))
                     bb (new Mat height width (CvType/CV_8U))
-                    result (new Mat)]
-                (.put rr (int 0) (int 0) ra)
-                (.put gg (int 0) (int 0) ga)
-                (.put bb (int 0) (int 0) ba)
+                    result (new Mat height width (CvType/CV_8U))]
+                (do
+                  (.put rr 0 0 (byte-array ra))
+                  (.put gg 0 0 (byte-array ga))
+                  (.put bb 0 0 (byte-array ba)))
                 (Core/merge (java.util.ArrayList. [bb gg rr]) result)
                 result)
               ;; gray image

Reply via email to