You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/02/27 18:23:06 UTC

[2/2] systemml git commit: [SYSTEMML-1872] Added average pooling and upsampling layers

[SYSTEMML-1872] Added average pooling and upsampling layers

- Added avg_pool and avg_pool_backward builtin functions.
- The above builtin functions are supported on both CPU and GPU.
- Also, added compare_backend tests for the above function to compare the results of CP operator with that of CuDNN.
- Added avg_pool2d_builtin.dml and upsample2d.dml in the nn library.
- Added gradcheck tests for the above layers.
- Supported average pooling and upsampling in Keras2DML as well as Caffe2DML.
- Tested the results of ResNet with average pooling on real-world images as a sanity check.
- Also, tested upsampling layer by comparing the results with that returned by Keras.

Closes #734.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/54a11eed
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/54a11eed
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/54a11eed

Branch: refs/heads/master
Commit: 54a11eed33529591ec8c21e5f404f4bbea1e8235
Parents: d16cc7c
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue Feb 27 10:20:08 2018 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue Feb 27 10:21:57 2018 -0800

----------------------------------------------------------------------
 docs/dml-language-reference.md                  |  18 +-
 docs/reference-guide-caffe2dml.md               |  26 +-
 scripts/nn/layers/avg_pool2d_builtin.dml        | 103 +++++++
 scripts/nn/layers/upsample2d.dml                |  75 +++++
 .../nn/test/compare_backends/test_avgpool.dml   |  24 ++
 .../nn/test/compare_backends/test_avgpool.sh    |  48 ++++
 .../test/compare_backends/test_avgpool_bwd.dml  |  25 ++
 .../test/compare_backends/test_avgpool_bwd.sh   |  49 ++++
 scripts/nn/test/grad_check.dml                  |  92 ++++++
 scripts/nn/test/run_tests.dml                   |   2 +
 .../org/apache/sysml/hops/ConvolutionOp.java    |  47 ++--
 src/main/java/org/apache/sysml/hops/Hop.java    |   4 +-
 .../apache/sysml/lops/ConvolutionTransform.java |   9 +-
 .../sysml/parser/BuiltinFunctionExpression.java |  11 +-
 .../org/apache/sysml/parser/DMLTranslator.java  |  16 +-
 .../org/apache/sysml/parser/Expression.java     |   2 +-
 .../instructions/CPInstructionParser.java       |   2 +
 .../instructions/GPUInstructionParser.java      |   2 +
 .../cp/ConvolutionCPInstruction.java            |  30 +-
 .../gpu/ConvolutionGPUInstruction.java          |  29 +-
 .../spark/ConvolutionSPInstruction.java         |  12 +-
 .../runtime/matrix/data/LibMatrixCuDNN.java     |  41 +--
 .../data/LibMatrixCuDNNPoolingDescriptors.java  |  22 +-
 .../sysml/runtime/matrix/data/LibMatrixDNN.java |  39 ++-
 .../matrix/data/LibMatrixDNNPooling.java        | 279 +++++++++++++++----
 src/main/proto/caffe/caffe.proto                |   9 +
 src/main/python/systemml/mllearn/keras2caffe.py |   9 +-
 src/main/python/tests/test_nn_numpy.py          |  20 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    |  51 +++-
 .../org/apache/sysml/api/dl/CaffeNetwork.scala  |   7 +-
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |   2 +
 31 files changed, 937 insertions(+), 168 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 355b507..d0943d6 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1509,15 +1509,15 @@ The images are assumed to be stored NCHW format, where N = batch size, C = #chan
 Hence, the images are internally represented as a matrix with dimension (N, C * H * W).
 
 
