You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/23 02:21:55 UTC
[systemml] branch master updated: [SYSTEMML-540] Throw exception
whenever parameter of a Keras layer is not supported by SystemML
This is an automated email from the ASF dual-hosted git repository.
niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 7cab282 [SYSTEMML-540] Throw exception whenever parameter of a Keras layer is not supported by SystemML
7cab282 is described below
commit 7cab282faa77b3bc66200396803f97ec1375544a
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Fri Mar 22 19:21:40 2019 -0700
[SYSTEMML-540] Throw exception whenever parameter of a Keras layer is not supported by SystemML
---
docs/reference-guide-caffe2dml.md | 15 ++++++
src/main/python/systemml/mllearn/keras2caffe.py | 64 +++++++++++++++++--------
2 files changed, 60 insertions(+), 19 deletions(-)
diff --git a/docs/reference-guide-caffe2dml.md b/docs/reference-guide-caffe2dml.md
index 381b96d..6242e03 100644
--- a/docs/reference-guide-caffe2dml.md
+++ b/docs/reference-guide-caffe2dml.md
@@ -450,6 +450,21 @@ layer {
## Utility Layers
+### Flatten Layer
+
+The Flatten layer is a utility layer that flattens an input of shape n * c * h * w to a simple vector output of shape n * (c*h*w).
+
+
+**Sample Usage:**
+```
+layer {
+ name: "flatten_1"
+ type: "Flatten"
+ bottom: "max_pooling2d_2"
+ top: "flatten_1"
+}
+```
+
### Eltwise Layer
Element-wise operations such as product or sum between two blobs.
diff --git a/src/main/python/systemml/mllearn/keras2caffe.py b/src/main/python/systemml/mllearn/keras2caffe.py
index 2b97560..39a9755 100755
--- a/src/main/python/systemml/mllearn/keras2caffe.py
+++ b/src/main/python/systemml/mllearn/keras2caffe.py
@@ -192,6 +192,7 @@ def _parseKerasLayer(layer):
def _parseBatchNorm(layer):
+ # TODO: Ignoring axis
bnName = layer.name + '_1'
config = layer.get_config()
bias_term = 'true' if config['center'] else 'false'
@@ -215,44 +216,51 @@ def getPadding(kernel_size, padding):
else:
raise ValueError('Unsupported padding:' + str(padding))
+# Used by padding to extract different types of possible padding:
+# int: the same symmetric padding is applied to height and width.
+# tuple of 2 ints: interpreted as two different symmetric padding values for height and width: (symmetric_height_pad, symmetric_width_pad)
+# tuple of 2 tuples of 2 ints: interpreted as ((top_pad, bottom_pad), (left_pad, right_pad))
+def get2Tuple(val):
+ return [val, val] if isinstance(val, int) else [val[0], val[1]]
+
# Helper method to return Caffe's ConvolutionParameter in JSON-like data structure
def getConvParam(layer):
- stride = (1, 1) if layer.strides is None else layer.strides
+ # TODO: dilation_rate, kernel_constraint and bias_constraint are not supported
+ stride = (1, 1) if layer.strides is None else get2Tuple(layer.strides)
+ kernel_size = get2Tuple(layer.kernel_size)
config = layer.get_config()
+ if not layer.use_bias:
+ raise Exception('use_bias=False is not supported for the Conv2D layer. Consider setting use_bias to true.')
return {'num_output': layer.filters, 'bias_term': str(config['use_bias']).lower(
- ), 'kernel_h': layer.kernel_size[0], 'kernel_w': layer.kernel_size[1], 'stride_h': stride[0], 'stride_w': stride[1],
- 'pad_h': getPadding(layer.kernel_size[0], layer.padding), 'pad_w': getPadding(layer.kernel_size[1], layer.padding)}
+ ), 'kernel_h': kernel_size[0], 'kernel_w': kernel_size[1], 'stride_h': stride[0], 'stride_w': stride[1],
+ 'pad_h': getPadding(kernel_size[0], layer.padding), 'pad_w': getPadding(kernel_size[1], layer.padding)}
# Helper method to return newly added UpsampleParameter
# (search for UpsampleParameter in the file src/main/proto/caffe/caffe.proto) in JSON-like data structure
def getUpSamplingParam(layer):
- return {'size_h': layer.size[0], 'size_w': layer.size[1]}
-
-# Used by padding to extract different types of possible padding:
-# int: the same symmetric padding is applied to height and width.
-# tuple of 2 ints: interpreted as two different symmetric padding values for height and width: (symmetric_height_pad, symmetric_width_pad)
-# tuple of 2 tuples of 2 ints: interpreted as ((top_pad, bottom_pad), (left_pad, right_pad))
-def getPaddingTuple(padding):
- return [padding, padding] if isinstance(padding, int) else [padding[0], padding[1]]
+ # TODO: Skipping interpolation type
+ size = get2Tuple(layer.size)
+ return {'size_h': size[0], 'size_w': size[1]}
# Helper method to return newly added PaddingParameter
# (search for UpsampleParameter in the file src/main/proto/caffe/caffe.proto) in JSON-like data structure
def getPaddingParam(layer):
if isinstance(layer.padding, int):
- padding = getPaddingTuple(layer.padding) + getPaddingTuple(layer.padding)
+ padding = get2Tuple(layer.padding) + get2Tuple(layer.padding)
elif hasattr(layer.padding, '__len__') and len(layer.padding) == 2:
- padding = getPaddingTuple(layer.padding[0]) + getPaddingTuple(layer.padding[1])
+ padding = get2Tuple(layer.padding[0]) + get2Tuple(layer.padding[1])
else:
raise ValueError('padding should be either an int, a tuple of 2 ints or or a tuple of 2 tuples of 2 ints. Found: ' + str(layer.padding))
return {'top_pad': padding[0], 'bottom_pad': padding[1], 'left_pad': padding[2], 'right_pad': padding[3], 'pad_value':0}
# Helper method to return Caffe's PoolingParameter in JSON-like data structure
def getPoolingParam(layer, pool='MAX'):
- stride = (1, 1) if layer.strides is None else layer.strides
- return {'pool': pool, 'kernel_h': layer.pool_size[0], 'kernel_w': layer.pool_size[1],
- 'stride_h': stride[0], 'stride_w': stride[1], 'pad_h': getPadding(layer.pool_size[0], layer.padding),
- 'pad_w': getPadding(layer.pool_size[1], layer.padding)}
+ stride = (1, 1) if layer.strides is None else get2Tuple(layer.strides)
+ pool_size = get2Tuple(layer.pool_size)
+ return {'pool': pool, 'kernel_h': pool_size[0], 'kernel_w': pool_size[1],
+ 'stride_h': stride[0], 'stride_w': stride[1], 'pad_h': getPadding(pool_size[0], layer.padding),
+ 'pad_w': getPadding(pool_size[1], layer.padding)}
# Helper method to return Caffe's RecurrentParameter in JSON-like data structure
def getRecurrentParam(layer):
@@ -270,21 +278,39 @@ def getInnerProductParam(layer):
if len(layer.output_shape) != 2:
raise Exception('Only 2-D input is supported for the Dense layer in the current implementation, but found '
+ str(layer.input_shape) + '. Consider adding a Flatten before ' + str(layer.name))
+ if not layer.use_bias:
+ raise Exception('use_bias=False is not supported for the Dense layer. Consider setting use_bias to true.')
return {'num_output': layer.units}
+# Helper method to return Caffe's DropoutParameter in JSON-like data structure
+def getDropoutParam(layer):
+ if layer.noise_shape is not None:
+ supported = True
+ if len(layer.input_shape) != len(layer.noise_shape):
+ supported = False
+ else:
+ for i in range(len(layer.noise_shape)-1):
+ # Ignore the first dimension
+ if layer.input_shape[i+1] != layer.noise_shape[i+1]:
+ supported = False
+ if not supported:
+ raise Exception('noise_shape=' + str(layer.noise_shape) + ' is not supported for Dropout layer with input_shape='
+ + str(layer.input_shape))
+ return {'dropout_ratio': l.rate}
+
layerParamMapping = {
keras.layers.InputLayer: lambda l:
{'data_param': {'batch_size': l.batch_size}},
keras.layers.Dense: lambda l:
{'inner_product_param': getInnerProductParam(l)},
keras.layers.Dropout: lambda l:
- {'dropout_param': {'dropout_ratio': l.rate}},
+ {'dropout_param': getDropoutParam(l)},
keras.layers.Add: lambda l:
{'eltwise_param': {'operation': 'SUM'}},
keras.layers.Concatenate: lambda l:
{'concat_param': {'axis': _getCompensatedAxis(l)}},
keras.layers.Conv2DTranspose: lambda l:
- {'convolution_param': getConvParam(l)},
+ {'convolution_param': getConvParam(l)}, # will skip output_padding
keras.layers.UpSampling2D: lambda l:
{'upsample_param': getUpSamplingParam(l)},
keras.layers.ZeroPadding2D: lambda l: