You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@singa.apache.org by GitBox <gi...@apache.org> on 2020/03/11 11:47:50 UTC

[GitHub] [singa] XJDKC opened a new pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

XJDKC opened a new pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626
 
 
   # Overview
   This PR adds the computational graph with memory optimization. It is based on the code developed by @chrishkchris and @XJDKC and some discussions with @nudles.
   
   # Features
   There are three main features in this PR, namely the construction of the computational graph, lazy allocation and automatic recycling. Details as follows:
   * `Computational graph construction`: Construct a computational graph based on the user-defined neural network or expressions and then run the graph to accomplish the training task.
   * `Lazy allocation`: When blocks need to be allocated, devices won't allocate memory for them immediately. Only when an operation uses this block for the first time, memory allocation will be performed.
   * `Automatic recycling`: Automatically deallocate the intermediate tensors which won't be used again in the following operations when we are running the graph in an iteration.
   
   # Design
   1. Computational graph construction
       * Use the technique of delayed execution to falsely perform operations in forward propagation and backward propagation once. Buffer all the operations and the tensors read or written by each operation. 
       * Calculate dependencies between all the operations to decide the order of execution. (Support directed cyclic graph)
       * Execute all the operations in the order we just calculated to update all the parameters.
   2. Lazy allocation
       * When a device needs to create a new block, just pass the device to that block instead of allocating a piece of memory from the mempool and passing the pointer to that block.
       * When the block is accessed for the first time, let the device corresponding to the block allocate memory and then access it.
   3. Automatic recycling
       * When calculating dependencies between the operations during the graph construction, the reference count of tensors can also be calculated.
       * When an operation is completed, we can decrease the reference count of tensors the operation used.
       * If a tensor's reference count reaches zero, it means the tensor won't be accessed by latter operations and we can recycle its memory.
   
   # Changes
   * `Tensor`&`Operation`
       * Change the capture type of tensors in lambda expressions to achieve delayed execution.
       * Change the type of input and output parameters to ensure that the input and output of the operation are tensors.
   * `Device`: Add code for 
       * buffering operations
       * constructing graph
       * calculating dependencies
       * executing graph.
   * `Block`: Add a member variable of type device to help to do the lazy allocation. Add a function to help to do the automatic recycling.
   * `Swig`: add some interfaces
   *  `Examples`: Add some examples with operations buffering.
   
   # Evaluation
   * Experiment settings
       * Model: ResNet50 in [resnet.py](../tree/dev/examples/autograd/resnet.py)
       * GPU: Nvidia RTX 2080Ti
   * Result: `s =  second` `b = batch`
   <table>
       <tr>
           <th style="text-align: center">Batchsize</th>
           <th style="text-align: center">Cases</th>
           <th style="text-align: center">Memory-Usage(peak)</th>
           <th style="text-align: center">Throughput</th>
           <th style="text-align: center">Time</th>
           <th style="text-align: center">Reduction Rate</th>
           <th style="text-align: center">Speedup</th>
       </tr>
       <tr>
           <td rowspan="3">16</td>
           <td>dev branch</td>
           <td>4961MB</td>
           <td>176.9182/s</td>
           <td>0.0903s/b</td>
           <td>00.00%</td>
           <td>1.0000</td>
       </tr>
       <tr>
           <td>PR(no graph)</td>
           <td>4961MB</td>
           <td>173.6726/s</td>
           <td>0.0921s/b</td>
           <td>00.00%</td>
           <td>0.9796</td>
       </tr>
       <tr>
           <td>PR(with graph)</td>
           <td>3105MB</td>
           <td>218.4999/s</td>
           <td>0.0732s/b</td>
           <td>37.41%</td>
           <td>1.2325</td>
       </tr>
       <tr>
           <td rowspan="3">32</td>
           <td>dev branch</td>
           <td>10113MB</td>
           <td>203.3363/s</td>
           <td>0.1574s/b</td>
           <td>00.00%</td>
           <td>1.0000</td>
       </tr>
       <tr>
           <td>PR(no graph)</td>
           <td>10123MB</td>
           <td>173.6726/s</td>
           <td>202.3836s/b</td>
           <td>-0.10%</td>
           <td>0.9953</td>
       </tr>
       <tr>
           <td>PR(with graph)</td>
           <td>6517MB</td>
           <td>234.1376/s</td>
           <td>0.1367s/b</td>
           <td>35.56%</td>
           <td>1.1515</td>
       </tr>
   </table>
   
   From the table above, we can know that:
   * This PR does not affect training time and memory usage if the graph is disabled (has backward compatibility).
   * This PR can significantly reduce memory usage and training time by using the graph.
   
   # How to use
   ```Python
   # Initialize the input tensors
   # ...
   
   # Buffer the operations
   dev.SetBufferFlag(True)
   x = autograd.matmul(inputs, w0)
   x = autograd.add_bias(x, b0)
   x = autograd.relu(x)
   x = autograd.matmul(x, w1)
   x = autograd.add_bias(x, b1)
   # x = autograd.softmax(x)
   loss = autograd.softmax_cross_entropy(x, target)
   for p, gp in autograd.backward(loss):
           sgd.apply(0, gp, p, "")
   dev.SetBufferFlag(False)
   
   # Run Graph
   print("start executing buffered functions")
   for i in range(1001):
       dev.ExecBuffOps()
   ```
   
   # Plan
   - [ ] Computation graph optimization: replace a subgraph of the input computation graph with another subgraph which is functionally equivalent to the original one. 
   - [ ] Performing operations in parallel.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-605433692
 
 
   This pull request **introduces 4 alerts** and **fixes 1** when merging 0f7eeec60740a0c120794338a29459b3b7a8c648 into 38fa20b74ba6ed414e4ecce65d42063ef1662564 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-638b106e7bd2f9b1f0699480c10b9312a1175e7a)
   
   **new alerts:**
   
   * 2 for Unused local variable
   * 1 for Mismatch between signature and use of an overridden method
   * 1 for Redundant assignment
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-599933370
 
 
   This pull request **fixes 1 alert** when merging fbb7a869a78dbb8108f7b9973b2c067882377f79 into 5b478113f80adda5ffc6009cd6539a7c9f47f76d - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-43f5110edfe17800aec0838e7c52b32913acd040)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles merged pull request #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles merged pull request #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626
 
 
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-605487329
 
 
   This pull request **fixes 1 alert** when merging f7a200c02c23695c8f2c78020fb7185f4829e2b0 into 38fa20b74ba6ed414e4ecce65d42063ef1662564 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-61edfdacaa06507eb767a06579a30a716309d376)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393497630
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
 
 Review comment:
   it would be better to make the code the same for graph=On or Off. then we can switch between them easily.
   now this piece of code works when we use graph.
   if we do not use graph, we need to write another piece of code.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393512188
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
+    x = autograd.matmul(inputs, w0)
+    x = autograd.add_bias(x, b0)
+    x = autograd.relu(x)
+    x = autograd.matmul(x, w1)
+    x = autograd.add_bias(x, b1)
+    # x = autograd.softmax(x)
+    loss = autograd.softmax_cross_entropy(x, target)
+    print("start backward")
+    for p, gp in autograd.backward(loss):
+        sgd.apply(0, gp, p, "")
+    dev.EnableGraph(False)
+
+    # exec the buffered ops
+    print("start executing buffered functions")
+    for i in range(1001):
+        dev.RunGraph()
 
 Review comment:
   Yes, just like this.
   ```python
   # Copy the patch data into input tensors
   tx.copy_from_numpy(x)
   ty.copy_from_numpy(y)
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-605542596
 
 
   have you rebase/merged with the latest dev?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] chrishkchris commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
chrishkchris commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-597994780
 
 
   please correct the result of b32 PR (no graph), that should be typo

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-601080498
 
 
   This pull request **fixes 1 alert** when merging 2018eba08bb7d69d8b23b9097375484251041e46 into 07ff7d937d73b350bd027e6890b567e8d1601dc5 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-5855b11f4dddfcd3cd722d0cfecc39fa60458d6e)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-604639314
 
 
   This pull request **introduces 1 alert** and **fixes 1** when merging 2eabbb76fecbe2453f5b15b12c4d22c14bf28ebc into 38fa20b74ba6ed414e4ecce65d42063ef1662564 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-e35a46c1fc28f066480a8fc66db0d2ec3e9b560a)
   
   **new alerts:**
   
   * 1 for Unused import
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393498170
 
 

 ##########
 File path: examples/autograd/mnist_cnn_buffer.py
 ##########
 @@ -0,0 +1,271 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import singa_wrap as singa
+from singa import autograd
+from singa import tensor
+from singa import device
+from singa import opt
+import numpy as np
+import os
+import sys
+import gzip
+import codecs
+import time
+
+class CNN:
+    def __init__(self):
+        self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
+        self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
+        self.linear1 = autograd.Linear(4 * 4 * 50, 500)
+        self.linear2 = autograd.Linear(500, 10)
+        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
+        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = autograd.relu(y)
+        y = self.pooling1(y)
+        y = self.conv2(y)
+        y = autograd.relu(y)
+        y = self.pooling2(y)
+        y = autograd.flatten(y)
+        y = self.linear1(y)
+        y = autograd.relu(y)
+        y = self.linear2(y)
+        return y
+
+def check_dataset_exist(dirpath):
+    if not os.path.exists(dirpath):
+        print('The MNIST dataset does not exist. Please download the mnist dataset using download_mnist.py (e.g. python3 download_mnist.py)')
+        sys.exit(0)
+    return dirpath
+
+def load_dataset():
+    train_x_path = '/tmp/train-images-idx3-ubyte.gz'
+    train_y_path = '/tmp/train-labels-idx1-ubyte.gz'
+    valid_x_path = '/tmp/t10k-images-idx3-ubyte.gz'
+    valid_y_path = '/tmp/t10k-labels-idx1-ubyte.gz'
+
+    train_x = read_image_file(check_dataset_exist(train_x_path)).astype(
+        np.float32)
+    train_y = read_label_file(check_dataset_exist(train_y_path)).astype(
+        np.float32)
+    valid_x = read_image_file(check_dataset_exist(valid_x_path)).astype(
+        np.float32)
+    valid_y = read_label_file(check_dataset_exist(valid_y_path)).astype(
+        np.float32)
+    return train_x, train_y, valid_x, valid_y
+
+def read_label_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2049
+        length = get_int(data[4:8])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape(
+            (length))
+        return parsed
+
+def get_int(b):
+    return int(codecs.encode(b, 'hex'), 16)
+
+def read_image_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2051
+        length = get_int(data[4:8])
+        num_rows = get_int(data[8:12])
+        num_cols = get_int(data[12:16])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
+            (length, 1, num_rows, num_cols))
+        return parsed
+
+def to_categorical(y, num_classes):
+    y = np.array(y, dtype="int")
+    n = y.shape[0]
+    categorical = np.zeros((n, num_classes))
+    categorical[np.arange(n), y] = 1
+    categorical = categorical.astype(np.float32)
+    return categorical
+
+
+def accuracy(pred, target):
+    y = np.argmax(pred, axis=1)
+    t = np.argmax(target, axis=1)
+    a = y == t
+    return np.array(a, "int").sum()
+
+# Function to all reduce NUMPY Accuracy and Loss from Multiple Devices
+def reduce_variable(variable, dist_opt, reducer):
+    reducer.copy_from_numpy(variable)
+    dist_opt.all_reduce(reducer.data)
+    dist_opt.wait()
+    output=tensor.to_numpy(reducer)
+    return output
+
+# Function to sychronize SINGA TENSOR initial model parameters
+def sychronize(tensor, dist_opt):
+    dist_opt.all_reduce(tensor.data)
+    dist_opt.wait()
+    tensor /= dist_opt.world_size
+
+# Data augmentation
+def augmentation(x, batch_size):
+    xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
+    for data_num in range(0, batch_size):
+        offset = np.random.randint(8, size=2)
+        x[data_num,:,:,:] = xpad[data_num, :, offset[0]: offset[0] + 28, offset[1]: offset[1] + 28]
+        if_flip = np.random.randint(2)
+        if (if_flip):
+            x[data_num, :, :, :] = x[data_num, :, :, ::-1]
+    return x
+
+def train_mnist_cnn(sgd, max_epoch, batch_size, DIST=False, data_partition=None,
+                    gpu_num=None, gpu_per_node=None, nccl_id=None, spars=0, topK=False, corr=True):
+    # Prepare training and valadiation data
+    train_x, train_y, test_x, test_y = load_dataset()
+    IMG_SIZE = 28
+    num_classes = 10
+    train_y = to_categorical(train_y, num_classes)
+    test_y = to_categorical(test_y, num_classes)
+
+    # Normalization
+    train_x = train_x / 255
+    test_x = test_x / 255
+
+    if DIST:
+        # For Distributed GPU Training
+        sgd = opt.DistOpt(sgd, nccl_id=nccl_id, gpu_num=gpu_num, gpu_per_node=gpu_per_node)
+        dev = device.create_cuda_gpu_on(sgd.rank_in_local)
+        # Dataset partition for distributed training
+        train_x, train_y = data_partition(train_x, train_y, sgd.rank_in_global, sgd.world_size)
+        test_x, test_y = data_partition(test_x, test_y, sgd.rank_in_global, sgd.world_size)
+        world_size = sgd.world_size
+    else:
+        # For Single GPU
+        print("for single GPU")
+        dev = device.create_cuda_gpu_on(0)
+        device.set_default_device(dev)
+        dev.SetRandSeed(0)
+        world_size = 1
+
+    # create model
+    print("create the model")
+    model = CNN()
+
+    # create input tensors
+    print("create input tensors")
+    tx = tensor.Tensor((batch_size, 1, IMG_SIZE, IMG_SIZE), dev, tensor.float32)
+    ty = tensor.Tensor((batch_size, num_classes), dev, tensor.int32)
+    num_train_batch = train_x.shape[0] // batch_size
+    num_test_batch = test_x.shape[0] // batch_size
+    idx = np.arange(train_x.shape[0], dtype=np.int32)
+
+    if DIST:
+        #Sychronize the initial parameters
+        autograd.training = True
+        x = np.random.randn(batch_size, 1, IMG_SIZE, IMG_SIZE).astype(np.float32)
+        y = np.zeros( shape=(batch_size, num_classes), dtype=np.int32)
+        tx.copy_from_numpy(x)
+        ty.copy_from_numpy(y)
+        out = model.forward(tx)
+        loss = autograd.softmax_cross_entropy(out, ty)
+        for p, g in autograd.backward(loss):
+            sychronize(p, sgd)
+
+    # buffer all the operations in one iteration
+    print("buffer all the operations")
+    dev.EnableGraph(True)
+    autograd.training = True
+    out = model.forward(tx)
+    loss = autograd.softmax_cross_entropy(out, ty)
+    print("buffer softmax_cross_entropy")
+    for p, g in autograd.backward(loss):
+        sgd.update(p, g)
+        # print("update sgd")
+    autograd.training = False
+    dev.EnableGraph(False)
 
 Review comment:
   does it work for distributed training?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-602077786
 
 
   This pull request **fixes 1 alert** when merging 84d7f875d90d85b0e31fe75938343797d7c6cdb8 into 07ff7d937d73b350bd027e6890b567e8d1601dc5 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-729b3a79b6cc037e3038e49e0bea0d7903986798)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-602081170
 
 
   This pull request **fixes 1 alert** when merging 6031040fdbb5112f83697aaef68cf8da0d042495 into 07ff7d937d73b350bd027e6890b567e8d1601dc5 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-2592209f26bf57cb7ac3007bd515e36f4a0b6b95)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393500871
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
+    x = autograd.matmul(inputs, w0)
+    x = autograd.add_bias(x, b0)
+    x = autograd.relu(x)
+    x = autograd.matmul(x, w1)
+    x = autograd.add_bias(x, b1)
+    # x = autograd.softmax(x)
+    loss = autograd.softmax_cross_entropy(x, target)
+    print("start backward")
+    for p, gp in autograd.backward(loss):
+        sgd.apply(0, gp, p, "")
+    dev.EnableGraph(False)
+
+    # exec the buffered ops
+    print("start executing buffered functions")
+    for i in range(1001):
+        dev.RunGraph()
 
 Review comment:
   In this example, we just randomly generate the initial tensor and will not change it in the following training.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393510526
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
+    x = autograd.matmul(inputs, w0)
+    x = autograd.add_bias(x, b0)
+    x = autograd.relu(x)
+    x = autograd.matmul(x, w1)
+    x = autograd.add_bias(x, b1)
+    # x = autograd.softmax(x)
+    loss = autograd.softmax_cross_entropy(x, target)
+    print("start backward")
+    for p, gp in autograd.backward(loss):
+        sgd.apply(0, gp, p, "")
+    dev.EnableGraph(False)
+
+    # exec the buffered ops
+    print("start executing buffered functions")
+    for i in range(1001):
+        dev.RunGraph()
 
 Review comment:
   copy the data to a placeholder variable in the loop?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393507692
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
+    x = autograd.matmul(inputs, w0)
+    x = autograd.add_bias(x, b0)
+    x = autograd.relu(x)
+    x = autograd.matmul(x, w1)
+    x = autograd.add_bias(x, b1)
+    # x = autograd.softmax(x)
+    loss = autograd.softmax_cross_entropy(x, target)
+    print("start backward")
+    for p, gp in autograd.backward(loss):
+        sgd.apply(0, gp, p, "")
+    dev.EnableGraph(False)
+
+    # exec the buffered ops
+    print("start executing buffered functions")
+    for i in range(1001):
+        dev.RunGraph()
 
 Review comment:
   For real training, just call the copy_from_numpy as usual.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-605604183
 
 
   This pull request **fixes 1 alert** when merging 743addce22af0d62517418c7c2f144ac71ffd926 into 38fa20b74ba6ed414e4ecce65d42063ef1662564 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-deeccee15d7b05bf1ca2f8665d7a7d918536953c)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles edited a comment on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles edited a comment on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-605542596
 
 
   have you rebase/merged with the latest dev?
   ready for merge?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] XJDKC commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
XJDKC commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-597996012
 
 
   Noted, thanks. I will fix it.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-599762491
 
 
   This pull request **introduces 3 alerts** and **fixes 1** when merging 93b8067606e53d2e3fd1cac1e7f9d487750bcc4f into 5b478113f80adda5ffc6009cd6539a7c9f47f76d - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-748d83de7a1e8db4630f9e57b99cb4d72c511c3e)
   
   **new alerts:**
   
   * 2 for Wrong type of arguments to formatting function
   * 1 for Unused local variable
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-602189779
 
 
   This pull request **fixes 1 alert** when merging cd4639e6aeca5b4a8447a3db400f9cd9be8e332d into 07ff7d937d73b350bd027e6890b567e8d1601dc5 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-0d7e974dd588f3f10210cacd0e16c702ef0449ec)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393496835
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
+    x = autograd.matmul(inputs, w0)
+    x = autograd.add_bias(x, b0)
+    x = autograd.relu(x)
+    x = autograd.matmul(x, w1)
+    x = autograd.add_bias(x, b1)
+    # x = autograd.softmax(x)
+    loss = autograd.softmax_cross_entropy(x, target)
+    print("start backward")
+    for p, gp in autograd.backward(loss):
+        sgd.apply(0, gp, p, "")
+    dev.EnableGraph(False)
+
+    # exec the buffered ops
+    print("start executing buffered functions")
+    for i in range(1001):
+        dev.RunGraph()
 
 Review comment:
   how is the data fed into the graph?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] chrishkchris commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
chrishkchris commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393511167
 
 

 ##########
 File path: examples/autograd/mnist_cnn_buffer.py
 ##########
 @@ -0,0 +1,271 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import singa_wrap as singa
+from singa import autograd
+from singa import tensor
+from singa import device
+from singa import opt
+import numpy as np
+import os
+import sys
+import gzip
+import codecs
+import time
+
+class CNN:
+    def __init__(self):
+        self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
+        self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
+        self.linear1 = autograd.Linear(4 * 4 * 50, 500)
+        self.linear2 = autograd.Linear(500, 10)
+        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
+        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = autograd.relu(y)
+        y = self.pooling1(y)
+        y = self.conv2(y)
+        y = autograd.relu(y)
+        y = self.pooling2(y)
+        y = autograd.flatten(y)
+        y = self.linear1(y)
+        y = autograd.relu(y)
+        y = self.linear2(y)
+        return y
+
+def check_dataset_exist(dirpath):
+    if not os.path.exists(dirpath):
+        print('The MNIST dataset does not exist. Please download the mnist dataset using download_mnist.py (e.g. python3 download_mnist.py)')
+        sys.exit(0)
+    return dirpath
+
+def load_dataset():
+    train_x_path = '/tmp/train-images-idx3-ubyte.gz'
+    train_y_path = '/tmp/train-labels-idx1-ubyte.gz'
+    valid_x_path = '/tmp/t10k-images-idx3-ubyte.gz'
+    valid_y_path = '/tmp/t10k-labels-idx1-ubyte.gz'
+
+    train_x = read_image_file(check_dataset_exist(train_x_path)).astype(
+        np.float32)
+    train_y = read_label_file(check_dataset_exist(train_y_path)).astype(
+        np.float32)
+    valid_x = read_image_file(check_dataset_exist(valid_x_path)).astype(
+        np.float32)
+    valid_y = read_label_file(check_dataset_exist(valid_y_path)).astype(
+        np.float32)
+    return train_x, train_y, valid_x, valid_y
+
+def read_label_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2049
+        length = get_int(data[4:8])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape(
+            (length))
+        return parsed
+
+def get_int(b):
+    return int(codecs.encode(b, 'hex'), 16)
+
+def read_image_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2051
+        length = get_int(data[4:8])
+        num_rows = get_int(data[8:12])
+        num_cols = get_int(data[12:16])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
+            (length, 1, num_rows, num_cols))
+        return parsed
+
+def to_categorical(y, num_classes):
+    y = np.array(y, dtype="int")
+    n = y.shape[0]
+    categorical = np.zeros((n, num_classes))
+    categorical[np.arange(n), y] = 1
+    categorical = categorical.astype(np.float32)
+    return categorical
+
+
+def accuracy(pred, target):
+    y = np.argmax(pred, axis=1)
+    t = np.argmax(target, axis=1)
+    a = y == t
+    return np.array(a, "int").sum()
+
+# Function to all reduce NUMPY Accuracy and Loss from Multiple Devices
+def reduce_variable(variable, dist_opt, reducer):
+    reducer.copy_from_numpy(variable)
+    dist_opt.all_reduce(reducer.data)
+    dist_opt.wait()
+    output=tensor.to_numpy(reducer)
+    return output
+
+# Function to sychronize SINGA TENSOR initial model parameters
+def sychronize(tensor, dist_opt):
+    dist_opt.all_reduce(tensor.data)
+    dist_opt.wait()
+    tensor /= dist_opt.world_size
+
+# Data augmentation
+def augmentation(x, batch_size):
+    xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
+    for data_num in range(0, batch_size):
+        offset = np.random.randint(8, size=2)
+        x[data_num,:,:,:] = xpad[data_num, :, offset[0]: offset[0] + 28, offset[1]: offset[1] + 28]
+        if_flip = np.random.randint(2)
+        if (if_flip):
+            x[data_num, :, :, :] = x[data_num, :, :, ::-1]
+    return x
+
+def train_mnist_cnn(sgd, max_epoch, batch_size, DIST=False, data_partition=None,
+                    gpu_num=None, gpu_per_node=None, nccl_id=None, spars=0, topK=False, corr=True):
+    # Prepare training and valadiation data
+    train_x, train_y, test_x, test_y = load_dataset()
+    IMG_SIZE = 28
+    num_classes = 10
+    train_y = to_categorical(train_y, num_classes)
+    test_y = to_categorical(test_y, num_classes)
+
+    # Normalization
+    train_x = train_x / 255
+    test_x = test_x / 255
+
+    if DIST:
+        # For Distributed GPU Training
+        sgd = opt.DistOpt(sgd, nccl_id=nccl_id, gpu_num=gpu_num, gpu_per_node=gpu_per_node)
+        dev = device.create_cuda_gpu_on(sgd.rank_in_local)
+        # Dataset partition for distributed training
+        train_x, train_y = data_partition(train_x, train_y, sgd.rank_in_global, sgd.world_size)
+        test_x, test_y = data_partition(test_x, test_y, sgd.rank_in_global, sgd.world_size)
+        world_size = sgd.world_size
+    else:
+        # For Single GPU
+        print("for single GPU")
+        dev = device.create_cuda_gpu_on(0)
+        device.set_default_device(dev)
+        dev.SetRandSeed(0)
+        world_size = 1
+
+    # create model
+    print("create the model")
+    model = CNN()
+
+    # create input tensors
+    print("create input tensors")
+    tx = tensor.Tensor((batch_size, 1, IMG_SIZE, IMG_SIZE), dev, tensor.float32)
+    ty = tensor.Tensor((batch_size, num_classes), dev, tensor.int32)
+    num_train_batch = train_x.shape[0] // batch_size
+    num_test_batch = test_x.shape[0] // batch_size
+    idx = np.arange(train_x.shape[0], dtype=np.int32)
+
+    if DIST:
+        #Sychronize the initial parameters
+        autograd.training = True
+        x = np.random.randn(batch_size, 1, IMG_SIZE, IMG_SIZE).astype(np.float32)
+        y = np.zeros( shape=(batch_size, num_classes), dtype=np.int32)
+        tx.copy_from_numpy(x)
+        ty.copy_from_numpy(y)
+        out = model.forward(tx)
+        loss = autograd.softmax_cross_entropy(out, ty)
+        for p, g in autograd.backward(loss):
+            sychronize(p, sgd)
+
+    # buffer all the operations in one iteration
+    print("buffer all the operations")
+    dev.EnableGraph(True)
+    autograd.training = True
+    out = model.forward(tx)
+    loss = autograd.softmax_cross_entropy(out, ty)
+    print("buffer softmax_cross_entropy")
+    for p, g in autograd.backward(loss):
+        sgd.update(p, g)
+        # print("update sgd")
+    autograd.training = False
+    dev.EnableGraph(False)
 
 Review comment:
   > I discussed with Chris yesterday, to support distributed training I need to change some code in  communicator.cc
   
   Yes, i.e. to buffer the all-reduce operation in communicator.cc, then it will support multi-process or MPI.
   This is because the training is exactly the same as in single-gpu and multi-process (MPI), the only difference is that multi-process one has all-reduce.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393502094
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
 
 Review comment:
   Now, we buffer all the operations in the graph and there is no separate buffer(i.e. buffer = graph).

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393497630
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import optimizer
+from singa import device
+import numpy as np
+
+
+if __name__ == "__main__":
+    dev = device.get_default_device()
+
+    autograd.training = True
+    np.random.seed(0)
+
+    # prepare training data in numpy array
+
+    # generate the boundary
+    f = lambda x: (5 * x + 1)
+    bd_x = np.linspace(-1.0, 1, 200)
+    bd_y = f(bd_x)
+    # generate the training data
+    x = np.random.uniform(-1, 1, 400)
+    y = f(x) + 2 * np.random.randn(len(x))
+    # convert training data to 2d space
+    label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+    data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+    def to_categorical(y, num_classes):
+        """
+        Converts a class vector (integers) to binary class matrix.
+
+        Args
+            y: class vector to be converted into a matrix
+                (integers from 0 to num_classes).
+            num_classes: total number of classes.
+
+        Return
+            A binary matrix representation of the input.
+        """
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    label = to_categorical(label, 2).astype(np.float32)
+    print("train_data_shape:", data.shape)
+    print("train_label_shape:", label.shape)
+
+    inputs = Tensor(data=data, device=dev)
+    target = Tensor(data=label, device=dev)
+
+    w0 = Tensor(shape=(2, 3), device=dev, requires_grad=True, stores_grad=True)
+    w0.gaussian(0.0, 0.1)
+    b0 = Tensor(shape=(1, 3), device=dev, requires_grad=True, stores_grad=True)
+    b0.set_value(0.0)
+
+    w1 = Tensor(shape=(3, 2), device=dev, requires_grad=True, stores_grad=True)
+    w1.gaussian(0.0, 0.1)
+    b1 = Tensor(shape=(1, 2), device=dev, requires_grad=True, stores_grad=True)
+    b1.set_value(0.0)
+
+    print("finished init inputs")
+    print("w0:\n", tensor.to_numpy(w0))
+    print("b0:\n", tensor.to_numpy(b0))
+    print("w1:\n", tensor.to_numpy(w1))
+    print("b1:\n", tensor.to_numpy(b1))
+
+    sgd = optimizer.SGD(0.05)
+
+    # training process
+    print("start training")
+
+    # Buffer the operations
+    dev.EnableGraph(True)
 
 Review comment:
   it would be better to make the code the same for buffer=On or Off. then we can switch between them easily.
   now this piece of code works when we use buffer + graph.
   if we do not use graph, we need to write another piece of code.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-605406893
 
 
   This pull request **fixes 1 alert** when merging 1f16fcce637caa1bfad2eedd601f8360174f6ec6 into 38fa20b74ba6ed414e4ecce65d42063ef1662564 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-fc85d06ec1aaaafd1996c469d1135d4e9d5417b7)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-606516955
 
 
   This pull request **fixes 1 alert** when merging e36b26849bb8e1f7fd2e9467154f9a4ffb9de4d7 into ba7839bda3868de673d9e48d59b544ad5a47c79b - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-38a7312f4945cca2eac32ef0083438b0c6ea0f66)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] XJDKC commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
XJDKC commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-605597959
 
 
   > have you rebase/merged with the latest dev?
   > ready for merge?
   
   I have rebased it before but did not rebase on the latest dev branch. I'll rebase my pr again and then this pr is ready for merge.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-599754062
 
 
   This pull request **introduces 2 alerts** and **fixes 1** when merging 3123a6a6195bf439b47d65b04559f4af9181c70c into 5b478113f80adda5ffc6009cd6539a7c9f47f76d - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-8d43bbdb407edc9a940c3b5bda82a9fdede94a11)
   
   **new alerts:**
   
   * 2 for Wrong type of arguments to formatting function
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
XJDKC commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393504236
 
 

 ##########
 File path: examples/autograd/mnist_cnn_buffer.py
 ##########
 @@ -0,0 +1,271 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import singa_wrap as singa
+from singa import autograd
+from singa import tensor
+from singa import device
+from singa import opt
+import numpy as np
+import os
+import sys
+import gzip
+import codecs
+import time
+
+class CNN:
+    def __init__(self):
+        self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
+        self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
+        self.linear1 = autograd.Linear(4 * 4 * 50, 500)
+        self.linear2 = autograd.Linear(500, 10)
+        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
+        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = autograd.relu(y)
+        y = self.pooling1(y)
+        y = self.conv2(y)
+        y = autograd.relu(y)
+        y = self.pooling2(y)
+        y = autograd.flatten(y)
+        y = self.linear1(y)
+        y = autograd.relu(y)
+        y = self.linear2(y)
+        return y
+
+def check_dataset_exist(dirpath):
+    if not os.path.exists(dirpath):
+        print('The MNIST dataset does not exist. Please download the mnist dataset using download_mnist.py (e.g. python3 download_mnist.py)')
+        sys.exit(0)
+    return dirpath
+
+def load_dataset():
+    train_x_path = '/tmp/train-images-idx3-ubyte.gz'
+    train_y_path = '/tmp/train-labels-idx1-ubyte.gz'
+    valid_x_path = '/tmp/t10k-images-idx3-ubyte.gz'
+    valid_y_path = '/tmp/t10k-labels-idx1-ubyte.gz'
+
+    train_x = read_image_file(check_dataset_exist(train_x_path)).astype(
+        np.float32)
+    train_y = read_label_file(check_dataset_exist(train_y_path)).astype(
+        np.float32)
+    valid_x = read_image_file(check_dataset_exist(valid_x_path)).astype(
+        np.float32)
+    valid_y = read_label_file(check_dataset_exist(valid_y_path)).astype(
+        np.float32)
+    return train_x, train_y, valid_x, valid_y
+
+def read_label_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2049
+        length = get_int(data[4:8])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape(
+            (length))
+        return parsed
+
+def get_int(b):
+    return int(codecs.encode(b, 'hex'), 16)
+
+def read_image_file(path):
+    with gzip.open(path, 'rb') as f:
+        data = f.read()
+        assert get_int(data[:4]) == 2051
+        length = get_int(data[4:8])
+        num_rows = get_int(data[8:12])
+        num_cols = get_int(data[12:16])
+        parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
+            (length, 1, num_rows, num_cols))
+        return parsed
+
+def to_categorical(y, num_classes):
+    y = np.array(y, dtype="int")
+    n = y.shape[0]
+    categorical = np.zeros((n, num_classes))
+    categorical[np.arange(n), y] = 1
+    categorical = categorical.astype(np.float32)
+    return categorical
+
+
+def accuracy(pred, target):
+    y = np.argmax(pred, axis=1)
+    t = np.argmax(target, axis=1)
+    a = y == t
+    return np.array(a, "int").sum()
+
+# Function to all reduce NUMPY Accuracy and Loss from Multiple Devices
+def reduce_variable(variable, dist_opt, reducer):
+    reducer.copy_from_numpy(variable)
+    dist_opt.all_reduce(reducer.data)
+    dist_opt.wait()
+    output=tensor.to_numpy(reducer)
+    return output
+
+# Function to sychronize SINGA TENSOR initial model parameters
+def sychronize(tensor, dist_opt):
+    dist_opt.all_reduce(tensor.data)
+    dist_opt.wait()
+    tensor /= dist_opt.world_size
+
+# Data augmentation
+def augmentation(x, batch_size):
+    xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
+    for data_num in range(0, batch_size):
+        offset = np.random.randint(8, size=2)
+        x[data_num,:,:,:] = xpad[data_num, :, offset[0]: offset[0] + 28, offset[1]: offset[1] + 28]
+        if_flip = np.random.randint(2)
+        if (if_flip):
+            x[data_num, :, :, :] = x[data_num, :, :, ::-1]
+    return x
+
+def train_mnist_cnn(sgd, max_epoch, batch_size, DIST=False, data_partition=None,
+                    gpu_num=None, gpu_per_node=None, nccl_id=None, spars=0, topK=False, corr=True):
+    # Prepare training and valadiation data
+    train_x, train_y, test_x, test_y = load_dataset()
+    IMG_SIZE = 28
+    num_classes = 10
+    train_y = to_categorical(train_y, num_classes)
+    test_y = to_categorical(test_y, num_classes)
+
+    # Normalization
+    train_x = train_x / 255
+    test_x = test_x / 255
+
+    if DIST:
+        # For Distributed GPU Training
+        sgd = opt.DistOpt(sgd, nccl_id=nccl_id, gpu_num=gpu_num, gpu_per_node=gpu_per_node)
+        dev = device.create_cuda_gpu_on(sgd.rank_in_local)
+        # Dataset partition for distributed training
+        train_x, train_y = data_partition(train_x, train_y, sgd.rank_in_global, sgd.world_size)
+        test_x, test_y = data_partition(test_x, test_y, sgd.rank_in_global, sgd.world_size)
+        world_size = sgd.world_size
+    else:
+        # For Single GPU
+        print("for single GPU")
+        dev = device.create_cuda_gpu_on(0)
+        device.set_default_device(dev)
+        dev.SetRandSeed(0)
+        world_size = 1
+
+    # create model
+    print("create the model")
+    model = CNN()
+
+    # create input tensors
+    print("create input tensors")
+    tx = tensor.Tensor((batch_size, 1, IMG_SIZE, IMG_SIZE), dev, tensor.float32)
+    ty = tensor.Tensor((batch_size, num_classes), dev, tensor.int32)
+    num_train_batch = train_x.shape[0] // batch_size
+    num_test_batch = test_x.shape[0] // batch_size
+    idx = np.arange(train_x.shape[0], dtype=np.int32)
+
+    if DIST:
+        #Sychronize the initial parameters
+        autograd.training = True
+        x = np.random.randn(batch_size, 1, IMG_SIZE, IMG_SIZE).astype(np.float32)
+        y = np.zeros( shape=(batch_size, num_classes), dtype=np.int32)
+        tx.copy_from_numpy(x)
+        ty.copy_from_numpy(y)
+        out = model.forward(tx)
+        loss = autograd.softmax_cross_entropy(out, ty)
+        for p, g in autograd.backward(loss):
+            sychronize(p, sgd)
+
+    # buffer all the operations in one iteration
+    print("buffer all the operations")
+    dev.EnableGraph(True)
+    autograd.training = True
+    out = model.forward(tx)
+    loss = autograd.softmax_cross_entropy(out, ty)
+    print("buffer softmax_cross_entropy")
+    for p, g in autograd.backward(loss):
+        sgd.update(p, g)
+        # print("update sgd")
+    autograd.training = False
+    dev.EnableGraph(False)
 
 Review comment:
   I discussed with Chris yesterday, to support distributed training I need to change some code in  [communicator.cc](../tree/dev/src/io/communicator.cc)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-599923097
 
 
   This pull request **introduces 2 alerts** and **fixes 1** when merging fa818834ae3b704c2d9dba5e48cb3c2473cedc8a into 5b478113f80adda5ffc6009cd6539a7c9f47f76d - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-04e56d56b3a2b8f5ee622417028a13977cfc7acc)
   
   **new alerts:**
   
   * 2 for Wrong type of arguments to formatting function
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-604864503
 
 
   This pull request **introduces 1 alert** and **fixes 1** when merging 286d5563171d3bf8a291349af76085b0f3299323 into 38fa20b74ba6ed414e4ecce65d42063ef1662564 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-8e24bc38d3bd3fe68f22ff209df4c87e4a2893b9)
   
   **new alerts:**
   
   * 1 for Unused import
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
lgtm-com[bot] commented on issue #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#issuecomment-599954065
 
 
   This pull request **fixes 1 alert** when merging b2342b92126924502f1d7f372a0e519d6ef93559 into 65010c0b83b193e915de00ff47f58f97089750c0 - [view on LGTM.com](https://lgtm.com/projects/g/apache/singa/rev/pr-5d95a56b0dad26c77160ed668109784c56b592b4)
   
   **fixed alerts:**
   
   * 1 for Missing return statement

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [singa] nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization

Posted by GitBox <gi...@apache.org>.
nudles commented on a change in pull request #626: [WIP] SINGA-505 Computational graph with memory optimization
URL: https://github.com/apache/singa/pull/626#discussion_r393496613
 
 

 ##########
 File path: examples/autograd/mlp_buffer.py
 ##########
 @@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
 
 Review comment:
   are all new files of mode 644 or 755?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services