You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2021/05/21 08:40:57 UTC

[singa] branch dev updated: data downloading for mnist

This is an automated email from the ASF dual-hosted git repository.

zhaojing pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new ee48750  data downloading for mnist
     new 9e96ed9  Merge pull request #851 from zlheui/distributed-cnn-mnist-dataset
ee48750 is described below

commit ee48750e33db9acf9a85dc0b746705e20b765582
Author: zhulei <zl...@gmail.com>
AuthorDate: Thu May 20 17:52:23 2021 +0800

    data downloading for mnist
---
 .../cifar_distributed_cnn/data/download_mnist.py   | 49 ++++++++++++
 examples/cifar_distributed_cnn/data/mnist.py       | 91 ++++++++++++++++++++++
 2 files changed, 140 insertions(+)

diff --git a/examples/cifar_distributed_cnn/data/download_mnist.py b/examples/cifar_distributed_cnn/data/download_mnist.py
new file mode 100644
index 0000000..65acb0e
--- /dev/null
+++ b/examples/cifar_distributed_cnn/data/download_mnist.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import urllib.request
+
+
+def check_exist_or_download(url):
+
+    download_dir = '/tmp/'
+    name = url.rsplit('/', 1)[-1]
+    filename = os.path.join(download_dir, name)
+
+    if not os.path.isfile(filename):
+        print("Downloading %s" % url)
+        urllib.request.urlretrieve(url, filename)
+    else:
+        print("Already Downloaded: %s" % url)
+
+
+if __name__ == '__main__':
+
+    #url of the mnist dataset
+    train_x_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
+    train_y_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
+    valid_x_url = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
+    valid_y_url = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
+
+    #download the mnist dataset
+    check_exist_or_download(train_x_url)
+    check_exist_or_download(train_y_url)
+    check_exist_or_download(valid_x_url)
+    check_exist_or_download(valid_y_url)
diff --git a/examples/cifar_distributed_cnn/data/mnist.py b/examples/cifar_distributed_cnn/data/mnist.py
new file mode 100644
index 0000000..9cd1a84
--- /dev/null
+++ b/examples/cifar_distributed_cnn/data/mnist.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import numpy as np
+import os
+import sys
+import gzip
+import codecs
+
+
+def check_dataset_exist(dirpath):
+    if not os.path.exists(dirpath):
+        print(
+            'The MNIST dataset does not exist. Please download the mnist dataset using python data/download_mnist.py'
+        )
+        sys.exit(0)
+    return dirpath
+
+
+def load_dataset():
+    train_x_path = '/tmp/train-images-idx3-ubyte.gz'
+    train_y_path = '/tmp/train-labels-idx1-ubyte.gz'
+    valid_x_path = '/tmp/t10k-images-idx3-ubyte.gz'
+    valid_y_path = '/tmp/t10k-labels-idx1-ubyte.gz'
+
+    train_x = read_image_file(check_dataset_exist(train_x_path)).astype(
+        np.float32)
+    train_y = read_label_file(check_dataset_exist(train_y_path)).astype(
+        np.float32)
+    valid_x = read_image_file(check_dataset_exist(valid_x_path)).astype(
+        np.float32)
+    valid_y = read_label_file(check_dataset_exist(valid_y_path)).astype(
+        np.float32)
+    return train_x, train_y, valid_x, valid_y
+
+
+def read_label_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2049
+        length = get_int(data[4:8])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape((length))
+        return parsed
+
+
+def get_int(b):
+    return int(codecs.encode(b, 'hex'), 16)
+
+
+def read_image_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2051
+        length = get_int(data[4:8])
+        num_rows = get_int(data[8:12])
+        num_cols = get_int(data[12:16])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
+            (length, 1, num_rows, num_cols))
+        return parsed
+
+
+def normalize(train_x, val_x):
+    train_x /= 255
+    val_x /= 255
+    return train_x, val_x
+
+
+def load():
+    train_x, train_y, val_x, val_y = load_dataset()
+    train_x, val_x = normalize(train_x, val_x)
+    train_x = train_x.astype(np.float32)
+    val_x = val_x.astype(np.float32)
+    train_y = train_y.astype(np.int32)
+    val_y = val_y.astype(np.int32)
+    return train_x, train_y, val_x, val_y