You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/21 16:34:05 UTC

[systemml] branch master updated: [SYSTEMML-540] Bugfix for Python 3+ and updated the documentation

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new ef8b10a  [SYSTEMML-540] Bugfix for Python 3+ and updated the documentation
ef8b10a is described below

commit ef8b10a964a4b620f43e627303b85616d9abb502
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Thu Mar 21 09:29:07 2019 -0700

    [SYSTEMML-540] Bugfix for Python 3+ and updated the documentation
    
    - Added a quick tour of the documentation in the overview page.
    - Updated GPU documentation to explain how to resolve common setup issues.
    - Updated Keras2DML documentation to be compatible with the recently added features.
    - Updated mllearn documentation to include Keras2DML.
---
 docs/beginners-guide-keras2dml.md               |  89 +++++++++++--
 docs/deep-learning.md                           |  26 +++-
 docs/gpu.md                                     | 162 +++++++++++++++++++-----
 docs/index.md                                   |  25 ++++
 docs/native-backend.md                          |  10 ++
 docs/python-reference.md                        |   8 +-
 src/main/python/systemml/mllearn/keras2caffe.py |  37 +++---
 7 files changed, 288 insertions(+), 69 deletions(-)

diff --git a/docs/beginners-guide-keras2dml.md b/docs/beginners-guide-keras2dml.md
index c99334e..60de360 100644
--- a/docs/beginners-guide-keras2dml.md
+++ b/docs/beginners-guide-keras2dml.md
@@ -45,23 +45,88 @@ Keras models are parsed based on their layer structure and corresponding weights
 configuration. Be aware that currently this is a translation into Caffe and there will be loss of information from keras models such as 
 intializer information, and other layers which do not exist in Caffe. 
 
+First, install SystemML and other dependencies for the below demo:
+
+```
+pip install systemml keras tensorflow mlxtend
+``` 
+
 To create a Keras2DML object, simply pass the keras object to the Keras2DML constructor. It's also important to note that your models
-should be compiled so that the loss can be accessed for Caffe2DML
+should be compiled so that the loss can be accessed for Caffe2DML.
 
