You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/27 19:48:31 UTC

[GitHub] piiswrong closed pull request #9074: Rnn example updates

piiswrong closed pull request #9074: Rnn example updates
URL: https://github.com/apache/incubator-mxnet/pull/9074
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/gluon/word_language_model/get_ptb_data.sh b/example/gluon/word_language_model/get_ptb_data.sh
index d2641cb32b..2dc4034a93 100755
--- a/example/gluon/word_language_model/get_ptb_data.sh
+++ b/example/gluon/word_language_model/get_ptb_data.sh
@@ -17,6 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+echo
+echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
+echo
+
+if [ $REPLY != "Y" ]
+then
+    echo "License was not reviewed, aborting script."
+    exit 1
+fi
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"
diff --git a/example/model-parallel/lstm/get_ptb_data.sh b/example/model-parallel/lstm/get_ptb_data.sh
index d2641cb32b..2dc4034a93 100755
--- a/example/model-parallel/lstm/get_ptb_data.sh
+++ b/example/model-parallel/lstm/get_ptb_data.sh
@@ -17,6 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+echo
+echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
+echo
+
+if [ $REPLY != "Y" ]
+then
+    echo "License was not reviewed, aborting script."
+    exit 1
+fi
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"
diff --git a/example/rnn-time-major/bucket_io.py b/example/rnn-time-major/bucket_io.py
index 950b0c05cf..e689ff1126 100644
--- a/example/rnn-time-major/bucket_io.py
+++ b/example/rnn-time-major/bucket_io.py
@@ -17,9 +17,8 @@
 
 # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
 # pylint: disable=superfluous-parens, no-member, invalid-name
+
 from __future__ import print_function
-import sys
-sys.path.insert(0, "../../python")
 import numpy as np
 import mxnet as mx
 
@@ -206,7 +205,7 @@ def make_data_iter_plan(self):
             bucket_n_batches.append(len(self.data[i]) / self.batch_size)
             self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)]
 
-        bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)])
+        bucket_plan = np.hstack([np.zeros(int(n), int)+i for i, n in enumerate(bucket_n_batches)])
         np.random.shuffle(bucket_plan)
 
         bucket_idx_all = [np.random.permutation(len(x)) for x in self.data]
diff --git a/example/rnn-time-major/get_ptb_data.sh b/example/rnn-time-major/get_ptb_data.sh
index d2641cb32b..2dc4034a93 100755
--- a/example/rnn-time-major/get_ptb_data.sh
+++ b/example/rnn-time-major/get_ptb_data.sh
@@ -17,6 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+echo
+echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
+echo
+
+if [ $REPLY != "Y" ]
+then
+    echo "License was not reviewed, aborting script."
+    exit 1
+fi
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"
diff --git a/example/rnn-time-major/readme.md b/example/rnn-time-major/readme.md
new file mode 100644
index 0000000000..7983fe8c6b
--- /dev/null
+++ b/example/rnn-time-major/readme.md
@@ -0,0 +1,24 @@
+Time major data layout for RNN
+==============================
+
+This example demonstrates an RNN implementation with Time-major layout. This implementation shows 1.5x-2x speedups compared to Batch-major RNN.
+	
+As example of Batch-major RNN is available in MXNet [RNN Bucketing example](https://github.com/apache/incubator-mxnet/tree/master/example/rnn/bucketing)
+	
+## Running the example
+- Prerequisite: an instance with GPU compute resources is required to run MXNet RNN
+- Make the shell script ```get_ptb_data.sh``` executable:
+    ```bash 
+    chmod +x get_ptb_data.sh
+    ```
+- Run ```get_ptb_data.sh``` to download the PTB dataset, and follow the instructions to review the license:
+    ```bash
+    ./get_ptb_data.sh
+    ```
+    The PTB data sets will be downloaded into ./data directory, and available for the example to train on.
+- Run the example:
+    ```bash
+    python python rnn_cell_demo.py
+    ```
+    
+    If everything goes well, console will plot training speed and perplexity that you can compare to the batch major RNN.
diff --git a/example/rnn-time-major/rnn_cell_demo.py b/example/rnn-time-major/rnn_cell_demo.py
index c29d1ddea4..cf1e0a0cd1 100644
--- a/example/rnn-time-major/rnn_cell_demo.py
+++ b/example/rnn-time-major/rnn_cell_demo.py
@@ -48,17 +48,26 @@
 ################################################################################
 
 import os
-
 import numpy as np
 import mxnet as mx
 
 from bucket_io import BucketSentenceIter, default_build_vocab
 
-
 data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data'))
 
 
 def Perplexity(label, pred):
+    """ Calculates prediction perplexity
+
+    Args:
+        label (mx.nd.array): labels array
+        pred (mx.nd.array): prediction array
+
+    Returns:
+        float: calculated perplexity
+
+    """
+
     # collapse the time, batch dimension
     label = label.reshape((-1,))
     pred = pred.reshape((-1, pred.shape[-1]))
@@ -80,7 +89,10 @@ def Perplexity(label, pred):
     learning_rate = 0.01
     momentum = 0.0
 
-    contexts = [mx.context.gpu(i) for i in range(1)]
+    # Update count per available GPUs
+    gpu_count = 1
+    contexts = [mx.context.gpu(i) for i in range(gpu_count)]
+
     vocab = default_build_vocab(os.path.join(data_dir, 'ptb.train.txt'))
 
     init_h = [mx.io.DataDesc('LSTM_state', (num_lstm_layer, batch_size, num_hidden), layout='TNC')]
@@ -95,6 +107,15 @@ def Perplexity(label, pred):
                                   time_major=True)
 
     def sym_gen(seq_len):