-| Function name              | Input matrices | Dimension of first input matrix                           | Dimension of second input matrix (if applicable)          | Dimension of output matrix                                 | Input Parameters                                                                                                                                                                              | Notes                                                                                                                                             |
-|----------------------------|----------------|-----------------------------------------------------------|-----------------------------------------------------------|------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------|
-| conv2d                     | input, filter  | [batch_size X num_channels* height_image* width_image]    | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Performs 2D convolution operation                                                                                                                 |
-| conv2d_backward_filter     | input, dout    | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels_out* height_out* width_out]    | [num_filters X num_channels* height_filter* width_filter]  | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt filter of 2D convolution                                                                                               |
-| conv2d_backward_data       | filter, dout   | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]    | [batch_size X num_channels* height_image* width_image]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt input of 2D convolution                                                                                                |
-| max_pool                   | input          | [batch_size X num_channels* height_image* width_image]    |                                                           | [batch_size X num_channels* height_out* width_out]         | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                                   | Performs max pooling operation                                                                                                                    |
-| max_pool_backward          | input, dout    | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels* height_out* width_out]        | [batch_size X num_channels* height_image* width_image]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                                   | Computes the gradients wrt input of 2D maxpooling                                                                                                 |
-| bias_add                   | input, bias    | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]     |                                                                                                                                                                                               | Adds the bias (row vector of size num_channels) to input with the given num_channels                                                              |
-| bias_multiply              | input, bias    | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]     |                                                                                                                                                                                               | Multiplies the bias (row vector of size num_channels) to input with the given num_channels                                                        |
+| Function name                               | Input matrices | Dimension of first input matrix                           | Dimension of second input matrix (if applicable)          | Dimension of output matrix                                 | Input Parameters                                                                                                                                                                              | Notes                                                                                                                                             |
+|---------------------------------------------|----------------|-----------------------------------------------------------|-----------------------------------------------------------|------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------|
+| conv2d                                      | input, filter  | [batch_size X num_channels* height_image* width_image]    | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Performs 2D convolution operation                                                                                                                 |
+| conv2d_backward_filter                      | input, dout    | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels_out* height_out* width_out]    | [num_filters X num_channels* height_filter* width_filter]  | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt filter of 2D convolution                                                                                               |
+| conv2d_backward_data                        | filter, dout   | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]    | [batch_size X num_channels* height_image* width_image]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes the gradients wrt input of 2D convolution                                                                                                |
+| max_pool, avg_pool                          | input          | [batch_size X num_channels* height_image* width_image]    |                                                           | [batch_size X num_channels* height_out* width_out]         | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                                   | Performs max/average pooling operation                                                                                                            |
+| max_pool_backward, avg_pool_backward        | input, dout    | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels* height_out* width_out]        | [batch_size X num_channels* height_image* width_image]     | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                                   | Computes the gradients wrt input of 2D max pooling, average pooling                                                                               |
+| bias_add                                    | input, bias    | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]     |                                                                                                                                                                                               | Adds the bias (row vector of size num_channels) to input with the given num_channels                                                              |
+| bias_multiply                               | input, bias    | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]     |                                                                                                                                                                                               | Multiplies the bias (row vector of size num_channels) to input with the given num_channels                                                        |
 
 
 Examples:

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/docs/reference-guide-caffe2dml.md
----------------------------------------------------------------------
diff --git a/docs/reference-guide-caffe2dml.md b/docs/reference-guide-caffe2dml.md
index 0e191dd..8e2ed1f 100644
--- a/docs/reference-guide-caffe2dml.md
+++ b/docs/reference-guide-caffe2dml.md
@@ -97,7 +97,7 @@ Invokes [nn/layers/max_pool2d_builtin.dml](https://github.com/apache/systemml/bl
 - kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
 
 **Optional Parameters:**
-- pool (default MAX): the pooling method. Currently, we only support MAX, not AVE, or STOCHASTIC.
+- pool (default MAX): the pooling method. Currently, we only support MAX and AVE, not STOCHASTIC.
 - pad (or pad_h and pad_w) (default 0): specifies the number of pixels to (implicitly) add to each side of the input
 - stride (or stride_h and stride_w) (default 1): specifies the intervals at which to apply the filters to the input
 
@@ -116,6 +116,30 @@ layer {
 }
 ```
 
+
+### Upsampling Layer
+
+Invokes [nn/layers/upsample2d.dml](https://github.com/apache/systemml/blob/master/scripts/nn/layers/upsample2d.dml) layer.
+ 
+**Required Parameters:**
+
+- size_h and size_w: specifies the upsampling factor for rows and columns.
+
+**Sample Usage:**
+```
+layer {
+  name: "upsample1"
+  type: "Upsample"
+  bottom: "pool1"
+  top: "upsample1"
+  upsample_param  {
+    size_h = 2
+    size_w = 2
+  }
+}
+```
+
+
 ### Deconvolution Layer
 
 Invokes [nn/layers/conv2d_transpose.dml](https://github.com/apache/systemml/blob/master/scripts/nn/layers/conv2d_transpose.dml)

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/layers/avg_pool2d_builtin.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/avg_pool2d_builtin.dml b/scripts/nn/layers/avg_pool2d_builtin.dml
new file mode 100644
index 0000000..6615c99
--- /dev/null
+++ b/scripts/nn/layers/avg_pool2d_builtin.dml
@@ -0,0 +1,103 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 2D Average Pooling layer.
+ *
+ * This implementation uses a built-in operator for higher performance.
+ */
+
+forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
+                   int strideh, int stridew, int padh, int padw)
+    return (matrix[double] out, int Hout, int Wout) {
+  /*
+   * Computes the forward pass for a 2D spatial average pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * This implementation uses a built-in operator for higher
+   * performance.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
+  Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1))
+
+  # Max pooling - built-in implementation
+  out = avg_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf],
+                 stride=[strideh,stridew], padding=[padh,padw])
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X,
+                    int C, int Hin, int Win, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
+    return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a 2D spatial average pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of
+   *      shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   */
+  N = nrow(X)
+
+  # Gradient of average pooling
+  dX = avg_pool_backward(X, dout, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf],
+                         stride=[strideh,stridew], padding=[padh,padw])
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/layers/upsample2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/upsample2d.dml b/scripts/nn/layers/upsample2d.dml
new file mode 100644
index 0000000..f1be552
--- /dev/null
+++ b/scripts/nn/layers/upsample2d.dml
@@ -0,0 +1,75 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Upsampling layer for 2D inputs.
+ *
+ * Repeats the rows and columns of the data by size_h and size_w respectively.
+ */
+
+forward = function(matrix[double] X, int C, int Hin, int Win, int size_h, int size_w)
+    return (matrix[double] out) {
+  /*
+   * Computes the forward pass for a Upsampling layer.
+   *
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - size_h: upsampling factor for rows.
+   *  - size_w: upsampling factor for columns.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hout*Wout), where Hout = Hin*size_h, and Wout = Win * size_w.
+   */
+  N = nrow(X)
+  Hout = size_h*Hin
+  Wout = size_w*Win
+  emptyInput = matrix(0, rows=N, cols=C*Hout*Wout)
+  out = avg_pool_backward(emptyInput, X, input_shape=[N,C,Hout,Wout], pool_size=[size_h,size_w], stride=[size_h,size_w], padding=[0,0])
+  out = out * size_h * size_w
+}
+
+backward = function(matrix[double] dout, int C, int Hin, int Win, int size_h, int size_w)
+    return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a Upsampling layer.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream.
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - size_h: upsampling factor for rows.
+   *  - size_w: upsampling factor for columns.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of same shape as `X`.
+   */
+   N = nrow(dout)
+   Hout = size_h*Hin
+   Wout = size_w*Win
+   dX = avg_pool(dout, input_shape=[N,C,Hout,Wout], pool_size=[size_h,size_w], stride=[size_h,size_w], padding=[0,0])
+   dX = dX * size_h * size_w
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/test/compare_backends/test_avgpool.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_avgpool.dml b/scripts/nn/test/compare_backends/test_avgpool.dml
new file mode 100644
index 0000000..caf365d
--- /dev/null
+++ b/scripts/nn/test/compare_backends/test_avgpool.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read("input.mtx")
+out = avg_pool(X, input_shape=[$N,$C,$H,$W], pool_size=[$pool,$pool], stride=[$stride,$stride], padding=[$pad,$pad])
+write(out, $out, format="csv")

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/test/compare_backends/test_avgpool.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_avgpool.sh b/scripts/nn/test/compare_backends/test_avgpool.sh
new file mode 100644
index 0000000..40866ef
--- /dev/null
+++ b/scripts/nn/test/compare_backends/test_avgpool.sh
@@ -0,0 +1,48 @@
+#!/usr/bin/bash
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+jars='systemml-*-extra.jar'
+
+# N = Number of images, C = number of channels, H = height, W = width
+N=5
+C=3
+H=28
+W=28
+for sparsity in 0.1 0.2 0.5 0.6 0.9
+do
+	# Generating the data
+	$SPARK_HOME/bin/spark-submit SystemML.jar -f gen_maxpool.dml -nvargs sp=$sparsity N=$N C=$C H=$H W=$W
+	for stride in 1 2 3
+	do
+		for pad in 0 1 2
+		do
+			# Running a test in CPU mode
+			$SPARK_HOME/bin/spark-submit SystemML.jar -f test_avgpool.dml -nvargs stride=$stride pad=$pad out=out_cp.csv N=$N C=$C H=$H W=$W pool=3
+			# Running a test in GPU mode
+			$SPARK_HOME/bin/spark-submit --jars $jars SystemML.jar -f test_avgpool.dml -stats -gpu force -nvargs stride=$stride pad=$pad out=out_gpu.csv N=$N C=$C H=$H W=$W pool=3
+			# Comparing the CPU vs GPU results to make sure they are the same
+			$SPARK_HOME/bin/spark-submit SystemML.jar -f compare.dml -args out_cp.csv out_gpu.csv "avgpool:sparsity="$sparsity",stride="$stride",pad="$pad
+			rm -rf out_cp.csv out_gpu.csv out_cp.csv.mtd out_gpu.csv.mtd
+		done
+	done
+	rm -rf input.mtx input.mtx.mtd
+done

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/test/compare_backends/test_avgpool_bwd.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_avgpool_bwd.dml b/scripts/nn/test/compare_backends/test_avgpool_bwd.dml
new file mode 100644
index 0000000..938cc6a
--- /dev/null
+++ b/scripts/nn/test/compare_backends/test_avgpool_bwd.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read("input.mtx")
+dout = read("dout.mtx")
+out = avg_pool_backward(X, dout, input_shape=[$N,$C,$H,$W], pool_size=[$pool,$pool], stride=[$stride,$stride], padding=[$pad,$pad])
+write(out, $out, format="csv")

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/test/compare_backends/test_avgpool_bwd.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_avgpool_bwd.sh b/scripts/nn/test/compare_backends/test_avgpool_bwd.sh
new file mode 100644
index 0000000..4879057
--- /dev/null
+++ b/scripts/nn/test/compare_backends/test_avgpool_bwd.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/bash
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+jars='systemml-*-extra.jar'
+
+# N = Number of images, C = number of channels, H = height, W = width
+N=5
+C=3
+H=28
+W=28
+for sparsity in 0.1 0.2 0.5 0.6 0.9
+do
+	# Generating the data
+
+	for stride in 1 2 3
+	do
+		for pad in 0 1 2
+		do
+			$SPARK_HOME/bin/spark-submit SystemML.jar -f gen_maxpool_bwd.dml -nvargs sp=$sparsity N=$N C=$C H=$H W=$W pool=3 stride=$stride pad=$pad
+			# Running a test in CPU mode
+			$SPARK_HOME/bin/spark-submit SystemML.jar -f test_avgpool_bwd.dml -nvargs stride=$stride pad=$pad out=out_cp.csv N=$N C=$C H=$H W=$W pool=3
+			# Running a test in GPU mode
+			$SPARK_HOME/bin/spark-submit --jars $jars SystemML.jar -f test_avgpool_bwd.dml -stats -gpu force -nvargs stride=$stride pad=$pad out=out_gpu.csv N=$N C=$C H=$H W=$W pool=3
+			# Comparing the CPU vs GPU results to make sure they are the same
+			$SPARK_HOME/bin/spark-submit SystemML.jar -f compare.dml -args out_cp.csv out_gpu.csv "avgpool_bwd:sparsity="$sparsity",stride="$stride",pad="$pad
+			rm -rf out_cp.csv out_gpu.csv out_cp.csv.mtd out_gpu.csv.mtd
+		done
+	done
+	rm -rf input.mtx input.mtx.mtd
+done

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/grad_check.dml b/scripts/nn/test/grad_check.dml
index 47c6499..515bc1f 100644
--- a/scripts/nn/test/grad_check.dml
+++ b/scripts/nn/test/grad_check.dml
@@ -43,6 +43,8 @@ source("nn/layers/log_loss.dml") as log_loss
 source("nn/layers/lstm.dml") as lstm
 source("nn/layers/max_pool2d.dml") as max_pool2d
 source("nn/layers/max_pool2d_builtin.dml") as max_pool2d_builtin
+source("nn/layers/avg_pool2d_builtin.dml") as avg_pool2d_builtin
+source("nn/layers/upsample2d.dml") as upsample2d
 source("nn/layers/relu.dml") as relu
 source("nn/layers/rnn.dml") as rnn
 source("nn/layers/scale_shift1d.dml") as scale_shift1d
@@ -1642,6 +1644,60 @@ max_pool2d_builtin = function() {
   }
 }
 
+avg_pool2d_builtin = function() {
+  /*
+   * Gradient check for the 2D avg pooling layer.
+   */
+  print("Grad checking the built-in 2D avg pooling layer with L2 loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 4  # input height
+  Win = 4  # input width
+  Hf = 2  # pool filter height
+  Wf = 2  # pool filter width
+  stride = 2
+  X = rand(rows=N, cols=C*Hin*Win)
+
+  for (pad in 0:1) {
+    print(" - Grad checking w/ pad="+pad+".")
+    Hout = as.integer(floor((Hin + 2 * pad - Hf) / stride + 1))
+    Wout = as.integer(floor((Win + 2 * pad - Wf) / stride + 1))
+    y = rand(rows=N, cols=C*Hout*Wout)
+
+    # Compute analytical gradients of loss wrt parameters
+    [out, Hout, Wout] = avg_pool2d_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride,
+                                                    pad, pad)
+    dout = l2_loss::backward(out, y)
+    dX = avg_pool2d_builtin::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, stride, stride,
+                                      pad, pad)
+
+    # Grad check
+    h = 1e-5
+    for (i in 1:nrow(X)) {
+      for (j in 1:ncol(X)) {
+        # Compute numerical derivative
+        old = as.scalar(X[i,j])
+        X[i,j] = old - h
+        [outmh, Hout, Wout] = avg_pool2d_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride,
+                                                          pad, pad)
+        lossmh = l2_loss::forward(outmh, y)
+        X[i,j] = old + h
+        [outph, Hout, Wout] = avg_pool2d_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride,
+                                                          pad, pad)
+        lossph = l2_loss::forward(outph, y)
+        X[i,j] = old  # reset
+        dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+        # Check error
+        rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+      }
+    }
+  }
+}
+
+
 max_pool2d_simple = function() {
   /*
    * Gradient check for the simple reference 2D max pooling layer.
@@ -1694,6 +1750,42 @@ max_pool2d_simple = function() {
   }
 }
 
+upsample2d = function() {
+  print("Grad checking the upsample2d layer with L2 loss.")
+
+  C=2; Hin=3; Win=3; size_h=2; size_w=2
+  # Generate data
+  N = 3 # num examples
+  M = C*Hin*Win # num neurons
+  X = rand(rows=N, cols=M, min=-5, max=5)
+  y = rand(rows=N, cols=M*size_h*size_w)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = upsample2d::forward(X, C, Hin, Win, size_h, size_w)
+  dout = l2_loss::backward(out, y)
+  dX = upsample2d::backward(dout, C, Hin, Win, size_h, size_w)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = upsample2d::forward(X, C, Hin, Win, size_h, size_w)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = upsample2d::forward(X, C, Hin, Win, size_h, size_w)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
 relu = function() {
   /*
    * Gradient check for the ReLU nonlinearity layer.

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/scripts/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/run_tests.dml b/scripts/nn/test/run_tests.dml
index fd4f0fa..fd6e18e 100644
--- a/scripts/nn/test/run_tests.dml
+++ b/scripts/nn/test/run_tests.dml
@@ -56,6 +56,8 @@ grad_check::lstm()
 grad_check::max_pool2d()
 grad_check::max_pool2d_builtin()
 grad_check::max_pool2d_simple()
+grad_check::avg_pool2d_builtin()
+grad_check::upsample2d()
 grad_check::relu()
 grad_check::rnn()
 grad_check::scale_shift1d()

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index fce4958..410a83a 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -120,6 +120,8 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		{
 			case MAX_POOLING:
 			case MAX_POOLING_BACKWARD:
+			case AVG_POOLING:
+			case AVG_POOLING_BACKWARD:
 			case DIRECT_CONV2D:
 			case DIRECT_CONV2D_BACKWARD_DATA:
 			case DIRECT_CONV2D_BACKWARD_FILTER:
@@ -151,7 +153,8 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 	
 	private int getNumExpectedInputs() {
 		switch(op) {
-			case MAX_POOLING_BACKWARD: 
+			case MAX_POOLING_BACKWARD:
+			case AVG_POOLING_BACKWARD:
 			case DIRECT_CONV2D:
 			case DIRECT_CONV2D_BACKWARD_FILTER:
 			case DIRECT_CONV2D_BACKWARD_DATA:
@@ -206,24 +209,24 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 	}
 	
 	/**
-	 * Returns the output lop of maxpool operation with same parameters as this hop.
+	 * Returns the output lop of max_pool/avg_pool operation with same parameters as this hop.
 	 * If corresponding output lop is not found or if this is not a max_pool_backward operation, this function returns null
 	 * 
-	 * @return output lop of maxpool operation with same parameters as this hop
+	 * @return output lop of max_pool/avg_pool operation with same parameters as this hop
 	 * @throws HopsException if error 
 	 * @throws LopsException if error
 	 */
 	private Lop getMaxPoolOutputLop() throws HopsException, LopsException {
-		if(op != ConvOp.MAX_POOLING_BACKWARD)
-			return null;
-		
-		Hop inputImage = getInput().get(0);
-		for(Hop tmpParent : inputImage.getParent()) {
-			if(!(tmpParent instanceof ConvolutionOp))
-				continue;
-			ConvolutionOp parent = (ConvolutionOp) tmpParent;
-			if(parent.getOp() == ConvOp.MAX_POOLING && isPoolingParametersEqualAndKnown(parent._cachedParams, _cachedParams)) {
-				return parent.constructLops();
+		if(op == ConvOp.MAX_POOLING_BACKWARD || op == ConvOp.AVG_POOLING_BACKWARD) {
+			ConvOp opType = (op == ConvOp.MAX_POOLING_BACKWARD) ? ConvOp.MAX_POOLING : ConvOp.AVG_POOLING;
+			Hop inputImage = getInput().get(0);
+			for(Hop tmpParent : inputImage.getParent()) {
+				if(!(tmpParent instanceof ConvolutionOp))
+					continue;
+				ConvolutionOp parent = (ConvolutionOp) tmpParent;
+				if(parent.getOp() == opType && isPoolingParametersEqualAndKnown(parent._cachedParams, _cachedParams)) {
+					return parent.constructLops();
+				}
 			}
 		}
 		return null;
@@ -484,11 +487,11 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 			// im2col operation preserves the worst-case sparsity of the input.
 			cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", getInput().get(0).getSparsity()));
 		}
-		else if(getOp() == ConvOp.MAX_POOLING) {
+		else if(getOp() == ConvOp.MAX_POOLING || getOp() == ConvOp.AVG_POOLING) {
 			// Account for potential sparse-to-dense conversion of atleast 1 input row
 			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
 		}
-		else if(getOp() == ConvOp.MAX_POOLING_BACKWARD) {
+		else if(getOp() == ConvOp.MAX_POOLING_BACKWARD || getOp() == ConvOp.AVG_POOLING_BACKWARD) {
 			// Account for potential sparse-to-dense conversion of atleast 1 input + dout row
 			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
 			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CPQ"));
@@ -569,7 +572,7 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 	ConvolutionParameters parseInput() throws DMLRuntimeException {
 		
 		Hop imageHeightHop = null; Hop filterHeightHop = null;
-		if(op == ConvOp.MAX_POOLING_BACKWARD 
+		if(op == ConvOp.MAX_POOLING_BACKWARD || op == ConvOp.AVG_POOLING_BACKWARD 
 				|| op == ConvOp.DIRECT_CONV2D 
 				|| op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER
 				|| op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
@@ -606,10 +609,10 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		}
 		
 		if(INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP) {
-			boolean isMaxPool = getOp() == ConvOp.MAX_POOLING;
+			boolean isPool = (getOp() == ConvOp.MAX_POOLING || getOp() == ConvOp.AVG_POOLING);
 			boolean isConv = getOp() == ConvOp.DIRECT_CONV2D;
 			boolean unknownCHWPQ = _cachedParams.C < 0 || _cachedParams.H < 0 || _cachedParams.W < 0 || _cachedParams.P < 0 || _cachedParams.Q < 0;
-			if((isMaxPool || isConv) && unknownCHWPQ) {
+			if((isPool || isConv) && unknownCHWPQ) {
 				// Only infer input shape for convolution and maxpool
 				inferCHWPQFromParentOp();
 			}
@@ -679,7 +682,7 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		
 		if(parentOp == null)
 			return;
-		else if(parentOp.getOp() == ConvOp.MAX_POOLING) {
+		else if(parentOp.getOp() == ConvOp.MAX_POOLING || parentOp.getOp() == ConvOp.AVG_POOLING) {
 			ConvolutionParameters parentParam = parentOp.parseInput();
 			int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W;
 			// [C, P, Q] from maxpool becomes [C, H, W] of next op
@@ -730,6 +733,7 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 		switch(op) 
 		{
 			case MAX_POOLING:
+			case AVG_POOLING:
 			{	
 				_dim1 = getDim("N");
 				_dim2 = getDim("CPQ");
@@ -737,6 +741,7 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 				break;
 			}
 			case MAX_POOLING_BACKWARD:
+			case AVG_POOLING_BACKWARD:
 			{
 				_dim1 = getDim("N");
 				_dim2 = getDim("CHW");
@@ -849,10 +854,10 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 			input = getInput().get(0);
 			dout  = getInput().get(1);
 		}
-		else if(getOp() == ConvOp.MAX_POOLING) {
+		else if(getOp() == ConvOp.MAX_POOLING || getOp() == ConvOp.AVG_POOLING) {
 			input = getInput().get(0);
 		}
-		else if(getOp() == ConvOp.MAX_POOLING_BACKWARD) {
+		else if(getOp() == ConvOp.MAX_POOLING_BACKWARD || getOp() == ConvOp.AVG_POOLING_BACKWARD) {
 			input = getInput().get(0);
 			dout1  = getInput().get(1);
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 23d29e4..71f4d89 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1110,7 +1110,7 @@ public abstract class Hop implements ParseInfo
 	}
 	
 	public enum ConvOp {
-		MAX_POOLING, MAX_POOLING_BACKWARD,
+		MAX_POOLING, MAX_POOLING_BACKWARD, AVG_POOLING, AVG_POOLING_BACKWARD,
 		DIRECT_CONV2D, DIRECT_CONV2D_BACKWARD_FILTER, DIRECT_CONV2D_BACKWARD_DATA,
 		BIAS_ADD, BIAS_MULTIPLY
 	}
@@ -1177,6 +1177,8 @@ public abstract class Hop implements ParseInfo
 		HopsConv2Lops = new HashMap<>();
 		HopsConv2Lops.put(ConvOp.MAX_POOLING, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.MAX_POOLING);
 		HopsConv2Lops.put(ConvOp.MAX_POOLING_BACKWARD, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.MAX_POOLING_BACKWARD);
+		HopsConv2Lops.put(ConvOp.AVG_POOLING, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.AVG_POOLING);
+		HopsConv2Lops.put(ConvOp.AVG_POOLING_BACKWARD, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.AVG_POOLING_BACKWARD);
 		HopsConv2Lops.put(ConvOp.DIRECT_CONV2D, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.DIRECT_CONV2D);
 		HopsConv2Lops.put(ConvOp.BIAS_ADD, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.BIAS_ADD);
 		HopsConv2Lops.put(ConvOp.BIAS_MULTIPLY, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.BIAS_MULTIPLY);

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
index 68b3b21..bfb4758 100644
--- a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
+++ b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
@@ -30,7 +30,8 @@ public class ConvolutionTransform extends Lop
 
 	
 	public enum OperationTypes {
-		MAX_POOLING, MAX_POOLING_BACKWARD, RELU_MAX_POOLING, RELU_BACKWARD, RELU_MAX_POOLING_BACKWARD,
+		MAX_POOLING, MAX_POOLING_BACKWARD, AVG_POOLING, AVG_POOLING_BACKWARD,
+		RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
 		DIRECT_CONV2D, DIRECT_CONV2D_BACKWARD_FILTER, DIRECT_CONV2D_BACKWARD_DATA,
 		BIAS_ADD, DIRECT_CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS
 	}
@@ -140,6 +141,12 @@ public class ConvolutionTransform extends Lop
 		case MAX_POOLING_BACKWARD:
 			return "maxpooling_backward";
 		
+		case AVG_POOLING:
+			return "avgpooling";
+			
+		case AVG_POOLING_BACKWARD:
+			return "avgpooling_backward";
+		
 		case DIRECT_CONV2D:
 			return "conv2d";
 		

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 2ed02d2..a79a522 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -326,15 +326,15 @@ public class BuiltinFunctionExpression extends DataIdentifier
 				paramExpression = expandListParams(paramExpression, expand);
 				paramExpression = orderConvolutionParams(paramExpression, 2);
 			}
-			else if(_opcode == BuiltinFunctionOp.MAX_POOL || 
-					_opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD) {
+			else if(_opcode == BuiltinFunctionOp.MAX_POOL || _opcode == BuiltinFunctionOp.AVG_POOL ||  
+					_opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD || _opcode == BuiltinFunctionOp.AVG_POOL_BACKWARD) {
 				HashSet<String> expand = new HashSet<>();
 				expand.add("input_shape"); expand.add("pool_size"); expand.add("stride"); expand.add("padding");
 				paramExpression = expandListParams(paramExpression, expand);
 				paramExpression.add(new ParameterExpression("filter_shape1", new IntIdentifier(1, this)));
 				paramExpression.add(new ParameterExpression("filter_shape2", new IntIdentifier(1, this)));
 				paramExpression = replaceListParams(paramExpression, "pool_size", "filter_shape", 3);
-				if(_opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD)
+				if(_opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD || _opcode == BuiltinFunctionOp.AVG_POOL_BACKWARD)
 					paramExpression = orderConvolutionParams(paramExpression, 2);
 				else
 					paramExpression = orderConvolutionParams(paramExpression, 1);
@@ -1160,6 +1160,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
 		case MAX_POOL:
 		case AVG_POOL:
 		case MAX_POOL_BACKWARD:
+		case AVG_POOL_BACKWARD:
 		{
 			// At DML level:
 			// output = conv2d(input, filter, input_shape=[1, 3, 2, 2], filter_shape=[1, 3, 2, 2], 
@@ -1183,7 +1184,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			output.setValueType(ValueType.DOUBLE);
 			output.setBlockDimensions(input.getOutput().getRowsInBlock(), input.getOutput().getColumnsInBlock());
 			
-			if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD) {
+			if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD || this.getOpCode() == BuiltinFunctionOp.AVG_POOL_BACKWARD) {
 				output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2());
 			}
 			else {
@@ -1757,6 +1758,8 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			 bifop = Expression.BuiltinFunctionOp.MAX_POOL_BACKWARD;
 		else if (functionName.equals("avg_pool"))
 			 bifop = Expression.BuiltinFunctionOp.AVG_POOL;
+		else if (functionName.equals("avg_pool_backward"))
+			 bifop = Expression.BuiltinFunctionOp.AVG_POOL_BACKWARD;
 		else if (functionName.equals("solve"))
 			bifop = Expression.BuiltinFunctionOp.SOLVE;
 		else if (functionName.equals("ceil") || functionName.equals("ceiling"))

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 63c896c..2bcdde1 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -3022,15 +3022,19 @@ public class DMLTranslator
 			if(source.getOpCode() == BuiltinFunctionOp.MAX_POOL)
 				currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1);
 			else
-				throw new HopsException("Average pooling is not implemented");
+				currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING, inHops1);
 			setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
 			break;
 		}
+		case AVG_POOL_BACKWARD:
 		case MAX_POOL_BACKWARD:
 		{
 			Hop image = expr;
 			ArrayList<Hop> inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops); // process dout as well
-			currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1);
+			if(source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD)
+				currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1);
+			else
+				currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING_BACKWARD, inHops1);
 			setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
 			break;
 		}
@@ -3055,9 +3059,11 @@ public class DMLTranslator
 			throw new ParseException("Unsupported builtin function type: "+source.getOpCode());
 		}
 		
-		if( !(source.getOpCode() == BuiltinFunctionOp.CONV2D || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA ||
-				source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || source.getOpCode() == BuiltinFunctionOp.MAX_POOL ||
-				source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD) ) {
+		boolean isConvolution = source.getOpCode() == BuiltinFunctionOp.CONV2D || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA ||
+				source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || 
+				source.getOpCode() == BuiltinFunctionOp.MAX_POOL || source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD || 
+				source.getOpCode() == BuiltinFunctionOp.AVG_POOL || source.getOpCode() == BuiltinFunctionOp.AVG_POOL_BACKWARD;
+		if( !isConvolution) {
 			// Since the dimension of output doesnot match that of input variable for these operations
 			setIdentifierParams(currBuiltinOp, source.getOutput());
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index 6fa9ac6..ffffb36 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -87,7 +87,7 @@ public abstract class Expression implements ParseInfo
 		DIAG,
 		EIGEN,
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, BIAS_ADD, BIAS_MULTIPLY,
-		MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD,
+		MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, AVG_POOL_BACKWARD,
 		EXP,
 		FLOOR,
 		IFELSE,

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 169d0b4..de8deea 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -232,6 +232,8 @@ public class CPInstructionParser extends InstructionParser
 		String2CPInstructionType.put( "relu_maxpooling_backward"      , CPType.Convolution);
 		String2CPInstructionType.put( "maxpooling"      , CPType.Convolution);
 		String2CPInstructionType.put( "maxpooling_backward"      , CPType.Convolution);
+		String2CPInstructionType.put( "avgpooling"      , CPType.Convolution);
+		String2CPInstructionType.put( "avgpooling_backward"      , CPType.Convolution);
 		String2CPInstructionType.put( "conv2d"      , CPType.Convolution);
 		String2CPInstructionType.put( "conv2d_bias_add"      , CPType.Convolution);
 		String2CPInstructionType.put( "conv2d_backward_filter"      , CPType.Convolution);

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 3c19b1a..d4e18cb 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -51,6 +51,8 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "conv2d_backward_data",   GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "maxpooling",             GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "maxpooling_backward",    GPUINSTRUCTION_TYPE.Convolution);
+		String2GPUInstructionType.put( "avgpooling",             GPUINSTRUCTION_TYPE.Convolution);
+		String2GPUInstructionType.put( "avgpooling_backward",    GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "bias_add",               GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "bias_multiply",          GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "channel_sums",          GPUINSTRUCTION_TYPE.Convolution);

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index 34daf33..5d4deb2 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -31,6 +31,7 @@ import org.apache.sysml.runtime.functionobjects.KahanPlus;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.SparseBlock;
@@ -103,7 +104,8 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 
 		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
 		String opcode = parts[0];
-		if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) {
+		if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling") ||
+			opcode.equalsIgnoreCase("avgpooling")) {
 			InstructionUtils.checkNumFields(parts, 16);
 			// stride1, stride2, padding1, padding2
 			// input_shape1, input_shape2, input_shape3, input_shape4,
@@ -133,6 +135,7 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 					padding, input_shape, filter_shape, k, Double.parseDouble(parts[16]));
 		} 
 		else if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("relu_maxpooling_backward")
+				|| opcode.equalsIgnoreCase("avgpooling_backward")
 				|| opcode.equalsIgnoreCase("conv2d")
 				|| opcode.equalsIgnoreCase("conv2d_backward_filter")
 				|| opcode.equalsIgnoreCase("conv2d_backward_data")) {
@@ -388,7 +391,7 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 		
 		// acquire inputs
 		MatrixBlock outputBlock = null;
-		MatrixBlock matBlock = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
+		MatrixBlock matBlock = instOpcode.equalsIgnoreCase("avgpooling_backward") ? null : ec.getMatrixInput(input1.getName(), getExtendedOpcode());
 		int pad_h = getScalarInput(ec, _padding, 0);
 		int pad_w = getScalarInput(ec, _padding, 1);
 		int stride_h = getScalarInput(ec, _stride, 0);
@@ -408,28 +411,34 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 		
 		ConvolutionParameters params = new ConvolutionParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, _numThreads);
 		params.enableNative = NativeHelper.isNativeLibraryLoaded();
-		if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) {
+		if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling") ||
+			instOpcode.equalsIgnoreCase("avgpooling")) {
 			if(matBlock.isEmpty()) {
 				outputBlock = new MatrixBlock(N, C*P*Q, true);
 			}
 			else {
 				outputBlock = new MatrixBlock(N, C*P*Q, false).allocateBlock();
+				
+				PoolingType poolType = (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) ? PoolingType.MAX : PoolingType.AVG;
 				if(instOpcode.equalsIgnoreCase("relu_maxpooling"))
 					params.minValForMaxPoolOperations = 0;
-				LibMatrixDNN.maxpooling(matBlock, outputBlock, params);
+				LibMatrixDNN.pooling(matBlock, outputBlock, params, poolType);
 			}
 		}
-		else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward")) {
+		else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward") ||
+				instOpcode.equalsIgnoreCase("avgpooling_backward")) {
 			MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode());
-			if(matBlock.isEmpty() || dout.isEmpty()) {
+			boolean isEmpty = instOpcode.equalsIgnoreCase("avgpooling_backward") ? dout.isEmpty() : (matBlock.isEmpty() || dout.isEmpty());
+			if(isEmpty) {
 				outputBlock = new MatrixBlock(N, C*H*W, true);
 			}
 			else {
 				outputBlock = new MatrixBlock(N, C*H*W, false).allocateBlock();
-				if(instOpcode.equalsIgnoreCase("relu_maxpooling_backward"))
+				PoolingType poolType = (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward")) ? PoolingType.MAX : PoolingType.AVG;
+				boolean performReLUBackward = instOpcode.equalsIgnoreCase("relu_maxpooling_backward");
+				if(performReLUBackward)
 					params.minValForMaxPoolOperations = 0;
-				LibMatrixDNN.maxpoolingBackward(matBlock, dout, outputBlock, params, 
-					!instOpcode.equalsIgnoreCase("maxpooling_backward"));
+				LibMatrixDNN.poolingBackward(matBlock, dout, outputBlock, params, performReLUBackward, poolType);
 			}
 			ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
 		}
@@ -518,7 +527,8 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction {
 		}
 		
 		// release inputs/outputs
-		ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+		if(!instOpcode.equalsIgnoreCase("avgpooling_backward"))
+			ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
 		ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode());
 	}
 	

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index 62a20b8..9e2d672 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -28,6 +28,7 @@ import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 import org.apache.sysml.utils.GPUStatistics;
@@ -131,7 +132,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 			return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, stride,
 					padding, input_shape, filter_shape, Double.parseDouble(parts[16]));
 		}
-		else if( opcode.equalsIgnoreCase("maxpooling_backward") ) {
+		else if( opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") ) {
 			boolean withMaxPoolOut = false;
 			if(parts.length == 18) {
 				withMaxPoolOut = true;
@@ -191,7 +192,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 			return new ConvolutionGPUInstruction(in1, in2, in3, out, opcode, str, stride,
 					padding, input_shape, filter_shape, Double.parseDouble(parts[17]));
 		}
-		else if (opcode.equalsIgnoreCase("maxpooling")) {
+		else if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("avgpooling")) {
 			InstructionUtils.checkNumFields(parts, 15);
 			CPOperand in1 = new CPOperand(parts[1]);
 			CPOperand out = new CPOperand(parts[14]);
@@ -382,7 +383,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 			LibMatrixCuDNN.conv2dBackwardData(ec.getGPUContext(0), getExtendedOpcode(), filter, dout, out, N, C, H, W,
 					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
 		}
-		else if (instOpcode.equalsIgnoreCase("maxpooling")) {
+		else if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling")) {
 			MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
 
 			if(image.getNumRows() != N || image.getNumColumns() != C*H*W) 
@@ -390,12 +391,11 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 						image.getNumRows() + " != " +  N + " || " + image.getNumColumns() + " != " + C*H*W);
 			
 			MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * P * Q);
-			
-			if(instOpcode.equalsIgnoreCase("maxpooling"))
-				LibMatrixCuDNN.maxpooling(ec.getGPUContext(0), getExtendedOpcode(), image, out, N, C, H, W,
-					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
+			PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling") ? PoolingType.MAX : PoolingType.AVG;
+			LibMatrixCuDNN.pooling(ec.getGPUContext(0), getExtendedOpcode(), image, out, N, C, H, W,
+					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget);
 		}
-		else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) {
+		else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward")) {
 			MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
 			MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
 			MatrixObject maxPoolOutput = _input3 != null ? getMatrixInputForGPUInstruction(ec, _input3.getName()) : null;
@@ -406,9 +406,9 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 						image.getNumRows() + " != " +  N + " || " + image.getNumColumns() + " != " + K*P*Q);
 			
 			MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
-			
-			LibMatrixCuDNN.maxpoolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W,
-					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
+			PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling_backward") ? PoolingType.MAX : PoolingType.AVG;
+			LibMatrixCuDNN.poolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W,
+					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget);
 		}
 		else {
 			throw new DMLRuntimeException("Unsupported GPU context for " + instOpcode);
@@ -416,12 +416,15 @@ public class ConvolutionGPUInstruction extends GPUInstruction {
 		
 		// release inputs/outputs
 		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+		
+		boolean isPool = instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling");
+		boolean isPoolBackward = instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward");
 
-		if ( !instOpcode.equalsIgnoreCase("maxpooling") )
+		if ( !isPool )
 			ec.releaseMatrixInputForGPUInstruction(_input2.getName());
 
 		if (instOpcode.equalsIgnoreCase("conv2d_bias_add") || 
-			(instOpcode.equalsIgnoreCase("maxpooling_backward") && _input3 != null))
+			(isPoolBackward && _input3 != null))
 			ec.releaseMatrixInputForGPUInstruction(_input3.getName());
 
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java
index 77141b3..0ec5595 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java
@@ -40,6 +40,7 @@ import org.apache.sysml.runtime.matrix.MetaDataFormat;
 import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
 import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
@@ -356,7 +357,16 @@ public class ConvolutionSPInstruction extends UnarySPInstruction {
 					outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock();
 					if(instOpcode.equalsIgnoreCase("maxpooling"))
 						outputBlock.getDenseBlock().set(-Double.MAX_VALUE);
-					LibMatrixDNN.maxpooling(matBlock, outputBlock, params);
+					LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.MAX);
+				}
+			}
+			else if(instOpcode.equalsIgnoreCase("avgpooling") || instOpcode.equalsIgnoreCase("relu_avgpooling")) {
+				if(matBlock.isEmptyBlock()) {
+					outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true);
+				}
+				else {
+					outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock();
+					LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.AVG);
 				}
 			}
 			else {

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index 122304e..6642ee0 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -53,6 +53,7 @@ import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
 import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.utils.GPUStatistics;
 import org.apache.sysml.utils.Statistics;
 
@@ -618,13 +619,14 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 * @param stride_w		vertical stride
 	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
 	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @param poolingType	type of pooling
 	 * @param intermediateMemoryBudget intermediate memory budget
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
-	public static void maxpooling(GPUContext gCtx, String instName, MatrixObject image,
+	public static void pooling(GPUContext gCtx, String instName, MatrixObject image,
 			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
+			int Q, PoolingType poolingType, double intermediateMemoryBudget) throws DMLRuntimeException {
 		long CHW = C*H*W; long CPQ = C*P*Q;  
 		long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
@@ -634,12 +636,12 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			Pointer y = getDensePointerForCuDNN(gCtx, outputBlock, instName);
 			if(overhead <= intermediateMemoryBudget) {
 				Pointer x = getDensePointerForCuDNN(gCtx, image, instName);
-				cudnnMaxpooling(gCtx, instName, x, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+				cudnnPoolingHelper(gCtx, instName, x, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
 			}
 			else {
 				LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
 				for(int n = 0; n < N; n++) {
-					cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset(n*CPQ*sizeOfDataType), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+					cudnnPoolingHelper(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset(n*CPQ*sizeOfDataType), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
 				}
 				imgFetcher.close();
 			}
@@ -649,17 +651,17 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		}
 	}
 
-	private static void cudnnMaxpooling(GPUContext gCtx, String instName, Pointer x,
+	private static void cudnnPoolingHelper(GPUContext gCtx, String instName, Pointer x,
 			Pointer y, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
+			int Q, PoolingType poolingType) throws DMLRuntimeException {
 		if(LOG.isTraceEnabled()) {
-			LOG.trace("GPU : performMaxpooling" + ", GPUContext=" + gCtx);
+			LOG.trace("GPU : perform pooling" + ", GPUContext=" + gCtx);
 		}
 
 		try(LibMatrixCuDNNPoolingDescriptors desc = 
-				LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingDescriptors(gCtx, instName, N, C, H, W, K, R, S, 
-						pad_h, pad_w, stride_h, stride_w, P, Q)) {
+				LibMatrixCuDNNPoolingDescriptors.cudnnPoolingDescriptors(gCtx, instName, N, C, H, W, K, R, S, 
+						pad_h, pad_w, stride_h, stride_w, P, Q, poolingType)) {
 			long t1=0,t2=0;
 			if (DMLScript.FINEGRAINED_STATISTICS) t1 = System.nanoTime();
 			if (DMLScript.FINEGRAINED_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
@@ -673,7 +675,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 			throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
 		}
 	}
-
+	
 	/**
 	 * Performs maxpoolingBackward on GPU by exploiting cudnnPoolingBackward(...)
 	 * This method computes the backpropogation errors for previous layer of maxpooling operation
@@ -696,13 +698,14 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 * @param stride_w		vertical stride
 	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
 	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @param poolingType	type of pooling
 	 * @param intermediateMemoryBudget intermediate memory budget
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
-	public static void maxpoolingBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
+	public static void poolingBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
 			MatrixObject maxpoolOutput, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
+			int Q, PoolingType poolingType, double intermediateMemoryBudget) throws DMLRuntimeException {
 		long CHW = C*H*W; long CPQ = C*P*Q;  
 		long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
@@ -717,7 +720,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 				Pointer x = getDensePointerForCuDNN(gCtx, image, instName);
 				Pointer dy = getDensePointerForCuDNN(gCtx, dout, instName);
 				Pointer y = isMaxPoolOutputProvided ? getDensePointerForCuDNN(gCtx, maxpoolOutput, instName) : null;
-				cudnnMaxpoolingBackward(gCtx, instName, x, dy, y, dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+				cudnnPoolingBackwardHelper(gCtx, instName, x, dy, y, dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
 			}
 			else {
 				LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
@@ -727,9 +730,9 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 					Pointer x = imgFetcher.getNthRow(n);
 					Pointer dy = doutFetcher.getNthRow(n);
 					Pointer y = isMaxPoolOutputProvided ? maxPoolOutFetcher.getNthRow(n) : null;
-					cudnnMaxpoolingBackward(gCtx, instName, x, dy, y, 
+					cudnnPoolingBackwardHelper(gCtx, instName, x, dy, y, 
 							dx.withByteOffset(n*CHW*sizeOfDataType), 
-							1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+							1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolingType);
 				}
 				// Deallocate temporary array to hold one element of input
 				imgFetcher.close();
@@ -743,11 +746,11 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		}
 	}
 	
-	private static void cudnnMaxpoolingBackward(GPUContext gCtx, String instName, 
+	private static void cudnnPoolingBackwardHelper(GPUContext gCtx, String instName, 
 			Pointer x, Pointer dy, Pointer y, Pointer dx, 
 			int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
+			int Q, PoolingType poolingType) throws DMLRuntimeException {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : maxpoolingBackward" + ", GPUContext=" + gCtx);
 		}
@@ -755,8 +758,8 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		boolean isMaxPoolOutputProvided = (y != null);
 
 		try(LibMatrixCuDNNPoolingDescriptors desc = 
-				LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingBackwardDescriptors(gCtx, instName, N, C, H, W, K, R, S, 
-						pad_h, pad_w, stride_h, stride_w, P, Q)) {
+				LibMatrixCuDNNPoolingDescriptors.cudnnPoolingBackwardDescriptors(gCtx, instName, N, C, H, W, K, R, S, 
+						pad_h, pad_w, stride_h, stride_w, P, Q, poolingType)) {
 			long t1=0, t2=0, t3=0;
 			int status;
 			if(!isMaxPoolOutputProvided) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
index d4b213f..8c9dea4 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
@@ -26,10 +26,12 @@ import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
 import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
 import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
+import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 
 import jcuda.jcudnn.cudnnPoolingDescriptor;
 import jcuda.jcudnn.cudnnTensorDescriptor;
@@ -80,19 +82,20 @@ public class LibMatrixCuDNNPoolingDescriptors implements java.lang.AutoCloseable
 	 * @param stride_w		vertical stride
 	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
 	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @param poolingType	type of pooling
 	 * @return decriptor wrapper
 	 * @throws DMLRuntimeException if error occurs
 	 */
-	public static LibMatrixCuDNNPoolingDescriptors cudnnMaxpoolingBackwardDescriptors(GPUContext gCtx, 
+	public static LibMatrixCuDNNPoolingDescriptors cudnnPoolingBackwardDescriptors(GPUContext gCtx, 
 			String instName, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
+			int Q, PoolingType poolingType) throws DMLRuntimeException {
 		LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
 		ret.xDesc = allocateTensorDescriptor(N, C, H, W);
 		ret.yDesc = allocateTensorDescriptor(N, C, P, Q);
 		ret.dxDesc = allocateTensorDescriptor(N, C, H, W);
 		ret.dyDesc = allocateTensorDescriptor(N, C, P, Q);
-		ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+		ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w, poolingType);
 		return ret;
 	}
 	
@@ -114,17 +117,18 @@ public class LibMatrixCuDNNPoolingDescriptors implements java.lang.AutoCloseable
 	 * @param stride_w		vertical stride
 	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
 	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @param poolingType 	type of pooling
 	 * @return decriptor wrapper
 	 * @throws DMLRuntimeException if error occurs
 	 */
-	public static LibMatrixCuDNNPoolingDescriptors cudnnMaxpoolingDescriptors(GPUContext gCtx, 
+	public static LibMatrixCuDNNPoolingDescriptors cudnnPoolingDescriptors(GPUContext gCtx, 
 			String instName, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
-			int Q) throws DMLRuntimeException {
+			int Q, PoolingType poolingType) throws DMLRuntimeException {
 		LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
 		ret.xDesc = allocateTensorDescriptor(N, C, H, W);
 		ret.yDesc = allocateTensorDescriptor(N, C, P, Q);
-		ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+		ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w, poolingType);
 		return ret;
 	}
 
@@ -152,12 +156,14 @@ public class LibMatrixCuDNNPoolingDescriptors implements java.lang.AutoCloseable
 	 * @param pad_w		horizontal padding
 	 * @param stride_h	pooling vertical stride
 	 * @param stride_w	pooling horizontal stride
+	 * @param poolingType type of pooling
 	 * @return cudnn pooling descriptor
 	 */
-	private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w) {
+	private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, PoolingType poolingType) {
 		cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
 		cudnnCreatePoolingDescriptor(poolingDesc);
-		cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, CUDNN_PROPAGATE_NAN, R, S, pad_h, pad_w, stride_h, stride_w);
+		int CUDNN_POOLING = (poolingType == PoolingType.MAX) ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+		cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING, CUDNN_PROPAGATE_NAN, R, S, pad_h, pad_w, stride_h, stride_w);
 		return poolingDesc;
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/54a11eed/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 1ad56b2..d089521 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -64,6 +64,9 @@ import org.apache.sysml.runtime.util.ConvolutionUtils;
 public class LibMatrixDNN {
 	
 	protected static final Log LOG =  LogFactory.getLog(LibMatrixDNN.class.getName());
+	public static enum PoolingType {
+		MAX, AVG
+	}
 	
 	//library configurations and external contracts
 	// ------------------------------------------------------------------------------------------------
@@ -189,7 +192,7 @@ public class LibMatrixDNN {
 		outputBlock.examSparsity();
 	}
 	
-	public static void maxpooling(MatrixBlock input, MatrixBlock output, ConvolutionParameters params) throws DMLRuntimeException {
+	public static void pooling(MatrixBlock input, MatrixBlock output, ConvolutionParameters params, PoolingType poolType) throws DMLRuntimeException {
 		params.input1 = input;
 		params.output = output;
 		
@@ -202,7 +205,7 @@ public class LibMatrixDNN {
 		if( !params.isStride1Pad0() || input.sparse )
 			fillIndexesArray(params);
 		
-		long nnz = execute(LibMatrixDNNPooling.getMaxPoolingWorkers(params), params);
+		long nnz = execute(LibMatrixDNNPooling.getPoolingWorkers(params, poolType), params);
 		
 		// post-processing: maintain nnz
 		output.setNonZeros(nnz);
@@ -211,45 +214,51 @@ public class LibMatrixDNN {
 	
 
 	/**
-	 * This method computes the backpropogation errors for previous layer of maxpooling operation
+	 * This method computes the backpropogation errors for previous layer of pooling operation
 	 * 
 	 * @param input input matrix
 	 * @param dout dout matrix
 	 * @param outputBlock output matrix
 	 * @param params convolution parameters
 	 * @param performReluBackward perform ReLU backward
+	 * @param poolType type of pooling
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
-	public static void maxpoolingBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, 
-			ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
+	public static void poolingBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, 
+			ConvolutionParameters params, boolean performReluBackward, PoolingType poolType) throws DMLRuntimeException {
 		params.input1 = input;
 		params.input2 = dout;
 		params.output = outputBlock;
-		if(input.getNumColumns() != params.C*params.H*params.W || input.getNumRows() != params.N) {
+		
+		if(poolType == PoolingType.MAX && (input.getNumColumns() != params.C*params.H*params.W || input.getNumRows() != params.N)) {
 			throw new DMLRuntimeException("Incorrect input dimensions in maxpooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K*params.P*params.Q);
 		}
 
 		if(dout.getNumColumns() != params.C*params.P*params.Q || dout.getNumRows() != params.N) {
-			throw new DMLRuntimeException("Incorrect dout dimensions in maxpooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K*params.P*params.Q);
+			throw new DMLRuntimeException("Incorrect dout dimensions in pooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K*params.P*params.Q);
 		}
 		
 		if(DMLScript.FINEGRAINED_STATISTICS) {
-			if(input.isInSparseFormat() || dout.isInSparseFormat())
+			boolean isSparse = (poolType == PoolingType.MAX) ? (input.isInSparseFormat() || dout.isInSparseFormat()) : dout.isInSparseFormat();
+			if(isSparse)
 				maxPoolBwdSparseCount.addAndGet(1);
 			else
 				maxPoolBwdDenseCount.addAndGet(1);
 		}
 		
 		if (params.output.isInSparseFormat())
-			throw new DMLRuntimeException("Sparse maxpooling_backward is not supported");
+			throw new DMLRuntimeException("Sparse pooling_backward is not supported");
 
-		if( !(params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) )
-			fillIndexesArray(params); //not needed for sparse-dense
-		
-		long nnz = execute(LibMatrixDNNPooling.getMaxPoolingBackwardWorkers(params, performReluBackward), params);
-		
+		if(poolType == PoolingType.AVG) {
+			fillIndexesArray(params); 
+		}
+		else {
+			if( !(params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) )
+				fillIndexesArray(params); //not needed for sparse-dense	 
+		}
+		long nnz = execute(LibMatrixDNNPooling.getPoolingBackwardWorkers(params, performReluBackward, poolType), params);
 		//post-processing: maintain nnz 
-		outputBlock.setNonZeros(nnz); 
+		outputBlock.setNonZeros(nnz);
 		outputBlock.examSparsity();
 	}