You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/04/19 22:08:34 UTC

[3/3] incubator-systemml git commit: [SYSTEMML-692] Added initial version of DML generator for Caffe

[SYSTEMML-692] Added initial version of DML generator for Caffe

This experimental interface is called Caffe2DML and doesnot affect other functionality.

- Updated the interface to match the Caffe specification as per
  @bertholdreinwald 's suggestion.
- Added support for fine-tuning.
- Added support for explain, statistics and gpu.

Closes #422.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cc7993fc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cc7993fc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cc7993fc

Branch: refs/heads/master
Commit: cc7993fc87ccf7d404bc8802f9529aee7da5de5e
Parents: ad3e78a
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Wed Apr 19 14:07:44 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed Apr 19 15:07:43 2017 -0700

----------------------------------------------------------------------
 docs/beginners-guide-caffe2dml.md               |  124 ++
 docs/devdocs/deep-learning.md                   |   84 ++
 pom.xml                                         |   47 +-
 .../cp/AggregateUnaryCPInstruction.java         |    2 +-
 .../sysml/runtime/util/ConvolutionUtils.java    |   12 +
 .../udf/lib/Caffe2DMLVisualizeWrapper.java      |   66 +
 .../apache/sysml/utils/TensorboardLogger.java   |  177 +++
 src/main/proto/caffe/caffe.proto                | 1424 ++++++++++++++++++
 src/main/proto/tensorflow/event.proto           |  102 ++
 src/main/proto/tensorflow/summary.proto         |  123 ++
 src/main/python/setup.py                        |    4 +-
 src/main/python/systemml/converters.py          |   31 +-
 src/main/python/systemml/mllearn/estimators.py  |  168 ++-
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |  510 +++++++
 .../org/apache/sysml/api/dl/CaffeLayer.scala    |  357 +++++
 .../org/apache/sysml/api/dl/CaffeNetwork.scala  |  180 +++
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  158 ++
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  311 ++++
 .../scala/org/apache/sysml/api/dl/Utils.scala   |  127 ++
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |   38 +-
 .../sysml/api/ml/BaseSystemMLRegressor.scala    |    4 +
 .../sysml/api/ml/LogisticRegression.scala       |    2 +-
 .../org/apache/sysml/api/ml/NaiveBayes.scala    |    2 +-
 .../scala/org/apache/sysml/api/ml/SVM.scala     |    2 +-
 24 files changed, 4036 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/docs/beginners-guide-caffe2dml.md
----------------------------------------------------------------------
diff --git a/docs/beginners-guide-caffe2dml.md b/docs/beginners-guide-caffe2dml.md
new file mode 100644
index 0000000..cfcc0cb
--- /dev/null
+++ b/docs/beginners-guide-caffe2dml.md
@@ -0,0 +1,124 @@
+---
+layout: global
+title: Beginner's Guide for Caffe2DML users
+description: Beginner's Guide for Caffe2DML users
+---
+<!--
+{% comment %}
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+{% endcomment %}
+-->
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+<br/>
+
+## Introduction
+
+Caffe2DML is an experimental API that converts an Caffe specification to DML.
+
+## Frequently asked questions
+
+- How to set batch size ?
+
+Batch size is set in `data_param` of the Data layer:
+
+	layer {
+	  name: "mnist"
+	  type: "Data"
+	  top: "data"
+	  top: "label"
+	  data_param {
+	    source: "mnist_train"
+	    batch_size: 64
+	    backend: LMDB
+	  }
+	}
+	
+- How to set maximum number of iterations for training ?
+
+Caffe allows you to set the maximum number of iterations in solver specification
+
+	# The maximum number of iterations
+	max_iter: 2000
+	
+- How to set the size of the validation dataset ?
+
+The size of the validation dataset is determined by the parameters `test_iter` and the batch size. For example: If the batch size is 64 and 
+`test_iter` is 10, then the validation size is 640. This setting generates following DML code internally:
+
+	num_images = nrow(y_full)
+	BATCH_SIZE = 64
+	num_validation = 10 * BATCH_SIZE
+	X = X_full[(num_validation+1):num_images,]; y = y_full[(num_validation+1):num_images,]
+	X_val = X_full[1:num_validation,]; y_val = y_full[1:num_validation,]
+	num_images = nrow(y) 
+
+- How to monitor loss via command-line ?
+
+To monitor loss, please set following parameters in the solver specification
+
+	# Display training loss and accuracy every 100 iterations
+	display: 100
+	# Carry out validation every 500 training iterations and display validation loss and accuracy.
+	test_iter: 10
+	test_interval: 500
+	
+ - How to pass a single jpeg image to Caffe2DML for prediction ?
+ 
+	from PIL import Image
+	import systemml as sml
+	from systemml.mllearn import Caffe2DML
+	img_shape = (3, 224, 224)
+	input_image = sml.convertImageToNumPyArr(Image.open(img_file_path), img_shape=img_shape)
+	resnet = Caffe2DML(sqlCtx, solver='ResNet_50_solver.proto', weights='ResNet_50_pretrained_weights', input_shape=img_shape)
+	resnet.predict(input_image)
+
+- How to prepare a directory of jpeg images for training with Caffe2DML ?
+
+The below example assumes that the input dataset has 2 labels `cat` and `dogs` and the filename has these labels as prefix.
+We iterate through the directory and convert each jpeg image into pyspark.ml.linalg.Vector using pyspark.
+These vectors are stored as DataFrame and randomized using Spark SQL's `orderBy(rand())` function.
+The DataFrame is then saved in parquet format to reduce the cost of preprocessing for repeated training.
+
+	from systemml.mllearn import Caffe2DML
+	from pyspark.sql import SQLContext
+	import numpy as np
+	import urllib, os, scipy.ndimage
+	from pyspark.ml.linalg import Vectors
+	from pyspark import StorageLevel
+	import systemml as sml
+	from pyspark.sql.functions import rand 
+	# ImageNet specific parameters
+	img_shape = (3, 224, 224)
+	train_dir = '/home/biuser/dogs_vs_cats/train'
+	def getLabelFeatures(filename):
+		from PIL import Image
+		vec = Vectors.dense(sml.convertImageToNumPyArr(Image.open(os.path.join(train_dir, filename)), img_shape=img_shape)[0,:])
+		if filename.lower().startswith('cat'):
+			return (1, vec)
+		elif filename.lower().startswith('dog'):
+			return (2, vec)
+		else:
+			raise ValueError('Expected the filename to start with either cat or dog')
+	
+	list_jpeg_files = os.listdir(train_dir)
+	# 10 files per partition
+	train_df = sc.parallelize(list_jpeg_files, int(len(list_jpeg_files)/10)).map(lambda filename : getLabelFeatures(filename)).toDF(['label', 'features']).orderBy(rand())
+	# Optional: but helps seperates conversion-related from training
+	# Alternatively, this dataframe can be passed directly to `caffe2dml_model.fit(train_df)`
+	train_df.write.parquet('kaggle-cats-dogs.parquet')
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/docs/devdocs/deep-learning.md
----------------------------------------------------------------------
diff --git a/docs/devdocs/deep-learning.md b/docs/devdocs/deep-learning.md
index 1fb951a..329c6c8 100644
--- a/docs/devdocs/deep-learning.md
+++ b/docs/devdocs/deep-learning.md
@@ -139,3 +139,87 @@ updates for the image:
 |-----------------|---------------------------------|-----------------|
 | `w3*y1 + w1*y3` | `w4*y1 + w3*y2 + w2*y3 + w1*y4` | `w4*y2 + w2*y4` |
 | `w3*y3`         | `w4*y3 + w3*y4`                 | `w4*y4`         |
+
+# Caffe2DML examples
+
+## Training using Caffe models on Lenet
+
+The below script also demonstrates how to save the trained model.
+
+```python
+# Download the MNIST dataset
+from mlxtend.data import mnist_data
+import numpy as np
+from sklearn.utils import shuffle
+X, y = mnist_data()
+X, y = shuffle(X, y)
+num_classes = np.unique(y).shape[0]
+img_shape = (1, 28, 28)
+
+# Split the data into training and test
+n_samples = len(X)
+X_train = X[:int(.9 * n_samples)]
+y_train = y[:int(.9 * n_samples)]
+X_test = X[int(.9 * n_samples):]
+y_test = y[int(.9 * n_samples):]
+
+# Download the Lenet network
+import urllib
+urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet.proto', 'lenet.proto')
+urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet_solver.proto', 'lenet_solver.proto')
+
+# Train Lenet On MNIST using scikit-learn like API
+from systemml.mllearn import Caffe2DML
+lenet = Caffe2DML(sqlCtx, solver='lenet_solver.proto').set(max_iter=500, debug=True).setStatistics(True)
+print('Lenet score: %f' % lenet.fit(X_train, y_train).score(X_test, y_test))
+
+# Save the trained model
+lenet.save('lenet_model')
+```
+
+## Load the trained model and retrain (i.e. finetuning)
+
+```python
+# Fine-tune the existing trained model
+new_lenet = Caffe2DML(sqlCtx, solver='lenet_solver.proto', weights='lenet_model').set(max_iter=500, debug=True)
+new_lenet.fit(X_train, y_train)
+new_lenet.save('lenet_model')
+```
+
+## Perform prediction using the above trained model
+
+```python
+# Use the new model for prediction
+predict_lenet = Caffe2DML(sqlCtx, solver='lenet_solver.proto', weights='lenet_model')
+print('Lenet score: %f' % predict_lenet.score(X_test, y_test))
+```
+
+Similarly, you can perform prediction using the pre-trained ResNet network
+
+```python
+from systemml.mllearn import Caffe2DML
+from pyspark.sql import SQLContext
+import numpy as np
+import urllib, os, scipy.ndimage
+from PIL import Image
+import systemml as sml
+
+# ImageNet specific parameters
+img_shape = (3, 224, 224)
+
+# Downloads a jpg image, resizes it to 224 and return as numpy array in N X CHW format
+url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/58/MountainLion.jpg/312px-MountainLion.jpg'
+outFile = 'test.jpg'
+urllib.urlretrieve(url, outFile)
+input_image = sml.convertImageToNumPyArr(Image.open(outFile), img_shape=img_shape)
+
+# Download the ResNet network
+import urllib
+urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/resnet/ilsvrc12/ResNet_50_network.proto', 'ResNet_50_network.proto')
+urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/resnet/ilsvrc12/ResNet_50_solver.proto', 'ResNet_50_solver.proto')
+
+# Assumes that you have cloned the model_zoo repository
+# git clone https://github.com/niketanpansare/model_zoo.git
+resnet = Caffe2DML(sqlCtx, solver='ResNet_50_solver.proto', weights='~/model_zoo/caffe/vision/resnet/ilsvrc12/ResNet_50_pretrained_weights').set(input_shape=img_shape)
+resnet.predict(input_image)
+```
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index eba7f57..d107f64 100644
--- a/pom.xml
+++ b/pom.xml
@@ -324,6 +324,46 @@
 					</execution>
 				</executions>
 			</plugin>