+        """ Generates the MXNet symbol for the RNN
+
+        Args:
+            seq_len (int): input sequence length
+
+        Returns:
+            tuple: tuple containing symbol, data_names, label_names
+
+        """
         data = mx.sym.Variable('data')
         label = mx.sym.Variable('softmax_label')
         embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
@@ -146,7 +167,7 @@ def sym_gen(seq_len):
         data_names = ['data', 'LSTM_state', 'LSTM_state_cell']
         label_names = ['softmax_label']
 
-        return (sm, data_names, label_names)
+        return sm, data_names, label_names
 
     if len(buckets) == 1:
         mod = mx.mod.Module(*sym_gen(buckets[0]), context=contexts)
diff --git a/example/rnn/bucketing/get_ptb_data.sh b/example/rnn/bucketing/get_ptb_data.sh
index d2641cb32b..2dc4034a93 100755
--- a/example/rnn/bucketing/get_ptb_data.sh
+++ b/example/rnn/bucketing/get_ptb_data.sh
@@ -17,6 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+echo
+echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
+echo
+
+if [ $REPLY != "Y" ]
+then
+    echo "License was not reviewed, aborting script."
+    exit 1
+fi
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"
diff --git a/example/rnn/old/get_ptb_data.sh b/example/rnn/old/get_ptb_data.sh
index d2641cb32b..2dc4034a93 100755
--- a/example/rnn/old/get_ptb_data.sh
+++ b/example/rnn/old/get_ptb_data.sh
@@ -17,6 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+echo
+echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
+echo
+
+if [ $REPLY != "Y" ]
+then
+    echo "License was not reviewed, aborting script."
+    exit 1
+fi
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"
diff --git a/example/rnn/word_lm/get_ptb_data.sh b/example/rnn/word_lm/get_ptb_data.sh
index 0a0c7051b0..2dc4034a93 100755
--- a/example/rnn/word_lm/get_ptb_data.sh
+++ b/example/rnn/word_lm/get_ptb_data.sh
@@ -17,11 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
-echo ""
-echo "NOTE: Please review the licensing of the datasets in this script before proceeding"
+echo
+echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
 echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
-echo "Once that is done, please uncomment the wget commands in this script"
-echo ""
+read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
+echo
+
+if [ $REPLY != "Y" ]
+then
+    echo "License was not reviewed, aborting script."
+    exit 1
+fi
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"
@@ -31,7 +37,7 @@ if [[ ! -d "${DATA_DIR}" ]]; then
   mkdir -p ${DATA_DIR}
 fi
 
-#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
-#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
-#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
-#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
+wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
+wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
+wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
+wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
diff --git a/perl-package/AI-MXNet/examples/get_ptb_data.sh b/perl-package/AI-MXNet/examples/get_ptb_data.sh
index d2641cb32b..2dc4034a93 100755
--- a/perl-package/AI-MXNet/examples/get_ptb_data.sh
+++ b/perl-package/AI-MXNet/examples/get_ptb_data.sh
@@ -17,6 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+echo
+echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
+echo
+
+if [ $REPLY != "Y" ]
+then
+    echo "License was not reviewed, aborting script."
+    exit 1
+fi
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services