You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2021/05/20 09:48:22 UTC

[singa] branch dev updated: data downloading for cifar-10

This is an automated email from the ASF dual-hosted git repository.

zhaojing pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new db64353  data downloading for cifar-10
     new 4dab111  Merge pull request #850 from zlheui/distributed-cnn-cifar-dataset
db64353 is described below

commit db6435391982b6138803bd7f08d5901622ff27c0
Author: zhulei <zl...@gmail.com>
AuthorDate: Thu May 20 17:08:07 2021 +0800

    data downloading for cifar-10
---
 examples/cifar_distributed_cnn/data/cifar10.py     | 91 ++++++++++++++++++++++
 .../cifar_distributed_cnn/data/download_cifar10.py | 49 ++++++++++++
 2 files changed, 140 insertions(+)

diff --git a/examples/cifar_distributed_cnn/data/cifar10.py b/examples/cifar_distributed_cnn/data/cifar10.py
new file mode 100644
index 0000000..3b83ad7
--- /dev/null
+++ b/examples/cifar_distributed_cnn/data/cifar10.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+try:
+    import pickle
+except ImportError:
+    import cPickle as pickle
+
+import numpy as np
+import os
+import sys
+
+
+def load_dataset(filepath):
+    with open(filepath, 'rb') as fd:
+        try:
+            cifar10 = pickle.load(fd, encoding='latin1')
+        except TypeError:
+            cifar10 = pickle.load(fd)
+    image = cifar10['data'].astype(dtype=np.uint8)
+    image = image.reshape((-1, 3, 32, 32))
+    label = np.asarray(cifar10['labels'], dtype=np.uint8)
+    label = label.reshape(label.size, 1)
+    return image, label
+
+
+#def load_train_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py', num_batches=5):
+def load_train_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py', num_batches=5):
+    labels = []
+    batchsize = 10000
+    images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
+    for did in range(1, num_batches + 1):
+        fname_train_data = dir_path + "/data_batch_{}".format(did)
+        image, label = load_dataset(check_dataset_exist(fname_train_data))
+        images[(did - 1) * batchsize:did * batchsize] = image
+        labels.extend(label)
+    images = np.array(images, dtype=np.float32)
+    labels = np.array(labels, dtype=np.int32)
+    return images, labels
+
+
+#def load_test_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py'):
+def load_test_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py'):
+    images, labels = load_dataset(check_dataset_exist(dir_path + "/test_batch"))
+    return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
+
+
+def check_dataset_exist(dirpath):
+    if not os.path.exists(dirpath):
+        print(
+            'Please download the cifar10 dataset using python data/download_cifar10.py'
+        )
+        sys.exit(0)
+    return dirpath
+
+
+def normalize(train_x, val_x):
+    mean = [0.4914, 0.4822, 0.4465]
+    std = [0.2023, 0.1994, 0.2010]
+    train_x /= 255
+    val_x /= 255
+    for ch in range(0, 2):
+        train_x[:, ch, :, :] -= mean[ch]
+        train_x[:, ch, :, :] /= std[ch]
+        val_x[:, ch, :, :] -= mean[ch]
+        val_x[:, ch, :, :] /= std[ch]
+    return train_x, val_x
+
+def load():
+    train_x, train_y = load_train_data()
+    val_x, val_y = load_test_data()
+    train_x, val_x = normalize(train_x, val_x)
+    train_y = train_y.flatten()
+    val_y = val_y.flatten()
+    return train_x, train_y, val_x, val_y
diff --git a/examples/cifar_distributed_cnn/data/download_cifar10.py b/examples/cifar_distributed_cnn/data/download_cifar10.py
new file mode 100644
index 0000000..a010b2e
--- /dev/null
+++ b/examples/cifar_distributed_cnn/data/download_cifar10.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#     http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# 
+
+from __future__ import print_function
+from future import standard_library
+standard_library.install_aliases()
+import urllib.request, urllib.parse, urllib.error
+import tarfile
+import os
+import sys
+
+
+def extract_tarfile(filepath):
+    if os.path.exists(filepath):
+        print('The tar file does exist. Extracting it now..')
+        with tarfile.open(filepath, 'r') as f:
+            f.extractall('/tmp/')
+        print('Finished!')
+        sys.exit(0)
+
+
+def do_download(dirpath, gzfile, url):
+    print('Downloading CIFAR from %s' % (url))
+    urllib.request.urlretrieve(url, gzfile)
+    extract_tarfile(gzfile)
+    print('Finished!')
+
+
+if __name__ == '__main__':
+    dirpath = '/tmp/'
+    gzfile = dirpath + 'cifar-10-python.tar.gz'
+    url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
+    do_download(dirpath, gzfile, url)