-```python
-from systemml.mllearn import Keras2DML
-import keras
-from keras.applications.resnet50 import preprocess_input, decode_predictions, ResNet50
 
-keras_model = ResNet50(weights='imagenet',include_top=True,pooling='None',input_shape=(224,224,3))
-keras_model.compile(optimizer='sgd', loss= 'categorical_crossentropy')
 
-sysml_model = Keras2DML(spark, keras_model,input_shape=(3,224,224))
-sysml_model.summary()
+```python
+# pyspark --driver-memory 20g
+
+# Disable Tensorflow from using GPU to avoid unnecessary evictions by SystemML runtime
+import os
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ['CUDA_VISIBLE_DEVICES'] = ''
+
+# Import dependencies
+from mlxtend.data import mnist_data
+import numpy as np
+from sklearn.utils import shuffle
+from keras.models import Sequential
+from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout,Flatten
+from keras import backend as K
+from keras.models import Model
+from keras.optimizers import SGD
+
+# Set channel first layer
+K.set_image_data_format('channels_first')
+
+# Download the MNIST dataset
+X, y = mnist_data()
+X, y = shuffle(X, y)
+
+# Split the data into training and test
+n_samples = len(X)
+X_train = X[:int(.9 * n_samples)]
+y_train = y[:int(.9 * n_samples)]
+X_test = X[int(.9 * n_samples):]
+y_test = y[int(.9 * n_samples):]
+
+# Define Lenet in Keras
+keras_model = Sequential()
+keras_model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(1,28,28), padding='same'))
+keras_model.add(MaxPooling2D(pool_size=(2, 2)))
+keras_model.add(Conv2D(64, (5, 5), activation='relu', padding='same'))
+keras_model.add(MaxPooling2D(pool_size=(2, 2)))
+keras_model.add(Flatten())
+keras_model.add(Dense(512, activation='relu'))
+keras_model.add(Dropout(0.5))
+keras_model.add(Dense(10, activation='softmax'))
+keras_model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True))
+keras_model.summary()
+
+# Scale the input features
+scale = 0.00390625
+X_train = X_train*scale
+X_test = X_test*scale
+
+# Train Lenet using SystemML
+from systemml.mllearn import Keras2DML
+sysml_model = Keras2DML(spark, keras_model, weights='weights_dir')
+# sysml_model.setConfigProperty("sysml.native.blas", "auto")
+# sysml_model.setGPU(True).setForceGPU(True)
+sysml_model.fit(X_train, y_train)
+sysml_model.score(X_test, y_test)
 ```
 
 # Frequently asked questions
 
+#### How can I get the training and prediction DML script for the Keras model?
+
+The training and prediction DML scripts can be generated using `get_training_script()` and `get_prediction_script()` methods.
+
+```python
+from systemml.mllearn import Keras2DML
+sysml_model = Keras2DML(spark, keras_model, input_shape=(3,224,224))
+print(sysml_model.get_training_script())
+```
+
 #### What is the mapping between Keras' parameters and Caffe's solver specification ? 
 
 |                                                        | Specified via the given parameter in the Keras2DML constructor | From input Keras' model                                                                 | Corresponding parameter in the Caffe solver file |
@@ -134,3 +199,9 @@ For example: for the expression `Keras2DML(..., display=100, test_iter=10, test_
 - display the training loss and accuracy every 100 iterations and
 - carry out validation every 500 training iterations and display validation loss and accuracy.
 
+#### How do you ensure that Keras2DML produce same results as other Keras' backend?
+
+To verify that Keras2DML produce same results as other Keras' backend, we have [Python unit tests](https://github.com/apache/systemml/blob/master/src/main/python/tests/test_nn_numpy.py)
+that compare the results of Keras2DML with that of TensorFlow. We assume that Keras team ensure that all their backends are consistent with their TensorFlow backend.
+
+
diff --git a/docs/deep-learning.md b/docs/deep-learning.md
index 344150e..2dbb4bb 100644
--- a/docs/deep-learning.md
+++ b/docs/deep-learning.md
@@ -184,13 +184,22 @@ lenet.score(X_test, y_test)
 
 <div data-lang="Keras2DML" markdown="1">
 {% highlight python %}
+# Disable Tensorflow from using GPU to avoid unnecessary evictions by SystemML runtime
+import os
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ['CUDA_VISIBLE_DEVICES'] = ''
+
 from keras.models import Sequential
 from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout,Flatten
 from keras import backend as K
 from keras.models import Model
-input_shape = (1,28,28) if K.image_data_format() == 'channels_first' else (28,28, 1)
+from keras.optimizers import SGD
+
+# Set channel first layer
+K.set_image_data_format('channels_first')
+
 keras_model = Sequential()
-keras_model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=input_shape, padding='same'))
+keras_model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(1,28,28), padding='same'))
 keras_model.add(MaxPooling2D(pool_size=(2, 2)))
 keras_model.add(Conv2D(64, (5, 5), activation='relu', padding='same'))
 keras_model.add(MaxPooling2D(pool_size=(2, 2)))
@@ -206,7 +215,7 @@ X_train = X_train*scale
 X_test = X_test*scale
 
 from systemml.mllearn import Keras2DML
-sysml_model = Keras2DML(spark, keras_model, input_shape=(1,28,28), weights='weights_dir')
+sysml_model = Keras2DML(spark, keras_model, weights='weights_dir')
 # sysml_model.setConfigProperty("sysml.native.blas", "auto")
 # sysml_model.setGPU(True).setForceGPU(True)
 sysml_model.summary()
@@ -235,13 +244,22 @@ Will be added soon ...
 
 <div data-lang="Keras2DML" markdown="1">
 {% highlight python %}
+# Disable Tensorflow from using GPU to avoid unnecessary evictions by SystemML runtime
+import os
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ['CUDA_VISIBLE_DEVICES'] = ''
+
+# Set channel first layer
+from keras import backend as K
+K.set_image_data_format('channels_first')
+
 from systemml.mllearn import Keras2DML
 import systemml as sml
 import keras, urllib
 from PIL import Image
 from keras.applications.resnet50 import preprocess_input, decode_predictions, ResNet50
 
-keras_model = ResNet50(weights='imagenet',include_top=True,pooling='None',input_shape=(224,224,3))
+keras_model = ResNet50(weights='imagenet',include_top=True,pooling='None',input_shape=(3,224,224))
 keras_model.compile(optimizer='sgd', loss= 'categorical_crossentropy')
 
 sysml_model = Keras2DML(spark,keras_model,input_shape=(3,224,224), weights='weights_dir', labels='https://raw.githubusercontent.com/apache/systemml/master/scripts/nn/examples/caffe2dml/models/imagenet/labels.txt')