+			
+			<plugin>
+			    <groupId>com.github.os72</groupId>
+			    <artifactId>protoc-jar-maven-plugin</artifactId>
+			    <version>3.0.0-b2.1</version>
+			    <executions>
+			        <execution>
+			        	<id>caffe-sources</id>
+			            <phase>generate-sources</phase>
+			            <goals>
+			                <goal>run</goal>
+			            </goals>
+			            <configuration>
+			                <protocVersion>2.5.0</protocVersion> <!-- 2.4.1, 2.5.0, 2.6.1, 3.0.0 -->
+			                <inputDirectories>
+			                    <include>src/main/proto/caffe</include>
+			                </inputDirectories>
+			                <outputDirectories>
+			                    <include>src/main/java</include>
+			                </outputDirectories>
+			            </configuration>
+			        </execution>
+			        <execution>
+			        	<id>tf-sources</id>
+			            <phase>generate-sources</phase>
+			            <goals>
+			                <goal>run</goal>
+			            </goals>
+			            <configuration>
+			                <protocVersion>3.0.0</protocVersion> <!-- 2.4.1, 2.5.0, 2.6.1, 3.0.0 -->
+			                <inputDirectories>
+			                    <include>src/main/proto/tensorflow</include>
+			                </inputDirectories>
+			                <outputDirectories>
+			                    <include>src/main/java</include>
+			                </outputDirectories>
+			            </configuration>
+			        </execution>
+			    </executions>
+			</plugin>
 
 			<!-- Currently, all tests are integration tests. -->
 			<plugin>
@@ -1076,7 +1116,12 @@
 
 
 	<dependencies>
