You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/11/02 23:04:38 UTC

[incubator-mxnet] 02/02: wip

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit bcb527515569292b40460892d4318e23b48e8d55
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Nov 2 19:03:26 2018 -0400

    wip
---
 .../examples/gan/src/gan/gan_mnist.clj             | 41 ++++++++++++++++------
 1 file changed, 30 insertions(+), 11 deletions(-)

diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index 9a7bc35..593fe31 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -40,7 +40,7 @@
 (def data-dir "data/")
 (def output-path "results/")
 (def batch-size 10)
-(def num-epoch 100)
+(def num-epoch 1)
 
 (io/make-parents (str output-path "gout"))
 
@@ -59,6 +59,24 @@
                                          :data-shape [3 28 28]
                                          :batch-size batch-size}))
 
+(defn normalize-rgb [x]
+  (/ (- x 128.0) 128.0))
+
+(defn normalize-rgb-ndarray [nda]
+  (let [nda-shape (ndarray/shape-vec nda)
+        new-values (mapv #(normalize-rgb %) (ndarray/->vec nda))]
+    (ndarray/array new-values nda-shape)))
+
+
+(defn denormalize-rgb [x]
+  (+ (* x 128.0) 128.0))
+
+(defn clip [x]
+  (cond
+    (< x 0) 0
+    (> x 255) 255
+    :else (int x)))
+
 
 (defn postprocess-image [img]
   (let [datas (ndarray/->vec img)
@@ -69,11 +87,11 @@
                  (fn [pic]
                    (let [[rs gs bs] (doall (partition spatial-size pic))
                          this-pixels (mapv (fn [r g b]
-                                             (pixel/pack-pixel
-                                              (int r)
-                                              (int g)
-                                              (int b)
-                                              (int 255)))
+                                              (pixel/pack-pixel
+                                               (int (clip (denormalize-rgb r)))
+                                               (int (clip (denormalize-rgb g)))
+                                               (int (clip (denormalize-rgb b)))
+                                               (int 255)))
                                            rs gs bs)]
                      this-pixels))
                  pics)
@@ -84,7 +102,8 @@
 
 (defn postprocess-write-img [img filename]
   (img/write (-> (postprocess-image img)
-                    (img/zoom 1.5)) filename "png"))
+                 (img/zoom 1.5)) filename "png"))
+
 
 (comment 
   (def test-img (first (mx-io/batch-data (mx-io/next flan-iter))))  
@@ -194,7 +213,6 @@
 
 (defn save-img-gout [i n x]
   (do
-    (println "Carin gout shape is " (ndarray/shape x))
     (postprocess-write-img x (str output-path "/" "gout-" i "-" n ".png"))))
 
 (defn save-img-diff [i n x]
@@ -204,7 +222,7 @@
 (defn save-img-data [i n batch]
   (do
     (postprocess-write-img
-     (first (mx-io/batch-data batch)) (str output-path "/" "data-" i "-" n ".png"))))
+     (first batch) (str output-path "/" "data-" i "-" n ".png"))))
 
 (defn calc-diff [i n diff-d]
   (let [diff (ndarray/copy diff-d)
@@ -233,6 +251,7 @@
       (mx-io/reduce-batches flan-iter
                             (fn [n batch]
                               (let [rbatch (mx-io/next rand-noise-iter)
+                                    dbatch (mapv normalize-rgb-ndarray (mx-io/batch-data batch))
                                     out-g (-> mod-g
                                               (m/forward rbatch)
                                               (m/outputs))
@@ -243,7 +262,7 @@
                                                                                  (m/grad-arrays)))
                                    ;; update the discrimintator on the real
                                     grads-r (-> mod-d
-                                                (m/forward {:data (mx-io/batch-data batch) :label [(ndarray/ones [batch-size])]})
+                                                (m/forward {:data dbatch :label [(ndarray/ones [batch-size])]})
                                                 (m/backward)
                                                 (m/grad-arrays))
                                     _ (mapv (fn [real fake] (let [r (first real)]
@@ -260,7 +279,7 @@
                                 (when (zero? n)
                                   (println "iteration = " i  "number = " n)
                                   (save-img-gout i n (ndarray/copy (ffirst out-g)))
-                                  (save-img-data i n batch)
+                                  (save-img-data i n dbatch)
                                   (calc-diff i n (ffirst diff-d)))
                                 (inc n)))))))