You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2018/11/02 23:04:38 UTC
[incubator-mxnet] 02/02: wip
This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch can-you-gan
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit bcb527515569292b40460892d4318e23b48e8d55
Author: gigasquid <cm...@gigasquidsoftware.com>
AuthorDate: Fri Nov 2 19:03:26 2018 -0400
wip
---
.../examples/gan/src/gan/gan_mnist.clj | 41 ++++++++++++++++------
1 file changed, 30 insertions(+), 11 deletions(-)
diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
index 9a7bc35..593fe31 100644
--- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
+++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
@@ -40,7 +40,7 @@
(def data-dir "data/")
(def output-path "results/")
(def batch-size 10)
-(def num-epoch 100)
+(def num-epoch 1)
(io/make-parents (str output-path "gout"))
@@ -59,6 +59,24 @@
:data-shape [3 28 28]
:batch-size batch-size}))
+(defn normalize-rgb [x]
+ (/ (- x 128.0) 128.0))
+
+(defn normalize-rgb-ndarray [nda]
+ (let [nda-shape (ndarray/shape-vec nda)
+ new-values (mapv #(normalize-rgb %) (ndarray/->vec nda))]
+ (ndarray/array new-values nda-shape)))
+
+
+(defn denormalize-rgb [x]
+ (+ (* x 128.0) 128.0))
+
+(defn clip [x]
+ (cond
+ (< x 0) 0
+ (> x 255) 255
+ :else (int x)))
+
(defn postprocess-image [img]
(let [datas (ndarray/->vec img)
@@ -69,11 +87,11 @@
(fn [pic]
(let [[rs gs bs] (doall (partition spatial-size pic))
this-pixels (mapv (fn [r g b]
- (pixel/pack-pixel
- (int r)
- (int g)
- (int b)
- (int 255)))
+ (pixel/pack-pixel
+ (int (clip (denormalize-rgb r)))
+ (int (clip (denormalize-rgb g)))
+ (int (clip (denormalize-rgb b)))
+ (int 255)))
rs gs bs)]
this-pixels))
pics)
@@ -84,7 +102,8 @@
(defn postprocess-write-img [img filename]
(img/write (-> (postprocess-image img)
- (img/zoom 1.5)) filename "png"))
+ (img/zoom 1.5)) filename "png"))
+
(comment
(def test-img (first (mx-io/batch-data (mx-io/next flan-iter))))
@@ -194,7 +213,6 @@
(defn save-img-gout [i n x]
(do
- (println "Carin gout shape is " (ndarray/shape x))
(postprocess-write-img x (str output-path "/" "gout-" i "-" n ".png"))))
(defn save-img-diff [i n x]
@@ -204,7 +222,7 @@
(defn save-img-data [i n batch]
(do
(postprocess-write-img
- (first (mx-io/batch-data batch)) (str output-path "/" "data-" i "-" n ".png"))))
+ (first batch) (str output-path "/" "data-" i "-" n ".png"))))
(defn calc-diff [i n diff-d]
(let [diff (ndarray/copy diff-d)
@@ -233,6 +251,7 @@
(mx-io/reduce-batches flan-iter
(fn [n batch]
(let [rbatch (mx-io/next rand-noise-iter)
+ dbatch (mapv normalize-rgb-ndarray (mx-io/batch-data batch))
out-g (-> mod-g
(m/forward rbatch)
(m/outputs))
@@ -243,7 +262,7 @@
(m/grad-arrays)))
;; update the discrimintator on the real
grads-r (-> mod-d
- (m/forward {:data (mx-io/batch-data batch) :label [(ndarray/ones [batch-size])]})
+ (m/forward {:data dbatch :label [(ndarray/ones [batch-size])]})
(m/backward)
(m/grad-arrays))
_ (mapv (fn [real fake] (let [r (first real)]
@@ -260,7 +279,7 @@
(when (zero? n)
(println "iteration = " i "number = " n)
(save-img-gout i n (ndarray/copy (ffirst out-g)))
- (save-img-data i n batch)
+ (save-img-data i n dbatch)
(calc-diff i n (ffirst diff-d)))
(inc n)))))))