-
+		<dependency>
+	  		<groupId>com.google.protobuf</groupId>
+	  		<artifactId>protobuf-java</artifactId>
+	  		<version>3.2.0</version>
+	  		<scope>provided</scope>
+	  	</dependency>
 		<dependency>
 			<groupId>org.jcuda</groupId>
 			<artifactId>jcuda</artifactId>

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateUnaryCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index 8790a53..8dd372a 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -121,7 +121,7 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction
 						rval = mc.getRows() * mc.getCols();
 				}
 				else {
-					throw new DMLRuntimeException("Invalid meta data returned by '"+opcode+"': "+rval);
+					throw new DMLRuntimeException("Invalid meta data returned by '"+opcode+"': "+rval + ":" + instString);
 				}
 			}
 			

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
index 80b20cd..814cf22 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
@@ -22,6 +22,18 @@ package org.apache.sysml.runtime.util;
 
 public class ConvolutionUtils {
 	
+	public static String getConv2dOutputMap(String H, String R, String verticalStride, String heightPadding) {
+		long padX2 = -1;
+		try {
+			padX2 = Long.parseLong(heightPadding)*2;
+			return "" + getP(Long.parseLong(H), Long.parseLong(R), Long.parseLong(verticalStride), Long.parseLong(heightPadding));
+		} catch(Exception e) {
+			  if(padX2 == -1) 			return "((" + H + " + 2*" + heightPadding + " - " + R + ") / " + verticalStride + "+ 1)";
+			  else if(padX2 == 0) 	return "((" + H + " - " + R + ") / " + verticalStride + "+ 1)";
+			  else 									return "((" + H + " + " + padX2 + " - " + R + ") / " + verticalStride + "+ 1)";
+		}
+	}
+	
 	public static long getP(long H, long R, long verticalStride, long heightPadding) {
 		long ret = (H + 2 * heightPadding - R) / verticalStride + 1;
 		if(ret <= 0) {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/java/org/apache/sysml/udf/lib/Caffe2DMLVisualizeWrapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/udf/lib/Caffe2DMLVisualizeWrapper.java b/src/main/java/org/apache/sysml/udf/lib/Caffe2DMLVisualizeWrapper.java
new file mode 100644
index 0000000..15c867b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/udf/lib/Caffe2DMLVisualizeWrapper.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.udf.lib;
+
+import org.apache.sysml.udf.FunctionParameter;
+import org.apache.sysml.udf.PackageFunction;
+import org.apache.sysml.udf.Scalar;
+import org.apache.sysml.udf.Scalar.ScalarValueType;
+import org.apache.sysml.utils.TensorboardLogger;
+
+public class Caffe2DMLVisualizeWrapper extends PackageFunction 
+{
+	private static final long serialVersionUID = 1L;
+	private Scalar _ret;
+
+	@Override
+	public int getNumFunctionOutputs() {
+		return 1;
+	}
+
+	@Override
+	public FunctionParameter getFunctionOutput(int pos) {
+		if (pos == 0)
+			return _ret;
+
+		throw new RuntimeException(
+				"Invalid function output being requested");
+	}
+
+	@Override
+	public void execute() {
+		String layerName = ((Scalar) this.getFunctionInput(0)).getValue();
+		String varType = ((Scalar) this.getFunctionInput(1)).getValue();
+		String aggFn = ((Scalar) this.getFunctionInput(2)).getValue();
+		double x = Double.parseDouble(((Scalar) this.getFunctionInput(3)).getValue());
+		double y = Double.parseDouble(((Scalar) this.getFunctionInput(4)).getValue());
+		String logDir = ((Scalar) this.getFunctionInput(5)).getValue();
+
+		String key = null;
+		if(aggFn.equals("training_loss") || aggFn.equals("validation_loss") ||
+				aggFn.equals("training_accuracy") || aggFn.equals("validation_accuracy"))
+			key = aggFn;
+		else
+			key = aggFn + "_" + varType + "_" + layerName;
+		TensorboardLogger.writeScalar(logDir, key, (long)x, (float)y);
+		_ret = new Scalar(ScalarValueType.Double, String.valueOf(1));
+	}
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/java/org/apache/sysml/utils/TensorboardLogger.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/TensorboardLogger.java b/src/main/java/org/apache/sysml/utils/TensorboardLogger.java
new file mode 100644
index 0000000..245d757
--- /dev/null
+++ b/src/main/java/org/apache/sysml/utils/TensorboardLogger.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.utils;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.zip.Checksum;
+import org.tensorflow.framework.Summary;
+import org.tensorflow.util.Event;
+
+import com.google.common.primitives.Ints;
+import com.google.common.primitives.Longs;
+
+public class TensorboardLogger {
+	private static Crc32c crc32 = new Crc32c();
+	
+	/**
+	 * Writes scalar of given value in tensorboard format 
+	 * 
+	 * @param logDir log directory of tensorboard 
+	 * @param tag scalar tag (for example: training_loss, validation_loss, ...)
+	 * @param step usually the iteration number
+	 * @param value value of the scalar
+	 */
+	public static void writeScalar(String logDir, String tag, long step, float value) {
+		String filePath = logDir + File.separator + "tfevents.event_systemml_scalar"; 
+		try {
+			FileOutputStream outputStream = new FileOutputStream(filePath, true);
+			Event event = Event.newBuilder()
+					.setWallTime(System.currentTimeMillis() / 1e3)
+					.setStep(step)
+					.setSummary(Summary.newBuilder().addValue(
+							Summary.Value.newBuilder().setTag(tag).setSimpleValue(value)
+							).build())
+							.build();
+			byte[] eventString = event.toByteArray();
+			byte[] header = reverse(Longs.toByteArray((long)eventString.length));
+			write(outputStream, header);
+			write(outputStream, eventString);
+			outputStream.close();
+		}
+		catch(IOException e) {
+			throw new RuntimeException("Error writing event in tensorboard directory:" + filePath, e);
+		}
+	}
+
+	private static void write(FileOutputStream outputStream, byte[] byteString) throws IOException {
+		outputStream.write(byteString);
+		outputStream.write(reverse(Ints.toByteArray((int)maskedCRC32(byteString))));
+	}
+
+	private static byte[] reverse(byte[] nums) {
+		byte[] reversed = new byte[nums.length];
+		for (int i=0; i<nums.length; i++) {
+			reversed[i] = nums[nums.length - 1 - i];
+		}
+		return reversed;
+	}
+
+	private static long maskedCRC32(byte[] data){
+		crc32.reset();
+		crc32.update(data, 0, data.length);
+		long x = u32(crc32.getValue());
+		return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8);
+	}
+
+	private static long u32(long x){
+		return x & 0xffffffff;
+	}
+}
+
+class Crc32c implements Checksum {
+	private static final int[] crcTable = {
+		0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4,
+		0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB,
+		0x8AD958CF, 0x78B2DBCC, 0x6BE22838, 0x9989AB3B,
+		0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24,
+		0x105EC76F, 0xE235446C, 0xF165B798, 0x030E349B,
+		0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384,
+		0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, 0x89D76C54,
+		0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B,
+		0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A,
+		0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35,
+		0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5,
+		0x6DFE410E, 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA,
+		0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45,
+		0xF779DEAE, 0x05125DAD, 0x1642AE59, 0xE4292D5A,
+		0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A,
+		0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595,
+		0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48,
+		0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957,
+		0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687,
+		0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198,
+		0x5125DAD3, 0xA34E59D0, 0xB01EAA24, 0x42752927,
+		0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38,
+		0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8,
+		0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7,
+		0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096,
+		0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789,
+		0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859,
+		0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46,
+		0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9,
+		0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6,
+		0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36,
+		0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829,
+		0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C,
+		0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93,
+		0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043,
+		0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C,
+		0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3,
+		0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC,
+		0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C,
+		0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033,
+		0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652,
+		0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D,
+		0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D,
+		0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982,
+		0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D,
+		0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622,
+		0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2,
+		0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED,
+		0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530,
+		0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F,
+		0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF,
+		0x8ECEE914, 0x7CA56A17, 0x6FF599E3, 0x9D9E1AE0,
+		0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F,
+		0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540,
+		0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90,
+		0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F,
+		0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE,
+		0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1,
+		0x69E9F0D5, 0x9B8273D6, 0x88D28022, 0x7AB90321,
+		0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E,
+		0xF36E6F75, 0x0105EC76, 0x12551F82, 0xE03E9C81,
+		0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E,
+		0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E,
+		0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351,
+	};
+
+	private int crc = ~0;
+
+	public void update(byte[] buffer, int offset, int length) {
+		for (int i = offset; i < offset + length; i++) {
+			crc = crc32c(crc, buffer[i]);
+		}
+	}
+	public long getValue() {
+		return (crc ^ 0xFFFFFFFFL) & 0xFFFFFFFFL;
+	}
+	public void reset() {
+		crc = ~0;
+	}
+	private static int crc32c(int crc, int b) {
+		return crc >>> 8 ^ crcTable[(crc ^ b & 0xFF) & 0xFF];
+	}
+	public void update(int arg0) {
+		throw new RuntimeException("Not implemented");
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/proto/caffe/caffe.proto
----------------------------------------------------------------------
diff --git a/src/main/proto/caffe/caffe.proto b/src/main/proto/caffe/caffe.proto
new file mode 100644
index 0000000..cf53e17
--- /dev/null
+++ b/src/main/proto/caffe/caffe.proto
@@ -0,0 +1,1424 @@
+//-------------------------------------------------------------
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+//-------------------------------------------------------------
+
+syntax = "proto2";
+
+package caffe;
+
+// Specifies the shape (dimensions) of a Blob.
+message BlobShape {
+  repeated int64 dim = 1 [packed = true];
+}
+
+message BlobProto {
+  optional BlobShape shape = 7;
+  repeated float data = 5 [packed = true];
+  repeated float diff = 6 [packed = true];
+  repeated double double_data = 8 [packed = true];
+  repeated double double_diff = 9 [packed = true];
+
+  // 4D dimensions -- deprecated.  Use "shape" instead.
+  optional int32 num = 1 [default = 0];
+  optional int32 channels = 2 [default = 0];
+  optional int32 height = 3 [default = 0];
+  optional int32 width = 4 [default = 0];
+}
+
+// The BlobProtoVector is simply a way to pass multiple blobproto instances
+// around.
+message BlobProtoVector {
+  repeated BlobProto blobs = 1;
+}
+
+message Datum {
+  optional int32 channels = 1;
+  optional int32 height = 2;
+  optional int32 width = 3;
+  // the actual image data, in bytes
+  optional bytes data = 4;
+  optional int32 label = 5;
+  // Optionally, the datum could also hold float data.
+  repeated float float_data = 6;
+  // If true data contains an encoded image that need to be decoded
+  optional bool encoded = 7 [default = false];
+}
+
+message FillerParameter {
+  // The filler type.
+  optional string type = 1 [default = 'constant'];
+  optional float value = 2 [default = 0]; // the value in constant filler
+  optional float min = 3 [default = 0]; // the min value in uniform filler
+  optional float max = 4 [default = 1]; // the max value in uniform filler
+  optional float mean = 5 [default = 0]; // the mean value in Gaussian filler
+  optional float std = 6 [default = 1]; // the std value in Gaussian filler
+  // The expected number of non-zero output weights for a given input in
+  // Gaussian filler -- the default -1 means don't perform sparsification.
+  optional int32 sparse = 7 [default = -1];
+  // Normalize the filler variance by fan_in, fan_out, or their average.
+  // Applies to 'xavier' and 'msra' fillers.
+  enum VarianceNorm {
+    FAN_IN = 0;
+    FAN_OUT = 1;
+    AVERAGE = 2;
+  }
+  optional VarianceNorm variance_norm = 8 [default = FAN_IN];
+}
+
+message NetParameter {
+  optional string name = 1; // consider giving the network a name
+  // DEPRECATED. See InputParameter. The input blobs to the network.
+  repeated string input = 3;
+  // DEPRECATED. See InputParameter. The shape of the input blobs.
+  repeated BlobShape input_shape = 8;
+
+  // 4D input dimensions -- deprecated.  Use "input_shape" instead.
+  // If specified, for each input blob there should be four
+  // values specifying the num, channels, height and width of the input blob.
+  // Thus, there should be a total of (4 * #input) numbers.
+  repeated int32 input_dim = 4;
+
+  // Whether the network will force every layer to carry out backward operation.
+  // If set False, then whether to carry out backward is determined
+  // automatically according to the net structure and learning rates.
+  optional bool force_backward = 5 [default = false];
+  // The current "state" of the network, including the phase, level, and stage.
+  // Some layers may be included/excluded depending on this state and the states
+  // specified in the layers' include and exclude fields.
+  optional NetState state = 6;
+
+  // Print debugging information about results while running Net::Forward,
+  // Net::Backward, and Net::Update.
+  optional bool debug_info = 7 [default = false];
+
+  // The layers that make up the net.  Each of their configurations, including
+  // connectivity and behavior, is specified as a LayerParameter.
+  repeated LayerParameter layer = 100;  // ID 100 so layers are printed last.
+
+  // DEPRECATED: use 'layer' instead.
+  repeated V1LayerParameter layers = 2;
+}
+
+// NOTE
+// Update the next available ID when you add a new SolverParameter field.
+//
+// SolverParameter next available ID: 43 (last added: test_algo)
+message SolverParameter {
+  //////////////////////////////////////////////////////////////////////////////
+  // Specifying the train and test networks
+  //
+  // Exactly one train net must be specified using one of the following fields:
+  //     train_net_param, train_net, net_param, net
+  // One or more test nets may be specified using any of the following fields:
+  //     test_net_param, test_net, net_param, net
+  // If more than one test net field is specified (e.g., both net and
+  // test_net are specified), they will be evaluated in the field order given
+  // above: (1) test_net_param, (2) test_net, (3) net_param/net.
+  // A test_iter must be specified for each test_net.
+  // A test_level and/or a test_stage may also be specified for each test_net.
+  //////////////////////////////////////////////////////////////////////////////
+  
+  // SystemML extension
+  optional string train_algo = 41 [default = "minibatch"];
+  optional string test_algo = 42 [default = "minibatch"];
+
+  // Proto filename for the train net, possibly combined with one or more
+  // test nets.
+  optional string net = 24;
+  // Inline train net param, possibly combined with one or more test nets.
+  optional NetParameter net_param = 25;
+
+  optional string train_net = 1; // Proto filename for the train net.
+  repeated string test_net = 2; // Proto filenames for the test nets.
+  optional NetParameter train_net_param = 21; // Inline train net params.
+  repeated NetParameter test_net_param = 22; // Inline test net params.
+
+  // The states for the train/test nets. Must be unspecified or
+  // specified once per net.
+  //
+  // By default, all states will have solver = true;
+  // train_state will have phase = TRAIN,
+  // and all test_state's will have phase = TEST.
+  // Other defaults are set according to the NetState defaults.
+  optional NetState train_state = 26;
+  repeated NetState test_state = 27;
+
+  // The number of iterations for each test net.
+  repeated int32 test_iter = 3;
+
+  // The number of iterations between two testing phases.
+  optional int32 test_interval = 4 [default = 0];
+  optional bool test_compute_loss = 19 [default = false];
+  // If true, run an initial test pass before the first iteration,
+  // ensuring memory availability and printing the starting value of the loss.
+  optional bool test_initialization = 32 [default = true];
+  optional float base_lr = 5; // The base learning rate
+  // the number of iterations between displaying info. If display = 0, no info
+  // will be displayed.
+  optional int32 display = 6;
+  // Display the loss averaged over the last average_loss iterations
+  optional int32 average_loss = 33 [default = 1];
+  optional int32 max_iter = 7; // the maximum number of iterations
+  // accumulate gradients over `iter_size` x `batch_size` instances
+  optional int32 iter_size = 36 [default = 1];
+
+  // The learning rate decay policy. The currently implemented learning rate
+  // policies are as follows:
+  //    - fixed: always return base_lr.
+  //    - step: return base_lr * gamma ^ (floor(iter / step))
+  //    - exp: return base_lr * gamma ^ iter
+  //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
+  //    - multistep: similar to step but it allows non uniform steps defined by
+  //      stepvalue
+  //    - poly: the effective learning rate follows a polynomial decay, to be
+  //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+  //    - sigmoid: the effective learning rate follows a sigmod decay
+  //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+  //
+  // where base_lr, max_iter, gamma, step, stepvalue and power are defined
+  // in the solver parameter protocol buffer, and iter is the current iteration.
+  optional string lr_policy = 8;
+  optional float gamma = 9; // The parameter to compute the learning rate.
+  optional float power = 10; // The parameter to compute the learning rate.
+  optional float momentum = 11; // The momentum value.
+  optional float weight_decay = 12; // The weight decay.
+  // regularization types supported: L1 and L2
+  // controlled by weight_decay
+  optional string regularization_type = 29 [default = "L2"];
+  // the stepsize for learning rate policy "step"
+  optional int32 stepsize = 13;
+  // the stepsize for learning rate policy "multistep"
+  repeated int32 stepvalue = 34;
+
+  // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
+  // whenever their actual L2 norm is larger.
+  optional float clip_gradients = 35 [default = -1];
+
+  optional int32 snapshot = 14 [default = 0]; // The snapshot interval
+  optional string snapshot_prefix = 15; // The prefix for the snapshot.
+  // whether to snapshot diff in the results or not. Snapshotting diff will help
+  // debugging but the final protocol buffer size will be much larger.
+  optional bool snapshot_diff = 16 [default = false];
+  enum SnapshotFormat {
+    HDF5 = 0;
+    BINARYPROTO = 1;
+  }
+  optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
+  // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
+  enum SolverMode {
+    CPU = 0;
+    GPU = 1;
+  }
+  optional SolverMode solver_mode = 17 [default = GPU];
+  // the device_id will that be used in GPU mode. Use device_id = 0 in default.
+  optional int32 device_id = 18 [default = 0];
+  // If non-negative, the seed with which the Solver will initialize the Caffe
+  // random number generator -- useful for reproducible results. Otherwise,
+  // (and by default) initialize using a seed derived from the system clock.
+  optional int64 random_seed = 20 [default = -1];
+
+  // type of the solver
+  optional string type = 40 [default = "SGD"];
+
+  // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
+  optional float delta = 31 [default = 1e-8];
+  // parameters for the Adam solver
+  optional float momentum2 = 39 [default = 0.999];
+
+  // RMSProp decay value
+  // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
+  optional float rms_decay = 38 [default = 0.99];
+
+  // If true, print information about the state of the net that may help with
+  // debugging learning problems.
+  optional bool debug_info = 23 [default = false];
+
+  // If false, don't save a snapshot after training finishes.
+  optional bool snapshot_after_train = 28 [default = true];
+
+  // DEPRECATED: old solver enum types, use string instead
+  enum SolverType {
+    SGD = 0;
+    NESTEROV = 1;
+    ADAGRAD = 2;
+    RMSPROP = 3;
+    ADADELTA = 4;
+    ADAM = 5;
+  }
+  // DEPRECATED: use type instead of solver_type
+  optional SolverType solver_type = 30 [default = SGD];
+}
+
+// A message that stores the solver snapshots
+message SolverState {
+  optional int32 iter = 1; // The current iteration
+  optional string learned_net = 2; // The file that stores the learned net.
+  repeated BlobProto history = 3; // The history for sgd solvers
+  optional int32 current_step = 4 [default = 0]; // The current step for learning rate
+}
+
+enum Phase {
+   TRAIN = 0;
+   TEST = 1;
+}
+
+message NetState {
+  optional Phase phase = 1 [default = TEST];
+  optional int32 level = 2 [default = 0];
+  repeated string stage = 3;
+}
+
+message NetStateRule {
+  // Set phase to require the NetState have a particular phase (TRAIN or TEST)
+  // to meet this rule.
+  optional Phase phase = 1;
+
+  // Set the minimum and/or maximum levels in which the layer should be used.
+  // Leave undefined to meet the rule regardless of level.
+  optional int32 min_level = 2;
+  optional int32 max_level = 3;
+
+  // Customizable sets of stages to include or exclude.
+  // The net must have ALL of the specified stages and NONE of the specified
+  // "not_stage"s to meet the rule.
+  // (Use multiple NetStateRules to specify conjunctions of stages.)
+  repeated string stage = 4;
+  repeated string not_stage = 5;
+}
+
+// Specifies training parameters (multipliers on global learning constants,
+// and the name and other settings used for weight sharing).
+message ParamSpec {
+  // The names of the parameter blobs -- useful for sharing parameters among
+  // layers, but never required otherwise.  To share a parameter between two
+  // layers, give it a (non-empty) name.
+  optional string name = 1;
+
+  // Whether to require shared weights to have the same shape, or just the same
+  // count -- defaults to STRICT if unspecified.
+  optional DimCheckMode share_mode = 2;
+  enum DimCheckMode {
+    // STRICT (default) requires that num, channels, height, width each match.
+    STRICT = 0;
+    // PERMISSIVE requires only the count (num*channels*height*width) to match.
+    PERMISSIVE = 1;
+  }
+
+  // The multiplier on the global learning rate for this parameter.
+  optional float lr_mult = 3 [default = 1.0];
+
+  // The multiplier on the global weight decay for this parameter.
+  optional float decay_mult = 4 [default = 1.0];
+}
+
+// NOTE
+// Update the next available ID when you add a new LayerParameter field.
+//
+// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param)
+message LayerParameter {
+  optional string name = 1; // the layer name
+  optional string type = 2; // the layer type
+  repeated string bottom = 3; // the name of each bottom blob
+  repeated string top = 4; // the name of each top blob
+
+  // The train / test phase for computation.
+  optional Phase phase = 10;
+
+  // The amount of weight to assign each top blob in the objective.
+  // Each layer assigns a default value, usually of either 0 or 1,
+  // to each top blob.
+  repeated float loss_weight = 5;
+
+  // Specifies training parameters (multipliers on global learning constants,
+  // and the name and other settings used for weight sharing).
+  repeated ParamSpec param = 6;
+
+  // The blobs containing the numeric parameters of the layer.
+  repeated BlobProto blobs = 7;
+
+  // Specifies whether to backpropagate to each bottom. If unspecified,
+  // Caffe will automatically infer whether each input needs backpropagation
+  // to compute parameter gradients. If set to true for some inputs,
+  // backpropagation to those inputs is forced; if set false for some inputs,
+  // backpropagation to those inputs is skipped.
+  //
+  // The size must be either 0 or equal to the number of bottoms.
+  repeated bool propagate_down = 11;
+
+  // Rules controlling whether and when a layer is included in the network,
+  // based on the current NetState.  You may specify a non-zero number of rules
+  // to include OR exclude, but not both.  If no include or exclude rules are
+  // specified, the layer is always included.  If the current NetState meets
+  // ANY (i.e., one or more) of the specified rules, the layer is
+  // included/excluded.
+  repeated NetStateRule include = 8;
+  repeated NetStateRule exclude = 9;
+
+  // Parameters for data pre-processing.
+  optional TransformationParameter transform_param = 100;
+
+  // Parameters shared by loss layers.
+  optional LossParameter loss_param = 101;
+
+  // Layer type-specific parameters.
+  //
+  // Note: certain layers may have more than one computational engine
+  // for their implementation. These layers include an Engine type and
+  // engine parameter for selecting the implementation.
+  // The default for the engine is set by the ENGINE switch at compile-time.
+  optional AccuracyParameter accuracy_param = 102;
+  optional ArgMaxParameter argmax_param = 103;
+  optional BatchNormParameter batch_norm_param = 139;
+  optional BiasParameter bias_param = 141;
+  optional ConcatParameter concat_param = 104;
+  optional ContrastiveLossParameter contrastive_loss_param = 105;
+  optional ConvolutionParameter convolution_param = 106;
+  optional CropParameter crop_param = 144;
+  optional DataParameter data_param = 107;
+  optional DropoutParameter dropout_param = 108;
+  optional DummyDataParameter dummy_data_param = 109;
+  optional EltwiseParameter eltwise_param = 110;
+  optional ELUParameter elu_param = 140;
+  optional EmbedParameter embed_param = 137;
+  optional ExpParameter exp_param = 111;
+  optional FlattenParameter flatten_param = 135;
+  optional HDF5DataParameter hdf5_data_param = 112;
+  optional HDF5OutputParameter hdf5_output_param = 113;
+  optional HingeLossParameter hinge_loss_param = 114;
+  optional ImageDataParameter image_data_param = 115;
+  optional InfogainLossParameter infogain_loss_param = 116;
+  optional InnerProductParameter inner_product_param = 117;
+  optional InputParameter input_param = 143;
+  optional LogParameter log_param = 134;
+  optional LRNParameter lrn_param = 118;
+  optional MemoryDataParameter memory_data_param = 119;
+  optional MVNParameter mvn_param = 120;
+  optional ParameterParameter parameter_param = 145;
+  optional PoolingParameter pooling_param = 121;
+  optional PowerParameter power_param = 122;
+  optional PReLUParameter prelu_param = 131;
+  optional PythonParameter python_param = 130;
+  optional RecurrentParameter recurrent_param = 146;
+  optional ReductionParameter reduction_param = 136;
+  optional ReLUParameter relu_param = 123;
+  optional ReshapeParameter reshape_param = 133;
+  optional ScaleParameter scale_param = 142;
+  optional SigmoidParameter sigmoid_param = 124;
+  optional SoftmaxParameter softmax_param = 125;
+  optional SPPParameter spp_param = 132;
+  optional SliceParameter slice_param = 126;
+  optional TanHParameter tanh_param = 127;
+  optional ThresholdParameter threshold_param = 128;
+  optional TileParameter tile_param = 138;
+  optional WindowDataParameter window_data_param = 129;
+}
+
+// Message that stores parameters used to apply transformation
+// to the data layer's data
+message TransformationParameter {
+  // For data pre-processing, we can do simple scaling and subtracting the
+  // data mean, if provided. Note that the mean subtraction is always carried
+  // out before scaling.
+  optional float scale = 1 [default = 1];
+  // Specify if we want to randomly mirror data.
+  optional bool mirror = 2 [default = false];
+  // Specify if we would like to randomly crop an image.
+  optional uint32 crop_size = 3 [default = 0];
+  // mean_file and mean_value cannot be specified at the same time
+  optional string mean_file = 4;
+  // if specified can be repeated once (would substract it from all the channels)
+  // or can be repeated the same number of times as channels
+  // (would subtract them from the corresponding channel)
+  repeated float mean_value = 5;
+  // Force the decoded image to have 3 color channels.
+  optional bool force_color = 6 [default = false];
+  // Force the decoded image to have 1 color channels.
+  optional bool force_gray = 7 [default = false];
+}
+
+// Message that stores parameters shared by loss layers
+message LossParameter {
+  // If specified, ignore instances with the given label.
+  optional int32 ignore_label = 1;
+  // How to normalize the loss for loss layers that aggregate across batches,
+  // spatial dimensions, or other dimensions.  Currently only implemented in
+  // SoftmaxWithLoss layer.
+  enum NormalizationMode {
+    // Divide by the number of examples in the batch times spatial dimensions.
+    // Outputs that receive the ignore label will NOT be ignored in computing
+    // the normalization factor.
+    FULL = 0;
+    // Divide by the total number of output locations that do not take the
+    // ignore_label.  If ignore_label is not set, this behaves like FULL.
+    VALID = 1;
+    // Divide by the batch size.
+    BATCH_SIZE = 2;
+    // Do not normalize the loss.
+    NONE = 3;
+  }
+  optional NormalizationMode normalization = 3 [default = VALID];
+  // Deprecated.  Ignored if normalization is specified.  If normalization
+  // is not specified, then setting this to false will be equivalent to
+  // normalization = BATCH_SIZE to be consistent with previous behavior.
+  optional bool normalize = 2;
+}
+
+// Messages that store parameters used by individual layer types follow, in
+// alphabetical order.
+
+message AccuracyParameter {
+  // When computing accuracy, count as correct by comparing the true label to
+  // the top k scoring classes.  By default, only compare to the top scoring
+  // class (i.e. argmax).
+  optional uint32 top_k = 1 [default = 1];
+
+  // The "label" axis of the prediction blob, whose argmax corresponds to the
+  // predicted label -- may be negative to index from the end (e.g., -1 for the
+  // last axis).  For example, if axis == 1 and the predictions are
+  // (N x C x H x W), the label blob is expected to contain N*H*W ground truth
+  // labels with integer values in {0, 1, ..., C-1}.
+  optional int32 axis = 2 [default = 1];
+
+  // If specified, ignore instances with the given label.
+  optional int32 ignore_label = 3;
+}
+
+message ArgMaxParameter {
+  // If true produce pairs (argmax, maxval)
+  optional bool out_max_val = 1 [default = false];
+  optional uint32 top_k = 2 [default = 1];
+  // The axis along which to maximise -- may be negative to index from the
+  // end (e.g., -1 for the last axis).
+  // By default ArgMaxLayer maximizes over the flattened trailing dimensions
+  // for each index of the first / num dimension.
+  optional int32 axis = 3;
+}
+
+message ConcatParameter {
+  // The axis along which to concatenate -- may be negative to index from the
+  // end (e.g., -1 for the last axis).  Other axes must have the
+  // same dimension for all the bottom blobs.
+  // By default, ConcatLayer concatenates blobs along the "channels" axis (1).
+  optional int32 axis = 2 [default = 1];
+
+  // DEPRECATED: alias for "axis" -- does not support negative indexing.
+  optional uint32 concat_dim = 1 [default = 1];
+}
+
+message BatchNormParameter {
+  // If false, accumulate global mean/variance values via a moving average. If
+  // true, use those accumulated values instead of computing mean/variance
+  // across the batch.
+  optional bool use_global_stats = 1;
+  // How much does the moving average decay each iteration?
+  optional float moving_average_fraction = 2 [default = .999];
+  // Small value to add to the variance estimate so that we don't divide by
+  // zero.
+  optional float eps = 3 [default = 1e-5];
+}
+
+message BiasParameter {
+  // The first axis of bottom[0] (the first input Blob) along which to apply
+  // bottom[1] (the second input Blob).  May be negative to index from the end
+  // (e.g., -1 for the last axis).
+  //
+  // For example, if bottom[0] is 4D with shape 100x3x40x60, the output
+  // top[0] will have the same shape, and bottom[1] may have any of the
+  // following shapes (for the given value of axis):
+  //    (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60
+  //    (axis == 1 == -3)          3;     3x40;     3x40x60
+  //    (axis == 2 == -2)                   40;       40x60
+  //    (axis == 3 == -1)                                60
+  // Furthermore, bottom[1] may have the empty shape (regardless of the value of
+  // "axis") -- a scalar bias.
+  optional int32 axis = 1 [default = 1];
+
+  // (num_axes is ignored unless just one bottom is given and the bias is
+  // a learned parameter of the layer.  Otherwise, num_axes is determined by the
+  // number of axes by the second bottom.)
+  // The number of axes of the input (bottom[0]) covered by the bias
+  // parameter, or -1 to cover all axes of bottom[0] starting from `axis`.
+  // Set num_axes := 0, to add a zero-axis Blob: a scalar.
+  optional int32 num_axes = 2 [default = 1];
+
+  // (filler is ignored unless just one bottom is given and the bias is
+  // a learned parameter of the layer.)
+  // The initialization for the learned bias parameter.
+  // Default is the zero (0) initialization, resulting in the BiasLayer
+  // initially performing the identity operation.
+  optional FillerParameter filler = 3;
+}
+
+message ContrastiveLossParameter {
+  // margin for dissimilar pair
+  optional float margin = 1 [default = 1.0];
+  // The first implementation of this cost did not exactly match the cost of
+  // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2.
+  // legacy_version = false (the default) uses (margin - d)^2 as proposed in the
+  // Hadsell paper. New models should probably use this version.
+  // legacy_version = true uses (margin - d^2). This is kept to support /
+  // reproduce existing models and results
+  optional bool legacy_version = 2 [default = false];
+}
+
+message ConvolutionParameter {
+  optional uint32 num_output = 1; // The number of outputs for the layer
+  optional bool bias_term = 2 [default = true]; // whether to have bias terms
+
+  // Pad, kernel size, and stride are all given as a single value for equal
+  // dimensions in all spatial dimensions, or once per spatial dimension.
+  repeated uint32 pad = 3; // The padding size; defaults to 0
+  repeated uint32 kernel_size = 4; // The kernel size
+  repeated uint32 stride = 6; // The stride; defaults to 1
+  // Factor used to dilate the kernel, (implicitly) zero-filling the resulting
+  // holes. (Kernel dilation is sometimes referred to by its use in the
+  // algorithme � trous from Holschneider et al. 1987.)
+  repeated uint32 dilation = 18; // The dilation; defaults to 1
+
+  // For 2D convolution only, the *_h and *_w versions may also be used to
+  // specify both spatial dimensions.
+  optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only)
+  optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only)
+  optional uint32 kernel_h = 11; // The kernel height (2D only)
+  optional uint32 kernel_w = 12; // The kernel width (2D only)
+  optional uint32 stride_h = 13; // The stride height (2D only)
+  optional uint32 stride_w = 14; // The stride width (2D only)
+
+  optional uint32 group = 5 [default = 1]; // The group size for group conv
+
+  optional FillerParameter weight_filler = 7; // The filler for the weight
+  optional FillerParameter bias_filler = 8; // The filler for the bias
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 15 [default = DEFAULT];
+
+  // The axis to interpret as "channels" when performing convolution.
+  // Preceding dimensions are treated as independent inputs;
+  // succeeding dimensions are treated as "spatial".
+  // With (N, C, H, W) inputs, and axis == 1 (the default), we perform
+  // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for
+  // groups g>1) filters across the spatial axes (H, W) of the input.
+  // With (N, C, D, H, W) inputs, and axis == 1, we perform
+  // N independent 3D convolutions, sliding (C/g)-channels
+  // filters across the spatial axes (D, H, W) of the input.
+  optional int32 axis = 16 [default = 1];
+
+  // Whether to force use of the general ND convolution, even if a specific
+  // implementation for blobs of the appropriate number of spatial dimensions
+  // is available. (Currently, there is only a 2D-specific convolution
+  // implementation; for input blobs with num_axes != 2, this option is
+  // ignored and the ND implementation will be used.)
+  optional bool force_nd_im2col = 17 [default = false];
+}
+
+message CropParameter {
+  // To crop, elements of the first bottom are selected to fit the dimensions
+  // of the second, reference bottom. The crop is configured by
+  // - the crop `axis` to pick the dimensions for cropping
+  // - the crop `offset` to set the shift for all/each dimension
+  // to align the cropped bottom with the reference bottom.
+  // All dimensions up to but excluding `axis` are preserved, while
+  // the dimensions including and trailing `axis` are cropped.
+  // If only one `offset` is set, then all dimensions are offset by this amount.
+  // Otherwise, the number of offsets must equal the number of cropped axes to
+  // shift the crop in each dimension accordingly.
+  // Note: standard dimensions are N,C,H,W so the default is a spatial crop,
+  // and `axis` may be negative to index from the end (e.g., -1 for the last
+  // axis).
+  optional int32 axis = 1 [default = 2];
+  repeated uint32 offset = 2;
+}
+
+message DataParameter {
+  enum DB {
+    LEVELDB = 0;
+    LMDB = 1;
+  }
+  // Specify the data source.
+  optional string source = 1;
+  // Specify the batch size.
+  optional uint32 batch_size = 4;
+  // The rand_skip variable is for the data layer to skip a few data points
+  // to avoid all asynchronous sgd clients to start at the same point. The skip
+  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
+  // be larger than the number of keys in the database.
+  // DEPRECATED. Each solver accesses a different subset of the database.
+  optional uint32 rand_skip = 7 [default = 0];
+  optional DB backend = 8 [default = LEVELDB];
+  // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
+  // simple scaling and subtracting the data mean, if provided. Note that the
+  // mean subtraction is always carried out before scaling.
+  optional float scale = 2 [default = 1];
+  optional string mean_file = 3;
+  // DEPRECATED. See TransformationParameter. Specify if we would like to randomly
+  // crop an image.
+  optional uint32 crop_size = 5 [default = 0];
+  // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
+  // data.
+  optional bool mirror = 6 [default = false];
+  // Force the encoded image to have 3 color channels
+  optional bool force_encoded_color = 9 [default = false];
+  // Prefetch queue (Number of batches to prefetch to host memory, increase if
+  // data access bandwidth varies).
+  optional uint32 prefetch = 10 [default = 4];
+}
+
+message DropoutParameter {
+  optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
+}
+
+// DummyDataLayer fills any number of arbitrarily shaped blobs with random
+// (or constant) data generated by "Fillers" (see "message FillerParameter").
+message DummyDataParameter {
+  // This layer produces N >= 1 top blobs.  DummyDataParameter must specify 1 or N
+  // shape fields, and 0, 1 or N data_fillers.
+  //
+  // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used.
+  // If 1 data_filler is specified, it is applied to all top blobs.  If N are
+  // specified, the ith is applied to the ith top blob.
+  repeated FillerParameter data_filler = 1;
+  repeated BlobShape shape = 6;
+
+  // 4D dimensions -- deprecated.  Use "shape" instead.
+  repeated uint32 num = 2;
+  repeated uint32 channels = 3;
+  repeated uint32 height = 4;
+  repeated uint32 width = 5;
+}
+
+message EltwiseParameter {
+  enum EltwiseOp {
+    PROD = 0;
+    SUM = 1;
+    MAX = 2;
+  }
+  optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation
+  repeated float coeff = 2; // blob-wise coefficient for SUM operation
+
+  // Whether to use an asymptotically slower (for >2 inputs) but stabler method
+  // of computing the gradient for the PROD operation. (No effect for SUM op.)
+  optional bool stable_prod_grad = 3 [default = true];
+}
+
+// Message that stores parameters used by ELULayer
+message ELUParameter {
+  // Described in:
+  // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
+  // Deep Network Learning by Exponential Linear Units (ELUs). arXiv
+  optional float alpha = 1 [default = 1];
+}
+
+// Message that stores parameters used by EmbedLayer
+message EmbedParameter {
+  optional uint32 num_output = 1; // The number of outputs for the layer
+  // The input is given as integers to be interpreted as one-hot
+  // vector indices with dimension num_input.  Hence num_input should be
+  // 1 greater than the maximum possible input value.
+  optional uint32 input_dim = 2;
+
+  optional bool bias_term = 3 [default = true]; // Whether to use a bias term
+  optional FillerParameter weight_filler = 4; // The filler for the weight
+  optional FillerParameter bias_filler = 5; // The filler for the bias
+
+}
+
+// Message that stores parameters used by ExpLayer
+message ExpParameter {
+  // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
+  // Or if base is set to the default (-1), base is set to e,
+  // so y = exp(shift + scale * x).
+  optional float base = 1 [default = -1.0];
+  optional float scale = 2 [default = 1.0];
+  optional float shift = 3 [default = 0.0];
+}
+
+/// Message that stores parameters used by FlattenLayer
+message FlattenParameter {
+  // The first axis to flatten: all preceding axes are retained in the output.
+  // May be negative to index from the end (e.g., -1 for the last axis).
+  optional int32 axis = 1 [default = 1];
+
+  // The last axis to flatten: all following axes are retained in the output.
+  // May be negative to index from the end (e.g., the default -1 for the last
+  // axis).
+  optional int32 end_axis = 2 [default = -1];
+}
+
+// Message that stores parameters used by HDF5DataLayer
+message HDF5DataParameter {
+  // Specify the data source.
+  optional string source = 1;
+  // Specify the batch size.
+  optional uint32 batch_size = 2;
+
+  // Specify whether to shuffle the data.
+  // If shuffle == true, the ordering of the HDF5 files is shuffled,
+  // and the ordering of data within any given HDF5 file is shuffled,
+  // but data between different files are not interleaved; all of a file's
+  // data are output (in a random order) before moving onto another file.
+  optional bool shuffle = 3 [default = false];
+}
+
+message HDF5OutputParameter {
+  optional string file_name = 1;
+}
+
+message HingeLossParameter {
+  enum Norm {
+    L1 = 1;
+    L2 = 2;
+  }
+  // Specify the Norm to use L1 or L2
+  optional Norm norm = 1 [default = L1];
+}
+
+message ImageDataParameter {
+  // Specify the data source.
+  optional string source = 1;
+  // Specify the batch size.
+  optional uint32 batch_size = 4 [default = 1];
+  // The rand_skip variable is for the data layer to skip a few data points
+  // to avoid all asynchronous sgd clients to start at the same point. The skip
+  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
+  // be larger than the number of keys in the database.
+  optional uint32 rand_skip = 7 [default = 0];
+  // Whether or not ImageLayer should shuffle the list of files at every epoch.
+  optional bool shuffle = 8 [default = false];
+  // It will also resize images if new_height or new_width are not zero.
+  optional uint32 new_height = 9 [default = 0];
+  optional uint32 new_width = 10 [default = 0];
+  // Specify if the images are color or gray
+  optional bool is_color = 11 [default = true];
+  // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
+  // simple scaling and subtracting the data mean, if provided. Note that the
+  // mean subtraction is always carried out before scaling.
+  optional float scale = 2 [default = 1];
+  optional string mean_file = 3;
+  // DEPRECATED. See TransformationParameter. Specify if we would like to randomly
+  // crop an image.
+  optional uint32 crop_size = 5 [default = 0];
+  // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
+  // data.
+  optional bool mirror = 6 [default = false];
+  optional string root_folder = 12 [default = ""];
+}
+
+message InfogainLossParameter {
+  // Specify the infogain matrix source.
+  optional string source = 1;
+}
+
+message InnerProductParameter {
+  optional uint32 num_output = 1; // The number of outputs for the layer
+  optional bool bias_term = 2 [default = true]; // whether to have bias terms
+  optional FillerParameter weight_filler = 3; // The filler for the weight
+  optional FillerParameter bias_filler = 4; // The filler for the bias
+
+  // The first axis to be lumped into a single inner product computation;
+  // all preceding axes are retained in the output.
+  // May be negative to index from the end (e.g., -1 for the last axis).
+  optional int32 axis = 5 [default = 1];
+  // Specify whether to transpose the weight matrix or not.
+  // If transpose == true, any operations will be performed on the transpose
+  // of the weight matrix. The weight matrix itself is not going to be transposed
+  // but rather the transfer flag of operations will be toggled accordingly.
+  optional bool transpose = 6 [default = false];
+}
+
+message InputParameter {
+  // This layer produces N >= 1 top blob(s) to be assigned manually.
+  // Define N shapes to set a shape for each top.
+  // Define 1 shape to set the same shape for every top.
+  // Define no shape to defer to reshaping manually.
+  repeated BlobShape shape = 1;
+}
+
+// Message that stores parameters used by LogLayer
+message LogParameter {
+  // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
+  // Or if base is set to the default (-1), base is set to e,
+  // so y = ln(shift + scale * x) = log_e(shift + scale * x)
+  optional float base = 1 [default = -1.0];
+  optional float scale = 2 [default = 1.0];
+  optional float shift = 3 [default = 0.0];
+}
+
+// Message that stores parameters used by LRNLayer
+message LRNParameter {
+  optional uint32 local_size = 1 [default = 5];
+  optional float alpha = 2 [default = 1.];
+  optional float beta = 3 [default = 0.75];
+  enum NormRegion {
+    ACROSS_CHANNELS = 0;
+    WITHIN_CHANNEL = 1;
+  }
+  optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS];
+  optional float k = 5 [default = 1.];
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 6 [default = DEFAULT];
+}
+
+message MemoryDataParameter {
+  optional uint32 batch_size = 1;
+  optional uint32 channels = 2;
+  optional uint32 height = 3;
+  optional uint32 width = 4;
+}
+
+message MVNParameter {
+  // This parameter can be set to false to normalize mean only
+  optional bool normalize_variance = 1 [default = true];
+
+  // This parameter can be set to true to perform DNN-like MVN
+  optional bool across_channels = 2 [default = false];
+
+  // Epsilon for not dividing by zero while normalizing variance
+  optional float eps = 3 [default = 1e-9];
+}
+
+message ParameterParameter {
+  optional BlobShape shape = 1;
+}
+
+message PoolingParameter {
+  enum PoolMethod {
+    MAX = 0;
+    AVE = 1;
+    STOCHASTIC = 2;
+  }
+  optional PoolMethod pool = 1 [default = MAX]; // The pooling method
+  // Pad, kernel size, and stride are all given as a single value for equal
+  // dimensions in height and width or as Y, X pairs.
+  optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X)
+  optional uint32 pad_h = 9 [default = 0]; // The padding height
+  optional uint32 pad_w = 10 [default = 0]; // The padding width
+  optional uint32 kernel_size = 2; // The kernel size (square)
+  optional uint32 kernel_h = 5; // The kernel height
+  optional uint32 kernel_w = 6; // The kernel width
+  optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X)
+  optional uint32 stride_h = 7; // The stride height
+  optional uint32 stride_w = 8; // The stride width
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 11 [default = DEFAULT];
+  // If global_pooling then it will pool over the size of the bottom by doing
+  // kernel_h = bottom->height and kernel_w = bottom->width
+  optional bool global_pooling = 12 [default = false];
+}
+
+message PowerParameter {
+  // PowerLayer computes outputs y = (shift + scale * x) ^ power.
+  optional float power = 1 [default = 1.0];
+  optional float scale = 2 [default = 1.0];
+  optional float shift = 3 [default = 0.0];
+}
+
+message PythonParameter {
+  optional string module = 1;
+  optional string layer = 2;
+  // This value is set to the attribute `param_str` of the `PythonLayer` object
+  // in Python before calling the `setup()` method. This could be a number,
+  // string, dictionary in Python dict format, JSON, etc. You may parse this
+  // string in `setup` method and use it in `forward` and `backward`.
+  optional string param_str = 3 [default = ''];
+  // Whether this PythonLayer is shared among worker solvers during data parallelism.
+  // If true, each worker solver sequentially run forward from this layer.
+  // This value should be set true if you are using it as a data layer.
+  optional bool share_in_parallel = 4 [default = false];
+}
+
+// Message that stores parameters used by RecurrentLayer
+message RecurrentParameter {
+  // The dimension of the output (and usually hidden state) representation --
+  // must be explicitly set to non-zero.
+  optional uint32 num_output = 1 [default = 0];
+
+  optional FillerParameter weight_filler = 2; // The filler for the weight
+  optional FillerParameter bias_filler = 3; // The filler for the bias
+
+  // Whether to enable displaying debug_info in the unrolled recurrent net.
+  optional bool debug_info = 4 [default = false];
+
+  // Whether to add as additional inputs (bottoms) the initial hidden state
+  // blobs, and add as additional outputs (tops) the final timestep hidden state
+  // blobs.  The number of additional bottom/top blobs required depends on the
+  // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs.
+  optional bool expose_hidden = 5 [default = false];
+}
+
+// Message that stores parameters used by ReductionLayer
+message ReductionParameter {
+  enum ReductionOp {
+    SUM = 1;
+    ASUM = 2;
+    SUMSQ = 3;
+    MEAN = 4;
+  }
+
+  optional ReductionOp operation = 1 [default = SUM]; // reduction operation
+
+  // The first axis to reduce to a scalar -- may be negative to index from the
+  // end (e.g., -1 for the last axis).
+  // (Currently, only reduction along ALL "tail" axes is supported; reduction
+  // of axis M through N, where N < num_axes - 1, is unsupported.)
+  // Suppose we have an n-axis bottom Blob with shape:
+  //     (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)).
+  // If axis == m, the output Blob will have shape
+  //     (d0, d1, d2, ..., d(m-1)),
+  // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1))
+  // times, each including (dm * d(m+1) * ... * d(n-1)) individual data.
+  // If axis == 0 (the default), the output Blob always has the empty shape
+  // (count 1), performing reduction across the entire input --
+  // often useful for creating new loss functions.
+  optional int32 axis = 2 [default = 0];
+
+  optional float coeff = 3 [default = 1.0]; // coefficient for output
+}
+
+// Message that stores parameters used by ReLULayer
+message ReLUParameter {
+  // Allow non-zero slope for negative inputs to speed up optimization
+  // Described in:
+  // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
+  // improve neural network acoustic models. In ICML Workshop on Deep Learning
+  // for Audio, Speech, and Language Processing.
+  optional float negative_slope = 1 [default = 0];
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 2 [default = DEFAULT];
+}
+
+message ReshapeParameter {
+  // Specify the output dimensions. If some of the dimensions are set to 0,
+  // the corresponding dimension from the bottom layer is used (unchanged).
+  // Exactly one dimension may be set to -1, in which case its value is
+  // inferred from the count of the bottom blob and the remaining dimensions.
+  // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8:
+  //
+  //   layer {
+  //     type: "Reshape" bottom: "input" top: "output"
+  //     reshape_param { ... }
+  //   }
+  //
+  // If "input" is 2D with shape 2 x 8, then the following reshape_param
+  // specifications are all equivalent, producing a 3D blob "output" with shape
+  // 2 x 2 x 4:
+  //
+  //   reshape_param { shape { dim:  2  dim: 2  dim:  4 } }
+  //   reshape_param { shape { dim:  0  dim: 2  dim:  4 } }
+  //   reshape_param { shape { dim:  0  dim: 2  dim: -1 } }
+  //   reshape_param { shape { dim:  0  dim:-1  dim:  4 } }
+  //
+  optional BlobShape shape = 1;
+
+  // axis and num_axes control the portion of the bottom blob's shape that are
+  // replaced by (included in) the reshape. By default (axis == 0 and
+  // num_axes == -1), the entire bottom blob shape is included in the reshape,
+  // and hence the shape field must specify the entire output shape.
+  //
+  // axis may be non-zero to retain some portion of the beginning of the input
+  // shape (and may be negative to index from the end; e.g., -1 to begin the
+  // reshape after the last axis, including nothing in the reshape,
+  // -2 to include only the last axis, etc.).
+  //
+  // For example, suppose "input" is a 2D blob with shape 2 x 8.
+  // Then the following ReshapeLayer specifications are all equivalent,
+  // producing a blob "output" with shape 2 x 2 x 4:
+  //
+  //   reshape_param { shape { dim: 2  dim: 2  dim: 4 } }
+  //   reshape_param { shape { dim: 2  dim: 4 } axis:  1 }
+  //   reshape_param { shape { dim: 2  dim: 4 } axis: -3 }
+  //
+  // num_axes specifies the extent of the reshape.
+  // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on
+  // input axes in the range [axis, axis+num_axes].
+  // num_axes may also be -1, the default, to include all remaining axes
+  // (starting from axis).
+  //
+  // For example, suppose "input" is a 2D blob with shape 2 x 8.
+  // Then the following ReshapeLayer specifications are equivalent,
+  // producing a blob "output" with shape 1 x 2 x 8.
+  //
+  //   reshape_param { shape { dim:  1  dim: 2  dim:  8 } }
+  //   reshape_param { shape { dim:  1  dim: 2  }  num_axes: 1 }
+  //   reshape_param { shape { dim:  1  }  num_axes: 0 }
+  //
+  // On the other hand, these would produce output blob shape 2 x 1 x 8:
+  //
+  //   reshape_param { shape { dim: 2  dim: 1  dim: 8  }  }
+  //   reshape_param { shape { dim: 1 }  axis: 1  num_axes: 0 }
+  //
+  optional int32 axis = 2 [default = 0];
+  optional int32 num_axes = 3 [default = -1];
+}
+
+message ScaleParameter {
+  // The first axis of bottom[0] (the first input Blob) along which to apply
+  // bottom[1] (the second input Blob).  May be negative to index from the end
+  // (e.g., -1 for the last axis).
+  //
+  // For example, if bottom[0] is 4D with shape 100x3x40x60, the output
+  // top[0] will have the same shape, and bottom[1] may have any of the
+  // following shapes (for the given value of axis):
+  //    (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60
+  //    (axis == 1 == -3)          3;     3x40;     3x40x60
+  //    (axis == 2 == -2)                   40;       40x60
+  //    (axis == 3 == -1)                                60
+  // Furthermore, bottom[1] may have the empty shape (regardless of the value of
+  // "axis") -- a scalar multiplier.
+  optional int32 axis = 1 [default = 1];
+
+  // (num_axes is ignored unless just one bottom is given and the scale is
+  // a learned parameter of the layer.  Otherwise, num_axes is determined by the
+  // number of axes by the second bottom.)
+  // The number of axes of the input (bottom[0]) covered by the scale
+  // parameter, or -1 to cover all axes of bottom[0] starting from `axis`.
+  // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar.
+  optional int32 num_axes = 2 [default = 1];
+
+  // (filler is ignored unless just one bottom is given and the scale is
+  // a learned parameter of the layer.)
+  // The initialization for the learned scale parameter.
+  // Default is the unit (1) initialization, resulting in the ScaleLayer
+  // initially performing the identity operation.
+  optional FillerParameter filler = 3;
+
+  // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but
+  // may be more efficient).  Initialized with bias_filler (defaults to 0).
+  optional bool bias_term = 4 [default = false];
+  optional FillerParameter bias_filler = 5;
+}
+
+message SigmoidParameter {
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 1 [default = DEFAULT];
+}
+
+message SliceParameter {
+  // The axis along which to slice -- may be negative to index from the end
+  // (e.g., -1 for the last axis).
+  // By default, SliceLayer concatenates blobs along the "channels" axis (1).
+  optional int32 axis = 3 [default = 1];
+  repeated uint32 slice_point = 2;
+
+  // DEPRECATED: alias for "axis" -- does not support negative indexing.
+  optional uint32 slice_dim = 1 [default = 1];
+}
+
+// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer
+message SoftmaxParameter {
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 1 [default = DEFAULT];
+
+  // The axis along which to perform the softmax -- may be negative to index
+  // from the end (e.g., -1 for the last axis).
+  // Any other axes will be evaluated as independent softmaxes.
+  optional int32 axis = 2 [default = 1];
+}
+
+message TanHParameter {
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 1 [default = DEFAULT];
+}
+
+// Message that stores parameters used by TileLayer
+message TileParameter {
+  // The index of the axis to tile.
+  optional int32 axis = 1 [default = 1];
+
+  // The number of copies (tiles) of the blob to output.
+  optional int32 tiles = 2;
+}
+
+// Message that stores parameters used by ThresholdLayer
+message ThresholdParameter {
+  optional float threshold = 1 [default = 0]; // Strictly positive values
+}
+
+message WindowDataParameter {
+  // Specify the data source.
+  optional string source = 1;
+  // For data pre-processing, we can do simple scaling and subtracting the
+  // data mean, if provided. Note that the mean subtraction is always carried
+  // out before scaling.
+  optional float scale = 2 [default = 1];
+  optional string mean_file = 3;
+  // Specify the batch size.
+  optional uint32 batch_size = 4;
+  // Specify if we would like to randomly crop an image.
+  optional uint32 crop_size = 5 [default = 0];
+  // Specify if we want to randomly mirror data.
+  optional bool mirror = 6 [default = false];
+  // Foreground (object) overlap threshold
+  optional float fg_threshold = 7 [default = 0.5];
+  // Background (non-object) overlap threshold
+  optional float bg_threshold = 8 [default = 0.5];
+  // Fraction of batch that should be foreground objects
+  optional float fg_fraction = 9 [default = 0.25];
+  // Amount of contextual padding to add around a window
+  // (used only by the window_data_layer)
+  optional uint32 context_pad = 10 [default = 0];
+  // Mode for cropping out a detection window
+  // warp: cropped window is warped to a fixed size and aspect ratio
+  // square: the tightest square around the window is cropped
+  optional string crop_mode = 11 [default = "warp"];
+  // cache_images: will load all images in memory for faster access
+  optional bool cache_images = 12 [default = false];
+  // append root_folder to locate images
+  optional string root_folder = 13 [default = ""];
+}
+
+message SPPParameter {
+  enum PoolMethod {
+    MAX = 0;
+    AVE = 1;
+    STOCHASTIC = 2;
+  }
+  optional uint32 pyramid_height = 1;
+  optional PoolMethod pool = 2 [default = MAX]; // The pooling method
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 6 [default = DEFAULT];
+}
+
+// DEPRECATED: use LayerParameter.
+message V1LayerParameter {
+  repeated string bottom = 2;
+  repeated string top = 3;
+  optional string name = 4;
+  repeated NetStateRule include = 32;
+  repeated NetStateRule exclude = 33;
+  enum LayerType {
+    NONE = 0;
+    ABSVAL = 35;
+    ACCURACY = 1;
+    ARGMAX = 30;
+    BNLL = 2;
+    CONCAT = 3;
+    CONTRASTIVE_LOSS = 37;
+    CONVOLUTION = 4;
+    DATA = 5;
+    DECONVOLUTION = 39;
+    DROPOUT = 6;
+    DUMMY_DATA = 32;
+    EUCLIDEAN_LOSS = 7;
+    ELTWISE = 25;
+    EXP = 38;
+    FLATTEN = 8;
+    HDF5_DATA = 9;
+    HDF5_OUTPUT = 10;
+    HINGE_LOSS = 28;
+    IM2COL = 11;
+    IMAGE_DATA = 12;
+    INFOGAIN_LOSS = 13;
+    INNER_PRODUCT = 14;
+    LRN = 15;
+    MEMORY_DATA = 29;
+    MULTINOMIAL_LOGISTIC_LOSS = 16;
+    MVN = 34;
+    POOLING = 17;
+    POWER = 26;
+    RELU = 18;
+    SIGMOID = 19;
+    SIGMOID_CROSS_ENTROPY_LOSS = 27;
+    SILENCE = 36;
+    SOFTMAX = 20;
+    SOFTMAX_LOSS = 21;
+    SPLIT = 22;
+    SLICE = 33;
+    TANH = 23;
+    WINDOW_DATA = 24;
+    THRESHOLD = 31;
+  }
+  optional LayerType type = 5;
+  repeated BlobProto blobs = 6;
+  repeated string param = 1001;
+  repeated DimCheckMode blob_share_mode = 1002;
+  enum DimCheckMode {
+    STRICT = 0;
+    PERMISSIVE = 1;
+  }
+  repeated float blobs_lr = 7;
+  repeated float weight_decay = 8;
+  repeated float loss_weight = 35;
+  optional AccuracyParameter accuracy_param = 27;
+  optional ArgMaxParameter argmax_param = 23;
+  optional ConcatParameter concat_param = 9;
+  optional ContrastiveLossParameter contrastive_loss_param = 40;
+  optional ConvolutionParameter convolution_param = 10;
+  optional DataParameter data_param = 11;
+  optional DropoutParameter dropout_param = 12;
+  optional DummyDataParameter dummy_data_param = 26;
+  optional EltwiseParameter eltwise_param = 24;
+  optional ExpParameter exp_param = 41;
+  optional HDF5DataParameter hdf5_data_param = 13;
+  optional HDF5OutputParameter hdf5_output_param = 14;
+  optional HingeLossParameter hinge_loss_param = 29;
+  optional ImageDataParameter image_data_param = 15;
+  optional InfogainLossParameter infogain_loss_param = 16;
+  optional InnerProductParameter inner_product_param = 17;
+  optional LRNParameter lrn_param = 18;
+  optional MemoryDataParameter memory_data_param = 22;
+  optional MVNParameter mvn_param = 34;
+  optional PoolingParameter pooling_param = 19;
+  optional PowerParameter power_param = 21;
+  optional ReLUParameter relu_param = 30;
+  optional SigmoidParameter sigmoid_param = 38;
+  optional SoftmaxParameter softmax_param = 39;
+  optional SliceParameter slice_param = 31;
+  optional TanHParameter tanh_param = 37;
+  optional ThresholdParameter threshold_param = 25;
+  optional WindowDataParameter window_data_param = 20;
+  optional TransformationParameter transform_param = 36;
+  optional LossParameter loss_param = 42;
+  optional V0LayerParameter layer = 1;
+}
+
+// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters
+// in Caffe.  We keep this message type around for legacy support.
+message V0LayerParameter {
+  optional string name = 1; // the layer name
+  optional string type = 2; // the string to specify the layer type
+
+  // Parameters to specify layers with inner products.
+  optional uint32 num_output = 3; // The number of outputs for the layer
+  optional bool biasterm = 4 [default = true]; // whether to have bias terms
+  optional FillerParameter weight_filler = 5; // The filler for the weight
+  optional FillerParameter bias_filler = 6; // The filler for the bias
+
+  optional uint32 pad = 7 [default = 0]; // The padding size
+  optional uint32 kernelsize = 8; // The kernel size
+  optional uint32 group = 9 [default = 1]; // The group size for group conv
+  optional uint32 stride = 10 [default = 1]; // The stride
+  enum PoolMethod {
+    MAX = 0;
+    AVE = 1;
+    STOCHASTIC = 2;
+  }
+  optional PoolMethod pool = 11 [default = MAX]; // The pooling method
+  optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
+
+  optional uint32 local_size = 13 [default = 5]; // for local response norm
+  optional float alpha = 14 [default = 1.]; // for local response norm
+  optional float beta = 15 [default = 0.75]; // for local response norm
+  optional float k = 22 [default = 1.];
+
+  // For data layers, specify the data source
+  optional string source = 16;
+  // For data pre-processing, we can do simple scaling and subtracting the
+  // data mean, if provided. Note that the mean subtraction is always carried
+  // out before scaling.
+  optional float scale = 17 [default = 1];
+  optional string meanfile = 18;
+  // For data layers, specify the batch size.
+  optional uint32 batchsize = 19;
+  // For data layers, specify if we would like to randomly crop an image.
+  optional uint32 cropsize = 20 [default = 0];
+  // For data layers, specify if we want to randomly mirror data.
+  optional bool mirror = 21 [default = false];
+
+  // The blobs containing the numeric parameters of the layer
+  repeated BlobProto blobs = 50;
+  // The ratio that is multiplied on the global learning rate. If you want to
+  // set the learning ratio for one blob, you need to set it for all blobs.
+  repeated float blobs_lr = 51;
+  // The weight decay that is multiplied on the global weight decay.
+  repeated float weight_decay = 52;
+
+  // The rand_skip variable is for the data layer to skip a few data points
+  // to avoid all asynchronous sgd clients to start at the same point. The skip
+  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
+  // be larger than the number of keys in the database.
+  optional uint32 rand_skip = 53 [default = 0];
+
+  // Fields related to detection (det_*)
+  // foreground (object) overlap threshold
+  optional float det_fg_threshold = 54 [default = 0.5];
+  // background (non-object) overlap threshold
+  optional float det_bg_threshold = 55 [default = 0.5];
+  // Fraction of batch that should be foreground objects
+  optional float det_fg_fraction = 56 [default = 0.25];
+
+  // optional bool OBSOLETE_can_clobber = 57 [default = true];
+
+  // Amount of contextual padding to add around a window
+  // (used only by the window_data_layer)
+  optional uint32 det_context_pad = 58 [default = 0];
+
+  // Mode for cropping out a detection window
+  // warp: cropped window is warped to a fixed size and aspect ratio
+  // square: the tightest square around the window is cropped
+  optional string det_crop_mode = 59 [default = "warp"];
+
+  // For ReshapeLayer, one needs to specify the new dimensions.
+  optional int32 new_num = 60 [default = 0];
+  optional int32 new_channels = 61 [default = 0];
+  optional int32 new_height = 62 [default = 0];
+  optional int32 new_width = 63 [default = 0];
+
+  // Whether or not ImageLayer should shuffle the list of files at every epoch.
+  // It will also resize images if new_height or new_width are not zero.
+  optional bool shuffle_images = 64 [default = false];
+
+  // For ConcatLayer, one needs to specify the dimension for concatenation, and
+  // the other dimensions must be the same for all the bottom blobs.
+  // By default it will concatenate blobs along the channels dimension.
+  optional uint32 concat_dim = 65 [default = 1];
+
+  optional HDF5OutputParameter hdf5_output_param = 1001;
+}
+
+message PReLUParameter {
+  // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers:
+  // Surpassing Human-Level Performance on ImageNet Classification, 2015.
+
+  // Initial value of a_i. Default is a_i=0.25 for all i.
+  optional FillerParameter filler = 1;
+  // Whether or not slope paramters are shared across channels.
+  optional bool channel_shared = 2 [default = false];
+}
\ No newline at end of file