You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/23 00:58:05 UTC

[systemml] branch master updated: [SYSTEMML-540] Added zero padding layer in Caffe2DML, Keras2DML and nn library

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new f48235f  [SYSTEMML-540] Added zero padding layer in Caffe2DML, Keras2DML and nn library
f48235f is described below

commit f48235f3b4ffd254e37570747d019d4c1f312a2d
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Fri Mar 22 17:57:51 2019 -0700

    [SYSTEMML-540] Added zero padding layer in Caffe2DML, Keras2DML and nn library
    
    - Updated the tests and the documentation.
    - This layer is required for ResNet-50 demo with Keras2DML.
---
 docs/beginners-guide-keras2dml.md                  |  6 +-
 docs/reference-guide-caffe2dml.md                  | 30 +++++++
 scripts/nn/layers/zero_pad2d.dml                   | 93 ++++++++++++++++++++++
 scripts/nn/test/grad_check.dml                     | 52 ++++++++++++
 scripts/nn/test/run_tests.dml                      |  2 +
 .../org/apache/sysml/parser/ParserWrapper.java     |  3 +
 src/main/proto/caffe/caffe.proto                   | 10 +++
 src/main/python/systemml/mllearn/estimators.py     |  4 +-
 src/main/python/systemml/mllearn/keras2caffe.py    | 38 ++++++++-
 src/main/python/tests/test_nn_numpy.py             | 20 ++++-
 .../scala/org/apache/sysml/api/dl/CaffeLayer.scala | 37 +++++++++
 .../org/apache/sysml/api/dl/CaffeNetwork.scala     |  1 +
 .../org/apache/sysml/api/dl/DMLGenerator.scala     |  4 +-
 .../sysml/test/integration/scripts/nn/NNTest.java  |  2 +
 14 files changed, 292 insertions(+), 10 deletions(-)

diff --git a/docs/beginners-guide-keras2dml.md b/docs/beginners-guide-keras2dml.md
index 60de360..4517be5 100644
--- a/docs/beginners-guide-keras2dml.md
+++ b/docs/beginners-guide-keras2dml.md
@@ -161,12 +161,16 @@ sysml_model.fit(features, labels)
 
 #### What optimizer and loss does Keras2DML use by default if `keras_model` is not compiled ?
 
-If the user does not `compile` the keras model, then we use cross entropy loss and SGD optimizer with nesterov momentum:
+If the user does not `compile` the keras model, then we throw an error.
+
+For classification applications, you can consider using cross entropy loss and SGD optimizer with nesterov momentum:
 
 ```python 
 keras_model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.95, decay=5e-4, nesterov=True))
 ```
 
