You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2021/05/20 09:48:22 UTC
[singa] branch dev updated: data downloading for cifar-10
This is an automated email from the ASF dual-hosted git repository.
zhaojing pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new db64353 data downloading for cifar-10
new 4dab111 Merge pull request #850 from zlheui/distributed-cnn-cifar-dataset
db64353 is described below
commit db6435391982b6138803bd7f08d5901622ff27c0
Author: zhulei <zl...@gmail.com>
AuthorDate: Thu May 20 17:08:07 2021 +0800
data downloading for cifar-10
---
examples/cifar_distributed_cnn/data/cifar10.py | 91 ++++++++++++++++++++++
.../cifar_distributed_cnn/data/download_cifar10.py | 49 ++++++++++++
2 files changed, 140 insertions(+)
diff --git a/examples/cifar_distributed_cnn/data/cifar10.py b/examples/cifar_distributed_cnn/data/cifar10.py
new file mode 100644
index 0000000..3b83ad7
--- /dev/null
+++ b/examples/cifar_distributed_cnn/data/cifar10.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+try:
+ import pickle
+except ImportError:
+ import cPickle as pickle
+
+import numpy as np
+import os
+import sys
+
+
+def load_dataset(filepath):
+ with open(filepath, 'rb') as fd:
+ try:
+ cifar10 = pickle.load(fd, encoding='latin1')
+ except TypeError:
+ cifar10 = pickle.load(fd)
+ image = cifar10['data'].astype(dtype=np.uint8)
+ image = image.reshape((-1, 3, 32, 32))
+ label = np.asarray(cifar10['labels'], dtype=np.uint8)
+ label = label.reshape(label.size, 1)
+ return image, label
+
+
+#def load_train_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py', num_batches=5):
+def load_train_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py', num_batches=5):
+ labels = []
+ batchsize = 10000
+ images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
+ for did in range(1, num_batches + 1):
+ fname_train_data = dir_path + "/data_batch_{}".format(did)
+ image, label = load_dataset(check_dataset_exist(fname_train_data))
+ images[(did - 1) * batchsize:did * batchsize] = image
+ labels.extend(label)
+ images = np.array(images, dtype=np.float32)
+ labels = np.array(labels, dtype=np.int32)
+ return images, labels
+
+
+#def load_test_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py'):
+def load_test_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py'):
+ images, labels = load_dataset(check_dataset_exist(dir_path + "/test_batch"))
+ return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
+
+
+def check_dataset_exist(dirpath):
+ if not os.path.exists(dirpath):
+ print(
+ 'Please download the cifar10 dataset using python data/download_cifar10.py'
+ )
+ sys.exit(0)
+ return dirpath
+
+
+def normalize(train_x, val_x):
+ mean = [0.4914, 0.4822, 0.4465]
+ std = [0.2023, 0.1994, 0.2010]
+ train_x /= 255
+ val_x /= 255
+ for ch in range(0, 2):
+ train_x[:, ch, :, :] -= mean[ch]
+ train_x[:, ch, :, :] /= std[ch]
+ val_x[:, ch, :, :] -= mean[ch]
+ val_x[:, ch, :, :] /= std[ch]
+ return train_x, val_x
+
+def load():
+ train_x, train_y = load_train_data()
+ val_x, val_y = load_test_data()
+ train_x, val_x = normalize(train_x, val_x)
+ train_y = train_y.flatten()
+ val_y = val_y.flatten()
+ return train_x, train_y, val_x, val_y
diff --git a/examples/cifar_distributed_cnn/data/download_cifar10.py b/examples/cifar_distributed_cnn/data/download_cifar10.py
new file mode 100644
index 0000000..a010b2e
--- /dev/null
+++ b/examples/cifar_distributed_cnn/data/download_cifar10.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+from future import standard_library
+standard_library.install_aliases()
+import urllib.request, urllib.parse, urllib.error
+import tarfile
+import os
+import sys
+
+
+def extract_tarfile(filepath):
+ if os.path.exists(filepath):
+ print('The tar file does exist. Extracting it now..')
+ with tarfile.open(filepath, 'r') as f:
+ f.extractall('/tmp/')
+ print('Finished!')
+ sys.exit(0)
+
+
+def do_download(dirpath, gzfile, url):
+ print('Downloading CIFAR from %s' % (url))
+ urllib.request.urlretrieve(url, gzfile)
+ extract_tarfile(gzfile)
+ print('Finished!')
+
+
+if __name__ == '__main__':
+ dirpath = '/tmp/'
+ gzfile = dirpath + 'cifar-10-python.tar.gz'
+ url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
+ do_download(dirpath, gzfile, url)