diff --git a/docs/gpu.md b/docs/gpu.md
index 5e13e60..f334b47 100644
--- a/docs/gpu.md
+++ b/docs/gpu.md
@@ -32,11 +32,49 @@ limitations under the License.
 To use SystemML on GPUs, please ensure that [CUDA 9](https://developer.nvidia.com/cuda-90-download-archive) and
 [CuDNN 7](https://developer.nvidia.com/cudnn) is installed on your system.
 
+```
+$ nvcc --version | grep release
+Cuda compilation tools, release 9.0, V9.0.176
+$ cat /usr/local/cuda/include/cudnn.h | grep "CUDNN_MAJOR\|CUDNN_MINOR"
+#define CUDNN_MAJOR 7
+#define CUDNN_MINOR 0
+#define CUDNN_VERSION    (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
+```
+
+Depending on the API, the GPU backend can be enabled in different way:
+
+1. When invoking SystemML from command-line, the GPU backend can be enabled by providing the command-line `-gpu` flag.
+2. When invoking SystemML using the (Python or Scala) MLContext and MLLearn (includes Caffe2DML and Keras2DML) APIs, please use the `setGPU(enable)` method.
+3. When invoking SystemML using the JMLC API, please set the `useGpu` parameter in `org.apache.sysml.api.jmlc.Connection` class's `prepareScript` method.
+
+Python users do not need to explicitly provide the jar during their invocation. 
+For all other APIs, please remember to include the `systemml-*-extra.jar` in the classpath as described below.
+
+## Command-line users
+
+To enable the GPU backend via command-line, please provide `systemml-1.*-extra.jar` in the classpath and `-gpu` flag.
+
+```
+spark-submit --jars systemml-*-extra.jar SystemML.jar -f myDML.dml -gpu
+``` 
+
+To skip memory-checking and force all GPU-enabled operations on the GPU, please provide `force` option to the `-gpu` flag.
+
+```
+spark-submit --jars systemml-*-extra.jar SystemML.jar -f myDML.dml -gpu force
+``` 
+
 ## Python users
 
 Please install SystemML using pip:
 - For released version: `pip install systemml`
-- For bleeding edge version: `pip install https://sparktc.ibmcloud.com/repo/latest/systemml-1.2.0-SNAPSHOT-python.tar.gz`
+- For bleeding edge version: 
+```
+git clone https://github.com/apache/systemml.git
+cd systemml
+mvn package -P distribution
+pip install target/systemml-*-SNAPSHOT-python.tar.gz
+```
 
 Then you can use the `setGPU(True)` method of [MLContext](http://apache.github.io/systemml/spark-mlcontext-programming-guide.html) and 
 [MLLearn](http://apache.github.io/systemml/beginners-guide-python.html#invoke-systemmls-algorithms) APIs to enable the GPU usage.
@@ -54,45 +92,15 @@ lenet = Caffe2DML(spark, solver='lenet_solver.proto', input_shape=(1, 28, 28))
 lenet.setGPU(True).setForceGPU(True)
 ```
 
-## Command-line users
-
-To enable the GPU backend via command-line, please provide `systemml-1.*-extra.jar` in the classpath and `-gpu` flag.
-
-```
-spark-submit --jars systemml-1.*-extra.jar SystemML.jar -f myDML.dml -gpu
-``` 
-
-To skip memory-checking and force all GPU-enabled operations on the GPU, please provide `force` option to the `-gpu` flag.
-
-```
-spark-submit --jars systemml-1.*-extra.jar SystemML.jar -f myDML.dml -gpu force
-``` 
-
 ## Scala users
 
-To enable the GPU backend via command-line, please provide `systemml-1.*-extra.jar` in the classpath and use 
+To enable the GPU backend via command-line, please provide `systemml-*-extra.jar` in the classpath and use 
 the `setGPU(True)` method of [MLContext](http://apache.github.io/systemml/spark-mlcontext-programming-guide.html) API to enable the GPU usage.
 
 ```
-spark-shell --jars systemml-1.*-extra.jar,SystemML.jar
+spark-shell --jars systemml-*-extra.jar,SystemML.jar
 ``` 
 
-# Troubleshooting guide
-
-- If you have older gcc (< 5.0) and if you get `libstdc++.so.6: version CXXABI_1.3.8 not found` error, please upgrade to gcc v5+. 
-On Centos 5, you may have to compile gcc from the source:
-
-```
-sudo yum install libmpc-devel mpfr-devel gmp-devel zlib-devel*
-curl ftp://ftp.gnu.org/pub/gnu/gcc/gcc-5.3.0/gcc-5.3.0.tar.bz2 -O
-tar xvfj gcc-5.3.0.tar.bz2
-cd gcc-5.3.0
-./configure --with-system-zlib --disable-multilib --enable-languages=c,c++
-num_cores=`grep -c ^processor /proc/cpuinfo`
-make -j $num_cores
-sudo make install
-```
-
 # Advanced Configuration
 
 ## Using single precision
@@ -117,4 +125,90 @@ and can potentially lead to OOM if the network is deep as well as wide.
 By default, SystemML uses CUDA's memory allocator and performs on-demand eviction
 using the eviction policy set by the configuration property 'sysml.gpu.eviction.policy'.
 To use CUDA's unified memory allocator that performs page-level eviction instead,
-please set the configuration property 'sysml.gpu.memory.allocator' to 'unified_memory'.
\ No newline at end of file
+please set the configuration property 'sysml.gpu.memory.allocator' to 'unified_memory'.
+
+
+# Frequently asked questions
+
+### How do I find the CUDA and CuDNN version on my system?
+
+- Make sure `/usr/local/cuda` is pointing to the right CUDA version.
+
+```
+ls -l /usr/local/cuda
+```
+
+- Get the CUDA version using `nvcc`
+
+```
+$ nvcc --version | grep release
+Cuda compilation tools, release 9.0, V9.0.176
+```
+
+- Get the CuDNN version using the `cudnn.h` header file.
+
+```
+$ cat /usr/local/cuda/include/cudnn.h | grep "CUDNN_MAJOR\|CUDNN_MINOR"
+#define CUDNN_MAJOR 7
+#define CUDNN_MINOR 0
+#define CUDNN_VERSION    (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
+```
+
+
+### How do I verify the CUDA and CuDNN version that SystemML depends on?
+
+- Check the `jcuda.version` property in SystemML's `pom.xml` file.
+- Then find the CUDA dependency in [JCuda's documentation](http://www.jcuda.org/downloads/downloads.html).
+- For you reference, here are the corresponding CUDA and CuDNN versions for given JCuda version:
+
+| JCuda  | CUDA    | CuDNN |
+|--------|---------|-------|
+| 0.9.2  | 9.2     | 7.2   |
+| 0.9.0d | 9.0.176 | 7.0.2 |
+| 0.8.0  | 8.0.44  | 5.1   |
+
+
+### How do I verify that CUDA is installed correctly?
+
+- Make sure `/usr/local/cuda` is pointing to the right CUDA version.
+- Make sure that `/usr/local/cuda/bin` are in your PATH.
+```
+$ nvcc --version
+$ nvidia-smi 
+```
+- Make sure that `/usr/local/cuda/lib64` are in your `LD_LIBRARY_PATH`.
+- Test using CUDA samples
+```
+$ cd /usr/local/cuda-9.0/samples/
+$ sudo make
+$ ./bin/x86_64/linux/release/deviceQuery
+$ ./bin/x86_64/linux/release/bandwidthTest 
+$ ./bin/x86_64/linux/release/matrixMulCUBLAS 
+```
+
+### How to install CUDA 9 on Centos 7 with yum?
+
+```
+sudo yum install cuda-9-0.x86_64
+sudo ln -sfn /usr/local/cuda-9.0/ /usr/local/cuda
+```
+
+### What is the driver requirement for CUDA 9?
+
+As per [Nvidia's documentation](https://docs.nvidia.com/deploy/cuda-compatibility/index.html), the drivers have to be `>= 384.81` version.
+
+### What do I do if I get `CXXABI_1.3.8 not found` error?
+
+If you have older gcc (< 5.0) and if you get `libstdc++.so.6: version CXXABI_1.3.8 not found` error, please upgrade to gcc v5+. 
+On Centos 5, you may have to compile gcc from the source:
+
+```
+sudo yum install libmpc-devel mpfr-devel gmp-devel zlib-devel*
+curl ftp://ftp.gnu.org/pub/gnu/gcc/gcc-5.3.0/gcc-5.3.0.tar.bz2 -O
+tar xvfj gcc-5.3.0.tar.bz2
+cd gcc-5.3.0
+./configure --with-system-zlib --disable-multilib --enable-languages=c,c++
+num_cores=`grep -c ^processor /proc/cpuinfo`
+make -j $num_cores
+sudo make install
+```
diff --git a/docs/index.md b/docs/index.md
index 8117735..e7f16f3 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -38,6 +38,31 @@ To download SystemML, visit the [downloads](http://systemml.apache.org/download)
 
 This version of SystemML supports: Java 8+, Scala 2.11+, Python 2.7/3.5+, Hadoop 2.6+, and Spark 2.1+.
 
+## Quick tour of the documentation
+
+* If you are new to SystemML, please refer to the [installation guide](http://systemml.apache.org/install-systemml.html) and try out our [sample notebooks](http://systemml.apache.org/get-started.html#sample-notebook)
+* If you want to invoke one of our [pre-implemented algorithms](algorithms-reference):
+  * Using Python, consider using 
+    * the convenient [mllearn API](http://apache.github.io/systemml/python-reference.html#mllearn-api). The usage is describe in our [beginner's guide](http://apache.github.io/systemml/beginners-guide-python.html#invoke-systemmls-algorithms)  
+    * OR [Spark MLContext](spark-mlcontext-programming-guide) API
+  * Using Java/Scala, consider using 
+    * [Spark MLContext](spark-mlcontext-programming-guide) API for large datasets
+    * OR [JMLC](jmlc) API for in-memory scoring
+  * Via Command-line, follow the usage section in the [Algorithms Reference](algorithms-reference) 
+* If you want to implement a deep neural network, consider
+  * specifying your network in [Keras](https://keras.io/) format and invoking it with our [Keras2DML](beginners-guide-keras2dml) API
+  * OR specifying your network in [Caffe](http://caffe.berkeleyvision.org/) format and invoking it with our [Caffe2DML](beginners-guide-caffe2dml) API
+  * OR Using DML-bodied [NN library](https://github.com/apache/systemml/tree/master/scripts/nn). The usage is described in our [sample notebook](https://github.com/apache/systemml/blob/master/samples/jupyter-notebooks/Deep%20Learning%20Image%20Classification.ipynb)
+* Since training a deep neural network is often compute-bound, you may want to
+  * Enable [native BLAS](native-backend) in SystemML
+  * OR run it [using our GPU backend](gpu)  
+* If you want to implement a custom machine learning algorithm and you are familiar with:
+  * [R](https://www.r-project.org/about.html), consider implementing your algorithm in [DML](dml-language-reference) (recommended)
+  * [Python](https://www.python.org/), you can implement your algorithm in [PyDML](beginners-guide-to-dml-and-pydml) or using the [matrix class](http://apache.github.io/systemml/python-reference.html#matrix-class)
+* If you want to try out SystemML on single machine (for example, your laptop), consider
+  * using the above mentioned APIs with [Apache Spark](https://spark.apache.org/downloads.html) (recommended). Please refer to our [installation guide](http://systemml.apache.org/install-systemml.html).
+  * OR running it using java in [standalone mode](standalone-guide)
+
 ## Running SystemML
 
 * [Beginner's Guide For Python Users](beginners-guide-python) - Beginner's Guide for Python users.
diff --git a/docs/native-backend.md b/docs/native-backend.md
index 8f6886f..0f01fa4 100644
--- a/docs/native-backend.md
+++ b/docs/native-backend.md
@@ -244,3 +244,13 @@ The current set of dependencies other than MKL and OpenBLAS, are as follows:
 
 If CMake cannot detect your OpenBLAS installation, set the `OpenBLAS_HOME` environment variable to the OpenBLAS Home.
 
+
+## Debugging SystemML's native code
+
+To debug issues in SystemML's native code, please use the following flags:
+
+```
+$SPARK_HOME/bin/spark-submit --conf 'spark.driver.extraJavaOptions=-XX:OnError="gdb - %p"' SystemML.jar -f test_conv2d.dml -stats 10 -explain -nvargs stride=$stride pad=$pad out=out_cp.csv N=$N C=$C H=$H W=$W K=$K R=$R S=$S
+```
+
+When it fails, it will start a native debugger.
\ No newline at end of file
diff --git a/docs/python-reference.md b/docs/python-reference.md
index 4fd78fe..7d6af46 100644
--- a/docs/python-reference.md
+++ b/docs/python-reference.md
@@ -368,8 +368,9 @@ beta = ml.execute(script).get('B_out').toNumPy()
 ## mllearn API
 
 mllearn API is designed to be compatible with scikit-learn and MLLib.
-The classes that are part of mllearn API are LogisticRegression, LinearRegression, SVM, NaiveBayes 
-and [Caffe2DML](http://apache.github.io/systemml/beginners-guide-caffe2dml).
+The classes that are part of mllearn API are LogisticRegression, LinearRegression, SVM, NaiveBayes,
+[Keras2DML](http://apache.github.io/systemml/beginners-guide-keras2dml.html) 
+and [Caffe2DML](http://apache.github.io/systemml/beginners-guide-caffe2dml.html).
 
 The below code describes how to use mllearn API for training:
 
@@ -411,7 +412,8 @@ expects that labels have been converted to 1-based value.
 This avoids unnecessary decoding overhead for large dataset if the label columns has already been decoded.
 For scikit-learn API, there is no such requirement.
 
-The table below describes the parameter available for mllearn algorithms:
+The table below describes the parameter available for mllearn algorithms.
+These parameters are also specified in the usage section of the [Algorithms Reference](algorithms-reference):
 
 | Parameters | Description of the Parameters | LogisticRegression | LinearRegression | SVM | NaiveBayes |
 |----------------|-----------------------------------------------------------------------------------------------|-----------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ [...]
diff --git a/src/main/python/systemml/mllearn/keras2caffe.py b/src/main/python/systemml/mllearn/keras2caffe.py
index ca0fe3c..ce341fd 100755
--- a/src/main/python/systemml/mllearn/keras2caffe.py
+++ b/src/main/python/systemml/mllearn/keras2caffe.py
@@ -25,7 +25,12 @@
 import numpy as np
 import os
 import math
-from itertools import chain, imap
+from itertools import chain
+try:
+    from itertools import imap
+except ImportError:
+    # Support Python 3x
+    imap = map
 from ..converters import *
 from ..classloader import *
 import keras
@@ -112,7 +117,7 @@ def toKV(key, value):
 
 
 def _parseJSONObject(obj):
-    rootName = obj.keys()[0]
+    rootName = list(obj.keys())[0]
     ret = ['\n', rootName, ' {']
     for key in obj[rootName]:
         if isinstance(obj[rootName][key], dict):
@@ -172,7 +177,7 @@ def _parseKerasLayer(layer):
         layerArgs['bottom'] = _getBottomLayers(layer)
         layerArgs['top'] = layer.name
     if len(param) > 0:
-        paramName = param.keys()[0]
+        paramName = list(param.keys())[0]
         layerArgs[paramName] = param[paramName]
     ret = { 'layer': layerArgs }
     return [ret, _parseActivation(
@@ -194,20 +199,20 @@ specialLayers = {
     keras.layers.BatchNormalization: _parseBatchNorm
 }
 
+def getPadding(kernel_size, padding):
+    if padding.lower() == 'same':
+        return int(kernel_size/2)
+    elif padding.lower() == 'valid':
+        return 0
+    else:
+        raise ValueError('Unsupported padding:' + str(padding))
 
 def getConvParam(layer):
     stride = (1, 1) if layer.strides is None else layer.strides
-    padding = [
-        layer.kernel_size[0] /
-        2,
-        layer.kernel_size[1] /
-        2] if layer.padding == 'same' else [
-        0,
-        0]
     config = layer.get_config()
     return {'num_output': layer.filters, 'bias_term': str(config['use_bias']).lower(
     ), 'kernel_h': layer.kernel_size[0], 'kernel_w': layer.kernel_size[1], 'stride_h': stride[0], 'stride_w': stride[1],
-            'pad_h': padding[0], 'pad_w': padding[1]}
+            'pad_h': getPadding(layer.kernel_size[0], layer.padding), 'pad_w': getPadding(layer.kernel_size[1], layer.padding)}
 
 
 def getUpSamplingParam(layer):
@@ -216,15 +221,9 @@ def getUpSamplingParam(layer):
 
 def getPoolingParam(layer, pool='MAX'):
     stride = (1, 1) if layer.strides is None else layer.strides
-    padding = [
-        layer.pool_size[0] /
-        2,
-        layer.pool_size[1] /
-        2] if layer.padding == 'same' else [
-        0,
-        0]
     return {'pool': pool, 'kernel_h': layer.pool_size[0], 'kernel_w': layer.pool_size[1],
-            'stride_h': stride[0], 'stride_w': stride[1], 'pad_h': padding[0], 'pad_w': padding[1]}
+            'stride_h': stride[0], 'stride_w': stride[1], 'pad_h': getPadding(layer.pool_size[0], layer.padding),
+            'pad_w': getPadding(layer.pool_size[1], layer.padding)}
 
 
 def getRecurrentParam(layer):