+Please refer to [Keras's documentation](https://keras.io/losses/) for more detail.
+
 #### What is the learning rate schedule used ?
 
 Keras2DML does not support the `LearningRateScheduler` callback. 
diff --git a/docs/reference-guide-caffe2dml.md b/docs/reference-guide-caffe2dml.md
index 8e2ed1f..381b96d 100644
--- a/docs/reference-guide-caffe2dml.md
+++ b/docs/reference-guide-caffe2dml.md
@@ -139,6 +139,36 @@ layer {
 }
 ```
 
+### Padding Layer
+
+Invokes [nn/layers/zero_pad2d.dml](https://github.com/apache/systemml/blob/master/scripts/nn/layers/zero_pad2d.dml) layer.
+ 
+**Optional Parameters:**
+
+- top_pad: Padding for top side (default: 0).
+- bottom_pad: Padding for bottom side (default: 0).
+- left_pad: Padding for left side (default: 0).
+- right_pad: Padding for right side (default: 0).
+- right_pad: Padding for right side (default: 0).
+- pad_value: value to use for padding (default: 0). Only zero padding supported for now.
+
+**Sample Usage:**
+```
+layer {
+  name: "padding1"
+  type: "Padding"
+  bottom: "pool1"
+  top: "padding1"
+  padding_param  {
+    top_pad = 1
+    bottom_pad = 1
+    left_pad = 1
+    right_pad = 1
+    pad_value = 0
+  }
+}
+```
+
 
 ### Deconvolution Layer
 
diff --git a/scripts/nn/layers/zero_pad2d.dml b/scripts/nn/layers/zero_pad2d.dml
new file mode 100644
index 0000000..ac3eedf
--- /dev/null
+++ b/scripts/nn/layers/zero_pad2d.dml
@@ -0,0 +1,93 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Zero-padding layer for 2D input.
+ */
+
+forward = function(matrix[double] img, int C, int Hin, int Win, int top_pad, int bottom_pad, int left_pad, int right_pad)
+    return (matrix[double] img_padded) {
+  /*
+   * Computes the forward pass for a zero-padding layer.
+   *
+   * Inputs:
+   *  - img: Input images, of shape (N, C*Hin*Win)
+   *  - C: Number of input channels
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - top_pad: Padding for top and bottom side.
+   *  - bottom_pad: Padding for bottom side.
+   *  - left_pad: Padding for left side.
+   *  - right_pad: Padding for right side.
+   *
+   * Outputs:
+   *  - img_padded: The input images padded along the height and width
+   *      dimensions, of shape (N, C*(Hin+top_pad+bottom_pad)*(Win+left_pad+right_pad)).
+   */
+  N = nrow(img)
+  img_padded = matrix(0, rows=N, cols=C*(Hin+top_pad+bottom_pad)*(Win+left_pad+right_pad))  # zeros
+  img_index = 1
+  img_padded_index = 1
+  for(c in 1:C) {
+  	img_padded_index = img_padded_index + top_pad*(Win+left_pad+right_pad)
+    for(h in 1:Hin) {
+      img_padded_index = img_padded_index + left_pad
+  	  img_padded[,img_padded_index:(img_padded_index+Win-1)] = img[,img_index:(img_index+Win-1)] # vectorized over all images
+  	  img_padded_index = img_padded_index + Win + right_pad
+  	  img_index = img_index + Win
+  	}
+  	img_padded_index = img_padded_index + bottom_pad*(Win+left_pad+right_pad)
+  }
+}
+
+backward = function(matrix[double] dout, int C, int Hin, int Win, int top_pad, int bottom_pad, int left_pad, int right_pad)
+    return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a zero-padding layer.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*(Hin+top_pad+bottom_pad)*(Win+left_pad+right_pad)).
+   *  - C: Number of input channels
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - top_pad: Padding for top and bottom side.
+   *  - bottom_pad: Padding for bottom side.
+   *  - left_pad: Padding for left side.
+   *  - right_pad: Padding for right side.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   */
+  N = nrow(dout)
+  dX = matrix(0, rows=N, cols=C*Hin*Win)  # zeros
+  img_index = 1
+  img_padded_index = 1
+  for(c in 1:C) {
+  	img_padded_index = img_padded_index + top_pad*(Win+left_pad+right_pad)
+    for(h in 1:Hin) {
+      img_padded_index = img_padded_index + left_pad
+  	  dX[,img_index:(img_index+Win-1)] = dout[,img_padded_index:(img_padded_index+Win-1)] # vectorized over all images
+  	  img_padded_index = img_padded_index + Win + right_pad
+  	  img_index = img_index + Win
+  	}
+  	img_padded_index = img_padded_index + bottom_pad*(Win+left_pad+right_pad)
+  }
+}
diff --git a/scripts/nn/test/grad_check.dml b/scripts/nn/test/grad_check.dml
index a5da859..bb93db5 100644
--- a/scripts/nn/test/grad_check.dml
+++ b/scripts/nn/test/grad_check.dml
@@ -58,6 +58,7 @@ source("nn/test/max_pool2d_simple.dml") as max_pool2d_simple
 source("nn/test/util.dml") as test_util
 source("nn/util.dml") as util
 source("nn/layers/elu.dml") as elu
+source("nn/layers/zero_pad2d.dml") as zero_pad2d
 
 affine = function() {
   /*
@@ -1827,6 +1828,57 @@ relu = function() {
   }
 }
 
+zero_pad2d = function() {
+  /*
+   * Gradient check for the Zero-padding layer for 2D input.
+   *
+   * NOTE: This could result in a false-negative in which the test
+   * fails due to a kink being crossed in the nonlinearity.  This
+   * occurs when the tests, f(x-h) and f(x+h), end up on opposite
+   * sides of the zero threshold of max(0, fx).  For now, just run
+   * the tests again.  In the future, we can explicitly check for
+   * this and rerun the test automatically.
+   */
+  print("Grad checking the Zero-padding layer for 2D input with L2 loss.")
+
+  # Generate data
+  N = 48 # number of images
+  C = 3 # number channels
+  H = 32 # height
+  W = 64 # width
+  top_pad = 1
+  bottom_pad = 3
+  left_pad = 4
+  right_pad = 2
+  X = rand(rows=N, cols=C*H*W, min=-5, max=5)
+  y = rand(rows=N, cols=C*(H+top_pad+bottom_pad)*(W+left_pad+right_pad))
+
+  # Compute analytical gradients of loss wrt parameters
+  out = zero_pad2d::forward(X, C, H, W, top_pad, bottom_pad, left_pad, right_pad)
+  dout = l2_loss::backward(out, y)
+  dX = zero_pad2d::backward(dout, C, H, W, top_pad, bottom_pad, left_pad, right_pad)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = zero_pad2d::forward(X, C, H, W, top_pad, bottom_pad, left_pad, right_pad)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = zero_pad2d::forward(X, C, H, W, top_pad, bottom_pad, left_pad, right_pad)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
 rnn = function() {
   /*
    * Gradient check for the simple RNN layer.
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index 36f1583..5fc74ec 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -67,6 +67,8 @@ grad_check::sigmoid()
 grad_check::softmax()
 grad_check::softmax2d()
 grad_check::tanh()
+# TODO: Enable after adding a builtin function. The layer was tested by comparing its results with TensorFlow.
+# grad_check::zero_pad2d()
 print("")
 
 # Example model
diff --git a/src/main/java/org/apache/sysml/parser/ParserWrapper.java b/src/main/java/org/apache/sysml/parser/ParserWrapper.java
index 8dc9712..dfcfa65 100644
--- a/src/main/java/org/apache/sysml/parser/ParserWrapper.java
+++ b/src/main/java/org/apache/sysml/parser/ParserWrapper.java
@@ -53,6 +53,9 @@ public abstract class ParserWrapper {
 	 * @return corresponding statement block
 	 */
 	public static StatementBlock getStatementBlock(Statement current) {
+		if(current == null) {
+			throw new LanguageException("Error occured while parsing the script");
+		}
 		StatementBlock blk = null;
 		if(current instanceof ParForStatement) {
 			blk = new ParForStatementBlock();
diff --git a/src/main/proto/caffe/caffe.proto b/src/main/proto/caffe/caffe.proto
index 8d1d796..d8671d6 100644
--- a/src/main/proto/caffe/caffe.proto
+++ b/src/main/proto/caffe/caffe.proto
@@ -408,6 +408,7 @@ message LayerParameter {
   
   // Nike:
   optional UpsampleParameter upsample_param = 147;
+  optional PaddingParameter padding_param = 148;
 }
 
 // Message that stores parameters used to apply transformation
@@ -623,6 +624,15 @@ message ConvolutionParameter {
   optional bool force_nd_im2col = 17 [default = false];
 }
 
+// Nike:
+message PaddingParameter {
+  optional uint32 top_pad = 1 [default = 0]; // The top padding height (2D only)
+  optional uint32 bottom_pad = 2 [default = 0]; // The bottom padding height (2D only)
+  optional uint32 left_pad = 3 [default = 0]; // The left_pad padding width (2D only)
+  optional uint32 right_pad = 4 [default = 0]; // The right_pad padding width (2D only)
+  optional float pad_value = 5 [default = 0]; // only zero supported for now
+}
+
 message CropParameter {
   // To crop, elements of the first bottom are selected to fit the dimensions
   // of the second, reference bottom. The crop is configured by
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index d6aa8e8..8d1e164 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -36,9 +36,7 @@ from sklearn.metrics import accuracy_score, r2_score
 from py4j.protocol import Py4JError
 import traceback
 from sklearn.preprocessing import LabelEncoder
-import threading
-import time
-import math
+import threading, time, math, os
 
 from ..converters import *
 from ..classloader import *
diff --git a/src/main/python/systemml/mllearn/keras2caffe.py b/src/main/python/systemml/mllearn/keras2caffe.py
index 892deb2..2b97560 100755
--- a/src/main/python/systemml/mllearn/keras2caffe.py
+++ b/src/main/python/systemml/mllearn/keras2caffe.py
@@ -56,6 +56,12 @@ except ImportError:
 # - To add an activation, simply add the keras type to caffe type in supportedCaffeActivations.
 # - To add a layer, add the corresponding caffe layer type in supportedLayers. If the layer accepts parameters then update layerParamMapping too.
 # - The above logic is implemented in the function converKerasToCaffeNetwork
+#
+#
+# Example guide to add a new layer that does not have a weight and bias (eg: UpSampling2D or ZeroPadding2D):
+# - Add mapping of Keras class to Caffe layer in the supportedLayers map below
+# - Define a helper method that returns Caffe's layer parameter in JSON-like data structure. See getConvParam, getUpSamplingParam, getPaddingParam, etc.
+# - Add mapping of Keras class to Caffe layer parameter in the layerParamMapping map below
 # --------------------------------------------------------------------------------------
 
 supportedCaffeActivations = {
@@ -78,7 +84,8 @@ supportedLayers = {
     keras.layers.LSTM: 'LSTM',
     keras.layers.Flatten: 'Flatten',
     keras.layers.BatchNormalization: 'None',
-    keras.layers.Activation: 'None'
+    keras.layers.Activation: 'None',
+    keras.layers.ZeroPadding2D: 'Padding'
 }
 
 
@@ -199,6 +206,7 @@ specialLayers = {
     keras.layers.BatchNormalization: _parseBatchNorm
 }
 
+# Used by convolution and maxpooling to return the padding value as integer based on type 'same' and 'valid'
 def getPadding(kernel_size, padding):
     if padding.lower() == 'same':
         return int(kernel_size/2)
@@ -207,6 +215,7 @@ def getPadding(kernel_size, padding):
     else:
         raise ValueError('Unsupported padding:' + str(padding))
 
+# Helper method to return Caffe's ConvolutionParameter in JSON-like data structure
 def getConvParam(layer):
     stride = (1, 1) if layer.strides is None else layer.strides
     config = layer.get_config()
@@ -215,17 +224,37 @@ def getConvParam(layer):
             'pad_h': getPadding(layer.kernel_size[0], layer.padding), 'pad_w': getPadding(layer.kernel_size[1], layer.padding)}
 
 
+# Helper method to return newly added UpsampleParameter
+# (search for UpsampleParameter in the file src/main/proto/caffe/caffe.proto) in JSON-like data structure
 def getUpSamplingParam(layer):
     return {'size_h': layer.size[0], 'size_w': layer.size[1]}
 
+# Used by padding to extract different types of possible padding:
+# int: the same symmetric padding is applied to height and width.
+# tuple of 2 ints: interpreted as two different symmetric padding values for height and width: (symmetric_height_pad, symmetric_width_pad)
+# tuple of 2 tuples of 2 ints: interpreted as  ((top_pad, bottom_pad), (left_pad, right_pad))
+def getPaddingTuple(padding):
+    return [padding, padding] if isinstance(padding, int) else [padding[0], padding[1]]
+
+# Helper method to return newly added PaddingParameter
+# (search for UpsampleParameter in the file src/main/proto/caffe/caffe.proto) in JSON-like data structure
+def getPaddingParam(layer):
+    if isinstance(layer.padding, int):
+        padding = getPaddingTuple(layer.padding) + getPaddingTuple(layer.padding)
+    elif hasattr(layer.padding, '__len__') and len(layer.padding) == 2:
+        padding = getPaddingTuple(layer.padding[0]) + getPaddingTuple(layer.padding[1])
+    else:
+        raise ValueError('padding should be either an int, a tuple of 2 ints or or a tuple of 2 tuples of 2 ints. Found: ' + str(layer.padding))
+    return {'top_pad': padding[0], 'bottom_pad': padding[1], 'left_pad': padding[2], 'right_pad': padding[3], 'pad_value':0}
 
+# Helper method to return Caffe's PoolingParameter in JSON-like data structure
 def getPoolingParam(layer, pool='MAX'):
     stride = (1, 1) if layer.strides is None else layer.strides
     return {'pool': pool, 'kernel_h': layer.pool_size[0], 'kernel_w': layer.pool_size[1],
             'stride_h': stride[0], 'stride_w': stride[1], 'pad_h': getPadding(layer.pool_size[0], layer.padding),
             'pad_w': getPadding(layer.pool_size[1], layer.padding)}
 
-
+# Helper method to return Caffe's RecurrentParameter in JSON-like data structure
 def getRecurrentParam(layer):
     if (not layer.use_bias):
         raise Exception('Only use_bias=True supported for recurrent layers')
@@ -236,14 +265,13 @@ def getRecurrentParam(layer):
     return {'num_output': layer.units, 'return_sequences': str(
         layer.return_sequences).lower()}
 
-
+# Helper method to return Caffe's InnerProductParameter in JSON-like data structure
 def getInnerProductParam(layer):
     if len(layer.output_shape) != 2:
         raise Exception('Only 2-D input is supported for the Dense layer in the current implementation, but found '
                         + str(layer.input_shape) + '. Consider adding a Flatten before ' + str(layer.name))
     return {'num_output': layer.units}
 
-# TODO: Update AveragePooling2D when we add maxpooling support
 layerParamMapping = {
     keras.layers.InputLayer: lambda l:
     {'data_param': {'batch_size': l.batch_size}},
@@ -259,6 +287,8 @@ layerParamMapping = {
     {'convolution_param': getConvParam(l)},
     keras.layers.UpSampling2D: lambda l:
     {'upsample_param': getUpSamplingParam(l)},
+    keras.layers.ZeroPadding2D: lambda l:
+    {'padding_param': getPaddingParam(l)},
     keras.layers.Conv2D: lambda l:
     {'convolution_param': getConvParam(l)},
     keras.layers.MaxPooling2D: lambda l:
diff --git a/src/main/python/tests/test_nn_numpy.py b/src/main/python/tests/test_nn_numpy.py
index 43e3303..d30c692 100644
--- a/src/main/python/tests/test_nn_numpy.py
+++ b/src/main/python/tests/test_nn_numpy.py
@@ -44,7 +44,7 @@ import unittest
 
 import numpy as np
 from keras.models import Sequential
-from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout, Flatten, LSTM, UpSampling2D, SimpleRNN, Activation
+from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout, Flatten, LSTM, UpSampling2D, SimpleRNN, Activation, ZeroPadding2D
 from keras.optimizers import SGD
 from keras import backend as K
 from keras.models import Model
@@ -276,5 +276,23 @@ class TestNNLibrary(unittest.TestCase):
     def test_upsampling_backward(self):
         self.failUnless(test_backward(UpSampling2D(size=(2, 2), input_shape=(3, 64, 32))))
 
+    def test_zeropadding_forward(self):
+        self.failUnless(test_forward(ZeroPadding2D(padding=1, input_shape=(3, 64, 32))))
+
+    def test_zeropadding_backward(self):
+        self.failUnless(test_backward(ZeroPadding2D(padding=1, input_shape=(3, 64, 32))))
+
+    def test_zeropadding_forward1(self):
+        self.failUnless(test_forward(ZeroPadding2D(padding=(1, 2), input_shape=(3, 64, 32))))
+
+    def test_zeropadding_backward1(self):
+        self.failUnless(test_backward(ZeroPadding2D(padding=(1, 2), input_shape=(3, 64, 32))))
+
+    def test_zeropadding_forward2(self):
+        self.failUnless(test_forward(ZeroPadding2D(padding=((3, 2), (1, 3)), input_shape=(3, 64, 32))))
+
+    def test_zeropadding_backward2(self):
+        self.failUnless(test_backward(ZeroPadding2D(padding=((3, 2), (1, 3)), input_shape=(3, 64, 32))))
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index cd17af5..62323d1 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -671,6 +671,43 @@ class TanH(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extend
   // -------------------------------------------------
 }
 
+class Padding(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
+  override def sourceFileName                       = {
+    if(param.getPaddingParam.getPadValue == 0) "zero_pad2d"
+    else throw new DMLRuntimeException("Only pad_value = 0 is supported. Found: " + param.getPaddingParam.getPadValue)
+  }
+  override def init(dmlScript: StringBuilder): Unit = {}
+  
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = {
+    if(skipPadding) {
+      assign(dmlScript, out, X)
+    }
+    else {
+      invokeForward(dmlScript, List[String](out), X, numChannels, Hin, Win, top_pad, bottom_pad, left_pad, right_pad)
+    }
+  }
+  override def backward(dmlScript: StringBuilder, outSuffix: String): Unit = {
+    if(skipPadding) {
+      assignDoutToDX(dmlScript, outSuffix)
+    }
+    else {
+      invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, numChannels, Hin, Win, top_pad, bottom_pad, left_pad, right_pad)
+    }
+  }
+  override def weightShape(): Array[Int]                             = null
+  override def biasShape(): Array[Int]                               = null
+  override def outputShape = (numChannels, int_add(Hin, top_pad, bottom_pad), int_add(Win, left_pad, right_pad))
+  def skipPadding = param.getPaddingParam.getTopPad == 0 && param.getPaddingParam.getBottomPad == 0 && 
+    param.getPaddingParam.getLeftPad == 0 && param.getPaddingParam.getRightPad == 0
+  def top_pad = param.getPaddingParam.getTopPad.toString
+  def bottom_pad = param.getPaddingParam.getBottomPad.toString
+  def left_pad = param.getPaddingParam.getLeftPad.toString
+  def right_pad = param.getPaddingParam.getRightPad.toString
+  def numChannels  = bottomLayerOutputShape._1
+  def Hin          = bottomLayerOutputShape._2
+  def Win          = bottomLayerOutputShape._3
+}
+
 class ReLU(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
   // TODO: Leaky ReLU: negative_slope [default 0]: specifies whether to leak the negative part by multiplying it with the slope value rather than setting it to 0.
   // -------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
index d3449f3..278b07b 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -250,6 +250,7 @@ class CaffeNetwork(netFilePath: String, val currentPhase: Phase, var numChannels
       case "rnn"             => new RNN(param, id, this)
       case "lstm"            => new LSTM(param, id, this)
       case "flatten"         => new Flatten(param, id, this)
+      case "padding"         => new Padding(param, id, this)
       case _                 => throw new LanguageException("Layer of type " + param.getType + " is not supported")
     }
   }
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 59c75ad..8597efd 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -51,6 +51,8 @@ trait BaseDMLGenerator {
     try { (v1.toDouble * v2.toDouble * v3.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "*" + v2 + "*" + v3 + ")" }
   def int_mult(v1: String, v2: String): String =
     try { (v1.toDouble * v2.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "*" + v2 + ")" }
+  def int_add(v1: String, v2: String, v3: String): String =
+    try { (v1.toDouble + v2.toDouble + v3.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "+" + v2 + "+" + v3 + ")" }
   def isNumber(x: String): Boolean                                                   = x forall Character.isDigit
   def transpose(x: String): String                                                   = "t(" + x + ")"
   def write(varName: String, fileName: String, format: String): String               = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
@@ -246,7 +248,7 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator {
     // Append source statements for layers as well as solver
     source(net, solver, if (isTraining) Array[String]("l1_reg") else null)
     source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
-    source(dmlScript, numTabs, "util", Caffe2DML.nnDir)
+    source(dmlScript, numTabs, "util", Caffe2DML.nnDir)  
 
     if (isTraining) {
       // Append external built-in function headers:
diff --git a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
index 92b9f67..4bcc2b0 100644
--- a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
@@ -27,6 +27,8 @@ import org.junit.Test;
 
 /**
  * Test the SystemML deep learning library, `nn`.
+ * 
+ * mvn -Dit.test=org.apache.sysml.test.integration.scripts.nn.NNTest verify
  */
 public class NNTest extends MLContextTestBase {