You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2023/12/29 00:17:13 UTC

(systemds) branch main updated: [SYSTEMDS-3660] GPU cache eviction operator and related rewrite This patch introduces a new operator, _evict, to clean up the free pointer cached in the lineage cache. A shift in the allocation pattern leads to large eviction overhead and memory fragmentation. To address that, we speculatively clear a fraction of the free pointers. Currently, we place a _evict before every mini-batch processing.

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5898293b8d [SYSTEMDS-3660] GPU cache eviction operator and related rewrite This patch introduces a new operator, _evict, to clean up the free pointer cached in the lineage cache. A shift in the allocation pattern leads to large eviction overhead and memory fragmentation. To address that, we speculatively clear a fraction of the free pointers. Currently, we place a _evict before every mini-batch processing.
5898293b8d is described below

commit 5898293b8db25b9b1784ff30ec732812f4402d54
Author: Arnab Phani <ph...@gmail.com>
AuthorDate: Fri Dec 29 01:16:36 2023 +0100

    [SYSTEMDS-3660] GPU cache eviction operator and related rewrite
    This patch introduces a new operator, _evict, to clean up the free
    pointer cached in the lineage cache. A shift in the allocation pattern
    leads to large eviction overhead and memory fragmentation. To address
    that, we speculatively clear a fraction of the free pointers. Currently,
    we place a _evict before every mini-batch processing.
    
    Closes #1964
---
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 .../apache/sysds/conf/ConfigurationManager.java    |   4 +
 .../java/org/apache/sysds/hops/OptimizerUtils.java |   6 +
 src/main/java/org/apache/sysds/hops/UnaryOp.java   |   3 +
 .../org/apache/sysds/lops/rewrite/LopRewriter.java |   1 +
 .../sysds/lops/rewrite/RewriteAddGPUEvictLop.java  | 115 ++++
 .../runtime/instructions/CPInstructionParser.java  |   5 +
 .../runtime/instructions/cp/CPInstruction.java     |   1 +
 .../instructions/cp/EvictCPInstruction.java        |  49 ++
 .../runtime/lineage/LineageGPUCacheEviction.java   |   8 +-
 .../lineage/GPULineageCacheEvictionTest.java       |  16 +-
 .../functions/lineage/GPUCacheEviction6.dml        | 746 +++++++++++++++++++++
 12 files changed, 953 insertions(+), 3 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 84019e8078..30cd6bf5bd 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -341,7 +341,7 @@ public class Types
 		CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
 		CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
 		IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
-		MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP,
+		MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
 		SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
 		//fused ML-specific operators for performance 
 		SPROP, //sample proportion: P * (1 - P)
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 1ac4d13974..8c7d5547f5 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -294,6 +294,10 @@ public class ConfigurationManager{
 			|| OptimizerUtils.RULE_BASED_GPU_EXEC));
 	}
 
+	public static boolean isAutoEvictionEnabled() {
+		return OptimizerUtils.AUTO_GPU_CACHE_EVICTION;
+	}
+
 	public static ILinearize.DagLinearization getLinearizationOrder() {
 		if (OptimizerUtils.COST_BASED_ORDERING)
 			return ILinearize.DagLinearization.AUTO;
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index dc2bc487ed..8953cba378 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -310,6 +310,12 @@ public class OptimizerUtils
 	 */
 	public static boolean RULE_BASED_GPU_EXEC = false;
 
+	/**
+	 * Automatic placement of GPU lineage cache eviction
+	 */
+
+	public static boolean AUTO_GPU_CACHE_EVICTION = true;
+
 	//////////////////////
 	// Optimizer levels //
 	//////////////////////
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 5dbb55a303..d394beaf0e 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -145,6 +145,9 @@ public class UnaryOp extends MultiThreadedHop
 				case LOCAL:
 					ret = new Local(input.constructLops(), getDataType(), getValueType());
 					break;
+				case _EVICT:
+					ret = new UnaryCP(input.constructLops(), _op, getDataType(), getValueType());
+					break;
 				default:
 					final boolean isScalarIn = getInput().get(0).getDataType() == DataType.SCALAR;
 					if(getDataType() == DataType.SCALAR // value type casts or matrix to scalar
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 8d2c0a63f8..88c1787843 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -45,6 +45,7 @@ public class LopRewriter
 		_lopSBRuleSet.add(new RewriteAddBroadcastLop());
 		_lopSBRuleSet.add(new RewriteAddChkpointLop());
 		_lopSBRuleSet.add(new RewriteAddChkpointInLoop());
+		_lopSBRuleSet.add(new RewriteAddGPUEvictLop());
 		// TODO: A rewrite pass to remove less effective chkpoints
 		// Last rewrite to reset Lop IDs in a depth-first manner
 		_lopSBRuleSet.add(new RewriteFixIDs());
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.java b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.java
new file mode 100644
index 0000000000..8618e6a2eb
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.lops.BinaryScalar;
+import org.apache.sysds.lops.Data;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.lops.RightIndex;
+import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.VariableSet;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class RewriteAddGPUEvictLop extends LopRewriteRule
+{
+	@Override
+	public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
+		// TODO: Move this as a Statement block rewrite
+		if (!ConfigurationManager.isAutoEvictionEnabled())
+			return List.of(sb);
+
+		if (sb == null || !(sb instanceof ForStatementBlock)
+			|| !DMLScript.USE_ACCELERATOR || LineageCacheConfig.ReuseCacheType.isNone())
+			return List.of(sb);
+
+		// Collect the LOPs
+		StatementBlock csb = ((ForStatement) sb.getStatement(0)).getBody().get(0);
+		ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(csb);
+
+		// Check if this loop is for mini-batch processing
+		boolean isMiniBatch = findMiniBatchSlicing(lops);
+
+		// Insert statement block with _evict instruction before the loop
+		ArrayList<StatementBlock> ret = new ArrayList<>();
+		if (isMiniBatch) {
+			int evictFrac = 100;
+			StatementBlock sb0 = new StatementBlock();
+			sb0.setDMLProg(sb.getDMLProg());
+			sb0.setParseInfo(sb);
+			sb0.setLiveIn(new VariableSet());
+			sb0.setLiveOut(new VariableSet());
+			// Create both lops and hops (hops for recompilation)
+			// TODO: Add another input for the backend (GPU/CPU/Spark)
+			ArrayList<Lop> newlops = new ArrayList<>();
+			ArrayList<Hop> newhops = new ArrayList<>();
+			Lop fr = Data.createLiteralLop(Types.ValueType.INT64, Integer.toString(evictFrac));
+			fr.getOutputParameters().setDimensions(0, 0, 0, -1);
+			UnaryCP evict = new UnaryCP(fr, Types.OpOp1._EVICT, fr.getDataType(), fr.getValueType(), Types.ExecType.CP);
+			Hop in = new LiteralOp(evictFrac);
+			Hop evictHop = new UnaryOp("tmp", Types.DataType.SCALAR, Types.ValueType.INT64, Types.OpOp1._EVICT, in);
+			newlops.add(evict);
+			newhops.add(evictHop);
+			sb0.setLops(newlops);
+			sb0.setHops(newhops);
+			ret.add(sb0);
+		}
+		ret.add(sb);
+
+		return ret;
+	}
+
+	@Override
+	public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+		return sbs;
+	}
+
+	// To verify mini-batch processing, match the below pattern
+	// beg = ((i-1) * batch_size) %% N + 1;
+	// end = min(N, beg+batch_size-1);
+	// X_batch = X[beg:end];
+	private boolean findMiniBatchSlicing(ArrayList<Lop> lops) {
+		for (Lop l : lops) {
+			if (l instanceof RightIndex) {
+				ArrayList<Lop> inputs = l.getInputs();
+				if (inputs.get(0) instanceof Data && ((Data) inputs.get(0)).isTransientRead()
+					&& inputs.get(0).getInputs().size() == 0		//input1 is the dataset
+					&& inputs.get(1) instanceof BinaryScalar		//input2 is beg
+					&& ((BinaryScalar) inputs.get(1)).getOperationType() == Types.OpOp2.PLUS
+					&& inputs.get(2) instanceof BinaryScalar		//input3 is end
+					&& ((BinaryScalar) inputs.get(2)).getOperationType() == Types.OpOp2.MIN)
+					return true;
+			}
+		}
+		return false;
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 3de9fcd65d..c73d755b5e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -50,6 +50,7 @@ import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.DeCompressionCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.DnnCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.EvictCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.LocalCPInstruction;
@@ -337,6 +338,7 @@ public class CPInstructionParser extends InstructionParser {
 		String2CPInstructionType.put( DeCompression.OPCODE, CPType.DeCompression);
 		String2CPInstructionType.put( "spoof",     CPType.SpoofFused);
 		String2CPInstructionType.put( "prefetch",  CPType.Prefetch);
+		String2CPInstructionType.put( "_evict",  CPType.EvictLineageCache);
 		String2CPInstructionType.put( "broadcast",  CPType.Broadcast);
 		String2CPInstructionType.put( "trigremote",  CPType.TrigRemote);
 		String2CPInstructionType.put( Local.OPCODE, CPType.Local);
@@ -483,6 +485,9 @@ public class CPInstructionParser extends InstructionParser {
 				
 			case Broadcast:
 				return BroadcastCPInstruction.parseInstruction(str);
+
+			case EvictLineageCache:
+				return EvictCPInstruction.parseInstruction(str);
 			
 			default:
 				throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype );
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index 3503b256f7..1398d4365b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -45,6 +45,7 @@ public abstract class CPInstruction extends Instruction {
 		Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, Local,
 		MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, Compression, DeCompression, SpoofFused,
 		StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote,
+		EvictLineageCache,
 		NoOp,
 	 }
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvictCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvictCPInstruction.java
new file mode 100644
index 0000000000..d958f6e1ed
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvictCPInstruction.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageGPUCacheEviction;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class EvictCPInstruction extends UnaryCPInstruction
+{
+	private EvictCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
+		super(CPType.EvictLineageCache, op, in, out, opcode, istr);
+	}
+
+	public static EvictCPInstruction parseInstruction(String str) {
+		InstructionUtils.checkNumFields(str, 3);
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode = parts[0];
+		CPOperand in = new CPOperand(parts[1]);
+		CPOperand out = new CPOperand(parts[2]);
+		return new EvictCPInstruction(null, in, out, opcode, str);
+	}
+
+	@Override
+	public void processInstruction(ExecutionContext ec) {
+		// Evict fraction of cached objects
+		ScalarObject fr = ec.getScalarInput(input1);
+		double evictFrac = ((double) fr.getLongValue()) / 100;
+		LineageGPUCacheEviction.removeAllEntries(evictFrac);
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
index 7eac5e4a54..5497210999 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
@@ -131,10 +131,13 @@ public class LineageGPUCacheEviction
 		}
 	}
 
-	public static void removeAllEntries() {
+	// Speculative eviction
+	public static void removeAllEntries(double evictFrac) {
 		List<Long> sizes = new ArrayList<>(freeQueues.keySet());
 		for (Long size : sizes) {
 			TreeSet<LineageCacheEntry> freeList = freeQueues.get(size);
+			int evictLim = (int) (freeList.size() * evictFrac);
+			int evictCount = 1;
 			LineageCacheEntry le = pollFirstFreeEntry(size);
 			while (le != null) {
 				// Free the pointer
@@ -142,6 +145,9 @@ public class LineageGPUCacheEviction
 				if (DMLScript.STATISTICS)
 					LineageCacheStatistics.incrementGpuDel();
 				le = pollFirstFreeEntry(size);
+				if (evictCount > evictLim)
+					break;
+				evictCount++;
 			}
 		}
 	}
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java
index cc59da7de9..0a12ffb9af 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java
@@ -29,6 +29,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
 import org.junit.Assume;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -39,7 +41,7 @@ public class GPULineageCacheEvictionTest extends AutomatedTestBase{
 	
 	protected static final String TEST_DIR = "functions/lineage/";
 	protected static final String TEST_NAME = "GPUCacheEviction";
-	protected static final int TEST_VARIANTS = 5;
+	protected static final int TEST_VARIANTS = 6;
 	protected String TEST_CLASS_DIR = TEST_DIR + GPULineageCacheEvictionTest.class.getSimpleName() + "/";
 	
 	@BeforeClass
@@ -80,6 +82,11 @@ public class GPULineageCacheEvictionTest extends AutomatedTestBase{
 		testLineageTraceExec(TEST_NAME+"5");
 	}
 
+	@Test
+	public void TransferLearning3Models() {  //transfer learning and reuse (AlexNet,VGG,ResNet)
+		testLineageTraceExec(TEST_NAME+"6");
+	}
+
 
 	private void testLineageTraceExec(String testname) {
 		System.out.println("------------ BEGIN " + testname + "------------");
@@ -117,6 +124,13 @@ public class GPULineageCacheEvictionTest extends AutomatedTestBase{
 
 		//compare results 
 		TestUtils.compareMatrices(R_orig, R_reused, 1e-6, "Origin", "Reused");
+
+		//Match _evict count
+		if (testname.equalsIgnoreCase(TEST_NAME+"6")) {
+			long exp_numev = 3;
+			long numev = Statistics.getCPHeavyHitterCount("_evict");
+			Assert.assertTrue("Violated Prefetch instruction count: "+numev, numev == exp_numev);
+		}
 	}
 }
 
diff --git a/src/test/scripts/functions/lineage/GPUCacheEviction6.dml b/src/test/scripts/functions/lineage/GPUCacheEviction6.dml
new file mode 100644
index 0000000000..fdf4285610
--- /dev/null
+++ b/src/test/scripts/functions/lineage/GPUCacheEviction6.dml
@@ -0,0 +1,746 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+conv2d_forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
+  int C, int Hin, int Win, int Hf, int Wf, int strideh, int stridew,
+  int padh, int padw) return (matrix[double] out, int Hout, int Wout)
+{
+  N = nrow(X)
+  F = nrow(W)
+  Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
+  Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1))
+  # Convolution - built-in implementation
+  out = conv2d(X, W, input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf],
+               stride=[strideh,stridew], padding=[padh,padw])
+  # Add bias term to each output filter
+  out = bias_add(out, b)
+}
+
+conv2d_backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X,
+  matrix[double] W, matrix[double] b, int C, int Hin, int Win, int Hf, int Wf,
+  int strideh, int stridew, int padh, int padw)
+  return (matrix[double] dX, matrix[double] dW, matrix[double] db)
+{
+  N = nrow(X)
+  F = nrow(W)
+  # Partial derivatives for convolution - built-in implementation
+  dW = conv2d_backward_filter(X, dout, stride=[strideh,stridew], padding=[padh,padw],
+                              input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf])
+  dX = conv2d_backward_data(W, dout, stride=[strideh,stridew], padding=[padh,padw],
+                            input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf])
+  # Partial derivatives for bias vector
+  # Here we sum each column, reshape to (F, Hout*Wout), and sum each row
+  # to result in the summation for each channel.
+  db = rowSums(matrix(colSums(dout), rows=F, cols=Hout*Wout))  # shape (F, 1)
+}
+
+conv2d_init = function(int F, int C, int Hf, int Wf, int seed = -1)
+  return (matrix[double] W, matrix[double] b) {
+  W = rand(rows=F, cols=C*Hf*Wf, pdf="normal", seed=seed) * sqrt(2.0/(C*Hf*Wf))
+  b = matrix(0, rows=F, cols=1)
+}
+
+bn2d_forward = function(matrix[double] X, int C, int Hin, int Win, 
+    double mu, double epsilon) return (matrix[double] out)
+{
+    gamma = matrix(1, rows=C, cols=1)
+    beta = matrix(0, rows=C, cols=1)
+    ema_mean = matrix(0, rows=C, cols=1)
+    ema_var = matrix(1, rows=C, cols=1)
+    ema_mean_upd = ema_mean; 
+    ema_var_upd = ema_var;  
+    cache_mean = ema_mean; 
+    cache_inv_var = ema_var
+    mode = 'train';
+    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu)
+}
+
+affine_forward = function(matrix[double] X, matrix[double] W, matrix[double] b) return (matrix[double] out) {
+  out = X %*% W + b;
+}
+
+affine_init = function(int D, int M, int seed = -1 ) return (matrix[double] W, matrix[double] b) {
+  W = rand(rows=D, cols=M, pdf="normal", seed=seed) * sqrt(2.0/D);
+  b = matrix(0, rows=1, cols=M);
+}
+
+relu_forward = function(matrix[double] X) return (matrix[double] out) {
+  out = max(0, X);
+}
+
+max_pool2d_forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
+  int strideh, int stridew, int padh, int padw) return(matrix[double] out, int Hout, int Wout)
+{
+  N = nrow(X)
+  Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
+  Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1))
+  out = max_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf],
+    stride=[strideh,stridew], padding=[padh,padw])
+}
+
+avg_pool2d_forward = function(matrix[double] X, int C, int Hin, int Win)
+  return (matrix[double] out, int Hout, int Wout) {
+  N = nrow(X)
+  Hout = 1
+  Wout = 1
+  out = avg_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hin,Win], stride=[1,1], padding=[0, 0])
+}
+
+softmax_forward = function(matrix[double] scores) return (matrix[double] probs) {
+  scores = scores - rowMaxs(scores);  # numerical stability
+  unnorm_probs = exp(scores);  # unnormalized probabilities
+  probs = unnorm_probs / rowSums(unnorm_probs);  # normalized probabilities
+}
+
+basic_block = function(matrix[double] X, int C, int C_base, int Hin, int Win, int strideh,
+    int stridew, matrix[double] WC1, matrix[double] bC1, matrix[double] WC2, matrix[double] bC2)
+  return (matrix[double] out, int Hout, int Wout)
+{
+  mu_bn = 0.1;
+  ep_bn = 1e-05;
+  downsample = strideh > 1 | stridew > 1 | C != C_base;
+  if (downsample) {
+    [WC3, bC3] = conv2d_init(C_base, C, Hf=1, Wf=1, 42);
+  }
+  # Residual Path
+  # conv1 -> bn1 -> relu1
+  [out, Hout, Wout] = conv2d_forward(X,WC1,bC1,C,Hin,Win,3,3,strideh,stridew,1,1);
+  out = bn2d_forward(out,C_base,Hout,Wout,mu_bn,ep_bn);
+  out = relu_forward(out);
+  # conv2 -> bn2 -> relu2
+  [out, Hout, Wout] = conv2d_forward(out,WC2,bC2,C_base,Hout,Wout,3,3,1,1,1,1);
+  out = bn2d_forward(out,C_base,Hout,Wout,mu_bn,ep_bn);
+  # Identity Path
+  identity = X;
+  if (downsample) {
+    # Downsample input
+    [identity, Hout, Wout] = conv2d_forward(X,WC3,bC3,C,Hin,Win,1,1,strideh,stridew,0,0);
+    out = bn2d_forward(identity,C_base,Hout,Wout,mu_bn,ep_bn);
+  }
+  out = relu_forward(out + identity);
+}
+
+getWeights = function(int fel, int lid,
+    matrix[double] W_pt, matrix[double] b_pt,
+    matrix[double] W_init, matrix[double] b_init)
+  return (matrix[double] Wl, matrix[double] bl)
+{
+  if (lid < fel) { #extract pretrained features
+    Wl = W_pt;
+    bl = b_pt;
+  }
+  else {  #use initialized weights
+    Wl = W_init;
+    bl = b_init;
+  }
+}
+
+rwRowIndexMax = function(matrix[double] X, matrix[double] oneVec, matrix[double] idxSeq)
+    return (matrix[double] index) {
+  rm = rowMaxs(X) %*% oneVec;
+  I = X == rm;
+  index = rowMaxs(I * idxSeq);
+}
+
+####################################################################
+
+# Exploratory feature extraction from pre-trained resnet18 model 
+predict_resnet18 = function(matrix[double] X, int C, int Hin, int Win, int K)
+  return (matrix[double] Y_pred)
+{
+  mu_bn = 0.1;
+  ep_bn = 1e-05;
+
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(64, C, Hf=7, Wf=7, 42);
+  [W2_pt, b2_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W3_pt, b3_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = conv2d_init(128, 64, Hf=3, Wf=3, 42);
+  [W7_pt, b7_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W8_pt, b8_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W9_pt, b9_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W10_pt, b10_pt] = conv2d_init(256, 128, Hf=3, Wf=3, 42);
+  [W11_pt, b11_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W12_pt, b12_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W13_pt, b13_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W14_pt, b14_pt] = conv2d_init(512, 256, Hf=3, Wf=3, 42);
+  [W15_pt, b15_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W16_pt, b16_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W17_pt, b17_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W18_pt, b18_pt] = affine_init(512, K, 42);
+  W18_pt = W18_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(64, C, Hf=7, Wf=7, 43);
+  [W2_init, b2_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W3_init, b3_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = conv2d_init(128, 64, Hf=3, Wf=3, 43);
+  [W7_init, b7_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W8_init, b8_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W9_init, b9_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W10_init, b10_init] = conv2d_init(256, 128, Hf=3, Wf=3, 42);
+  [W11_init, b11_init] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W12_init, b12_init] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W13_init, b13_init] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W14_init, b14_init] = conv2d_init(512, 256, Hf=3, Wf=3, 42);
+  [W15_init, b15_init] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W16_init, b16_init] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W17_init, b17_init] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W18_init, b18_init] = affine_init(512, K, 42);
+  W18_init = W18_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  Y_pred = matrix(0, rows=N, cols=3);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 3 layers
+    j = 1;
+    fel = 10; #extract 9, 8, 7, 6 
+    while (j < 4) {
+      # Compute forward pass
+      # Layer1: conv2d 7x7 -> bn -> relu -> maxpool 3x3
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,7,7,2,2,3,3);
+      outb1 = bn2d_forward(outc1,64,Houtc1,Woutc1,mu_bn,ep_bn);
+      outr1 = relu_forward(outb1);
+      [outp1, Houtp1, Woutp1] = max_pool2d_forward(outr1,64,Houtc1, Woutc1,3,3,2,2,1,1);
+
+      # Layer2: residual block1
+      lid = 2;
+      [Wc1, bc1] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [Wc2, bc2] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outrb1, Houtrb1, Woutrb1] = basic_block(outp1,64,64,Houtp1,Woutp1,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer3: residual block2
+      lid = 3;
+      [Wc1, bc1] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [Wc2, bc2] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outrb2, Houtrb2, Woutrb2] = basic_block(outrb1,64,64,Houtrb1,Woutrb1,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer4: residual block3
+      lid = 4;
+      [Wc1, bc1] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      [Wc2, bc2] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      [outrb3, Houtrb3, Woutrb3] = basic_block(outrb2,64,128,Houtrb2,Woutrb2,2,2,Wc1,bc1,Wc2,bc2);
+
+      # Layer5: residual block4
+      lid = 5;
+      [Wc1, bc1] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      [Wc2, bc2] = getWeights(fel, lid, W9_pt, b9_pt, W9_init, b9_init);
+      [outrb4, Houtrb4, Woutrb4] = basic_block(outrb3,128,128,Houtrb3,Woutrb3,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer6: residual block5
+      lid = 6;
+      [Wc1, bc1] = getWeights(fel, lid, W10_pt, b10_pt, W10_init, b10_init);
+      [Wc2, bc2] = getWeights(fel, lid, W11_pt, b11_pt, W11_init, b11_init);
+      [outrb5, Houtrb5, Woutrb5] = basic_block(outrb4,128,256,Houtrb4,Woutrb4,2,2,Wc1,bc1,Wc2,bc2);
+
+      # Layer7: residual block6
+      lid = 7;
+      [Wc1, bc1] = getWeights(fel, lid, W12_pt, b12_pt, W12_init, b12_init);
+      [Wc2, bc2] = getWeights(fel, lid, W13_pt, b13_pt, W13_init, b13_init);
+      [outrb6, Houtrb6, Woutrb6] = basic_block(outrb5,256,256,Houtrb5,Woutrb5,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer8: residual block7
+      lid = 8;
+      [Wc1, bc1] = getWeights(fel, lid, W14_pt, b14_pt, W14_init, b14_init);
+      [Wc2, bc2] = getWeights(fel, lid, W15_pt, b15_pt, W15_init, b15_init);
+      [outrb7, Houtrb7, Woutrb7] = basic_block(outrb6,256,512,Houtrb6,Woutrb6,2,2,Wc1,bc1,Wc2,bc2);
+
+      # Layer9: residual block8
+      lid = 9;
+      [Wc1, bc1] = getWeights(fel, lid, W16_pt, b16_pt, W16_init, b16_init);
+      [Wc2, bc2] = getWeights(fel, lid, W17_pt, b17_pt, W17_init, b17_init);
+      [outrb8, Houtrb8, Woutrb8] = basic_block(outrb7,512,512,Houtrb7,Woutrb7,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Global average pooling 
+      [outap1, Houtap1, Houtap2] = avg_pool2d_forward(outrb8, 512, Houtrb8, Woutrb8);
+
+      # layer10 : Fully connected layer
+      lid = 10;
+      [Wl10, bl10] = getWeights(fel, lid, W18_pt, b18_pt, W18_init, b18_init);
+      outa1 = affine_forward(outap1, Wl10, bl10);
+      probs_batch = softmax_forward(outa1);
+
+      # Store the predictions
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+
+# Exploratory feature extraction from pre-trained VGG16 model 
+predict_vgg = function(matrix[double] X, int C, int Hin, int Win, int K, int dim)
+  return (matrix[double] Y_pred)
+{
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(64, C, Hf=3, Wf=3, 42);
+  [W2_pt, b2_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W3_pt, b3_pt] = conv2d_init(128, 64, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(256, 128, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W7_pt, b7_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W8_pt, b8_pt] = conv2d_init(512, 256, Hf=3, Wf=3, 42);
+  [W9_pt, b9_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W10_pt, b10_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W11_pt, b11_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W12_pt, b12_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W13_pt, b13_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  if (dim == 224)
+    [W14_pt, b14_pt] = affine_init(25088, 4096, 42);
+  if (dim == 32)
+    [W14_pt, b14_pt] = affine_init(512, 4096, 42);
+  [W15_pt, b15_pt] = affine_init(4096, 4096, 42);
+  [W16_pt, b16_pt] = affine_init(4096, K, 42);
+  W16_pt = W16_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(64, C, Hf=3, Wf=3, 43);
+  [W2_init, b2_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W3_init, b3_init] = conv2d_init(128, 64, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(256, 128, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = conv2d_init(256, 256, Hf=3, Wf=3, 43);
+  [W7_init, b7_init] = conv2d_init(256, 256, Hf=3, Wf=3, 43);
+  [W8_init, b8_init] = conv2d_init(512, 256, Hf=3, Wf=3, 43);
+  [W9_init, b9_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W10_init, b10_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W11_init, b11_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W12_init, b12_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W13_init, b13_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  if (dim == 224)
+    [W14_init, b14_init] = affine_init(25088, 4096, 43);
+  if (dim == 32)
+    [W14_init, b14_init] = affine_init(512, 4096, 43);
+  [W15_init, b15_init] = affine_init(4096, 4096, 43);
+  [W16_init, b16_init] = affine_init(4096, K, 43);
+  W16_init = W16_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  Y_pred = matrix(0, rows=N, cols=3);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 3 layers
+    j = 1;
+    fel = 8; #extract 7, 6, 5
+    while (j < 4) {
+      # Compute forward pass
+      # layer 1: Two conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,3,3,1,1,1,1);
+      outr1 = relu_forward(outc1);
+      [Wl2, bl2] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [outc2, Houtc2, Woutc2] = conv2d_forward(outr1,Wl2,bl2,64,Houtc1,Woutc1,3,3,1,1,1,1);
+      outr2 = relu_forward(outc2);
+      [outp1, Houtp1, Woutp1] = max_pool2d_forward(outr2,64,Houtc2, Woutc2,2,2,2,2,0,0);
+
+      # layer 2: Two conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 2;
+      [Wl3, bl3] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outc3, Houtc3, Woutc3] = conv2d_forward(outp1,Wl3,bl3,64,Houtp1,Woutp1,3,3,1,1,1,1);
+      outr3 = relu_forward(outc3);
+      [Wl4, bl4] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [outc4, Houtc4, Woutc4] = conv2d_forward(outr3,Wl4,bl4,128,Houtc3,Woutc3,3,3,1,1,1,1);
+      outr4 = relu_forward(outc4);
+      [outp2, Houtp2, Woutp2] = max_pool2d_forward(outr4,128,Houtc4, Woutc4,2,2,2,2,0,0);
+
+      # layer 3: Three conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 3;
+      [Wl5, bl5] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outc5, Houtc5, Woutc5] = conv2d_forward(outp2,Wl5,bl5,128,Houtp2,Woutp2,3,3,1,1,1,1);
+      outr5 = relu_forward(outc5);
+      [Wl6, bl6] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      [outc6, Houtc6, Woutc6] = conv2d_forward(outr5,Wl6,bl6,256,Houtc5,Woutc5,3,3,1,1,1,1);
+      outr6 = relu_forward(outc6);
+      [Wl7, bl7] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      [outc7, Houtc7, Woutc7] = conv2d_forward(outr6,Wl7,bl7,256,Houtc6,Woutc6,3,3,1,1,1,1);
+      outr7 = relu_forward(outc7);
+      [outp3, Houtp3, Woutp3] = max_pool2d_forward(outr7,256,Houtc7, Woutc7,2,2,2,2,0,0);
+
+      # layer 4: Three conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 4;
+      [Wl8, bl8] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      [outc8, Houtc8, Woutc8] = conv2d_forward(outp3,Wl8,bl8,256,Houtp3,Woutp3,3,3,1,1,1,1);
+      outr8 = relu_forward(outc8);
+      [Wl9, bl9] = getWeights(fel, lid, W9_pt, b9_pt, W9_init, b9_init);
+      [outc9, Houtc9, Woutc9] = conv2d_forward(outr8,Wl9,bl9,512,Houtc8,Woutc8,3,3,1,1,1,1);
+      outr9 = relu_forward(outc9);
+      [Wl10, bl10] = getWeights(fel, lid, W10_pt, b10_pt, W10_init, b10_init);
+      [outc10, Houtc10, Woutc10] = conv2d_forward(outr9,Wl10,bl10,512,Houtc9,Woutc9,3,3,1,1,1,1);
+      outr10 = relu_forward(outc10);
+      [outp4, Houtp4, Woutp4] = max_pool2d_forward(outr10,512,Houtc10, Woutc10,2,2,2,2,0,0);
+
+      # layer 5: Three conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 5;
+      [Wl11, bl11] = getWeights(fel, lid, W11_pt, b11_pt, W11_init, b11_init);
+      [outc11, Houtc11, Woutc11] = conv2d_forward(outp4,Wl11,bl11,512,Houtp4,Woutp4,3,3,1,1,1,1);
+      outr11 = relu_forward(outc11);
+      [Wl12, bl12] = getWeights(fel, lid, W12_pt, b12_pt, W12_init, b12_init);
+      [outc12, Houtc12, Woutc12] = conv2d_forward(outr11,Wl12,bl12,512,Houtc11,Woutc11,3,3,1,1,1,1);
+      outr12 = relu_forward(outc12);
+      [Wl13, bl13] = getWeights(fel, lid, W13_pt, b13_pt, W13_init, b13_init);
+      [outc13, Houtc13, Woutc13] = conv2d_forward(outr12,Wl13,bl13,512,Houtc12,Woutc12,3,3,1,1,1,1);
+      outr13 = relu_forward(outc13);
+      [outp5, Houtp5, Woutp5] = max_pool2d_forward(outr13,512,Houtc13, Woutc13,2,2,2,2,0,0);
+
+      # layer 6: Fully connected layer (w/ activation relu)
+      lid = 6;
+      [Wl14, bl14] = getWeights(fel, lid, W14_pt, b14_pt, W14_init, b14_init);
+      outa6 = affine_forward(outp5, Wl14, bl14);
+      outr6 = relu_forward(outa6);
+
+      # layer 7: Fully connected layer (w/ activation relu)
+      lid = 7;
+      [Wl15, bl15] = getWeights(fel, lid, W15_pt, b15_pt, W15_init, b15_init);
+      outa7 = affine_forward(outr6, Wl15, bl15);
+      outr7 = relu_forward(outa7);
+
+      # layer 8: Fully connected layer (w/ activation softmax)
+      lid = 8;
+      [Wl16, bl16] = getWeights(fel, lid, W16_pt, b16_pt, W16_init, b16_init);
+      outa8 = affine_forward(outr7, Wl16, bl16);
+      probs_batch = softmax_forward(outa8);
+
+      # Store the predictions
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+predict_alex = function(matrix[double] X, int C, int Hin, int Win, int K) 
+  return (matrix[double] Y_pred)
+{
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(96, C, Hf=11, Wf=11, 42);
+  [W2_pt, b2_pt] = conv2d_init(256, 96, Hf=5, Wf=5, 42);
+  [W3_pt, b3_pt] = conv2d_init(384, 256, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(384, 384, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(256, 384, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = affine_init(6400, 4096, 42); 
+  [W7_pt, b7_pt] = affine_init(4096, 4096, 42);
+  [W8_pt, b8_pt] = affine_init(4096, K, 42);
+  W8_pt = W8_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(96, C, Hf=11, Wf=11, 43);
+  [W2_init, b2_init] = conv2d_init(256, 96, Hf=5, Wf=5, 43);
+  [W3_init, b3_init] = conv2d_init(384, 256, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(384, 384, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(256, 384, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = affine_init(6400, 4096, 43);
+  [W7_init, b7_init] = affine_init(4096, 4096, 43);
+  [W8_init, b8_init] = affine_init(4096, K, 43);
+  W8_init = W8_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  verbose = FALSE;
+  Y_pred = matrix(0, rows=N, cols=4);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 4 layers
+    j = 1;
+    fel = 8;
+    while (j < 5) {
+      # Compute forward pass
+      # layer 1: conv1 -> relu1 -> pool1
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,11,11,4,4,0,0);
+      if(verbose) print("sum(conv1) = "+sum(outc1));
+      if(verbose) print(nrow(outc1)+", "+ncol(outc1));
+      outr1 = relu_forward(outc1);
+      [outp1, Houtp1, Woutp1] = max_pool2d_forward(outr1,96,Houtc1,Woutc1,3,3,2,2,0,0)
+      if(verbose) print("sum(pool1) = "+sum(outp1));
+      if(verbose) print(nrow(outp1)+", "+ncol(outp1));
+
+      # layer 2: conv2 -> relu2 -> pool2
+      lid = 2;
+      [Wl2, bl2] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [outc2, Houtc2, Woutc2] = conv2d_forward(outp1,Wl2,bl2,96,Houtp1,Woutp1,5,5,1,1,2,2);
+      if(verbose) print("sum(conv2) = "+sum(outc2));
+      if(verbose) print(nrow(outc2)+", "+ncol(outc2));
+      outr2 = relu_forward(outc2);
+      [outp2, Houtp2, Woutp2] = max_pool2d_forward(outr2,256,Houtc2,Woutc2,3,3,2,2,0,0);
+      if(verbose) print("sum(pool2) = "+sum(outp2));
+      if(verbose) print(nrow(outp2)+", "+ncol(outp2));
+
+      # layer 3: conv3 -> relu3
+      lid = 3;
+      [Wl3, bl3] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outc3, Houtc3, Woutc3] = conv2d_forward(outp2,Wl3,bl3,256,Houtp2,Woutp2,3,3,1,1,1,1);
+      if(verbose) print("sum(conv3) = "+sum(outc3));
+      if(verbose) print(nrow(outc3)+", "+ncol(outc3));
+      outr3 = relu_forward(outc3);
+      
+      # layer 4: conv4 -> relu4
+      lid = 4;
+      [Wl4, bl4] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [outc4, Houtc4, Woutc4] = conv2d_forward(outr3,Wl4,bl4,384,Houtc3,Woutc3,3,3,1,1,1,1);
+      if(verbose) print("sum(conv4) = "+sum(outc4));
+      if(verbose) print(nrow(outc4)+", "+ncol(outc4));
+      outr4 = relu_forward(outc4);
+
+      # layer 5: conv5 -> relu5 -> pool3
+      lid = 5;
+      [Wl5, bl5] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outc5, Houtc5, Woutc5] = conv2d_forward(outr4,Wl5,bl5,384,Houtc4,Woutc4,3,3,1,1,1,1);
+      if(verbose) print("sum(conv5) = "+sum(outc5));
+      if(verbose) print(nrow(outc5)+", "+ncol(outc5));
+      outr5 = relu_forward(outc5);
+      [outp5, Houtp5, Woutp5] = max_pool2d_forward(outr5,256,Houtc5,Woutc5,3,3,2,2,0,0)
+      if(verbose) print("sum(pool3) = "+sum(outp5));
+      if(verbose) print(nrow(outp5)+", "+ncol(outp5));
+
+      # layer 6: affine1 -> relu6
+      lid = 6;
+      [Wl6, bl6] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      outa6 = affine_forward(outp5, Wl6, bl6);
+      if(verbose) print(nrow(outa6)+", "+ncol(outa6));
+      outr6 = relu_forward(outa6);
+
+      # layer 7: affine2 -> relu7
+      lid = 7;
+      [Wl7, bl7] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      outa7 = affine_forward(outr6, Wl7, bl7);
+      if(verbose) print(nrow(outa7)+", "+ncol(outa7));
+      outr7 = relu_forward(outa7);
+
+      # layer 8: affine3 -> softmax
+      lid = 8;
+      [Wl8, bl8] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      outa8 = affine_forward(outr7, Wl8, bl8);
+      if(verbose) print(nrow(outa8)+", "+ncol(outa8));
+      probs_batch = softmax_forward(outa8);
+
+      # Store the predicted classes
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+predict_alex_32 = function(matrix[double] X, int C, int Hin, int Win, int K)
+  return (matrix[double] Y_pred)
+{
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(64, C, Hf=11, Wf=11, 42);
+  [W2_pt, b2_pt] = conv2d_init(192, 64, Hf=5, Wf=5, 42);
+  [W3_pt, b3_pt] = conv2d_init(384, 192, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(256, 384, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = affine_init(256, 4096, 42);
+  [W7_pt, b7_pt] = affine_init(4096, 4096, 42);
+  [W8_pt, b8_pt] = affine_init(4096, K, 42);
+  W8_pt = W8_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(64, C, Hf=11, Wf=11, 43);
+  [W2_init, b2_init] = conv2d_init(192, 64, Hf=5, Wf=5, 43);
+  [W3_init, b3_init] = conv2d_init(384, 192, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(256, 384, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(256, 256, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = affine_init(256, 4096, 43);
+  [W7_init, b7_init] = affine_init(4096, 4096, 43);
+  [W8_init, b8_init] = affine_init(4096, K, 43);
+  W8_init = W8_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  verbose = FALSE;
+  Y_pred = matrix(0, rows=N, cols=4);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 4 layers
+    j = 1;
+    fel = 8;
+    while (j < 5) {
+      # Compute forward pass
+      # layer 1: conv1 -> relu1 -> pool1
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,11,11,4,4,2,2);
+      if(verbose) print("sum(conv1) = "+sum(outc1));
+      if(verbose) print(nrow(outc1)+", "+ncol(outc1));
+      outr1 = relu_forward(outc1);
+      [outp1, Houtp1, Woutp1] = max_pool2d_forward(outr1,64,Houtc1,Woutc1,3,3,2,2,0,0)
+      if(verbose) print("sum(pool1) = "+sum(outp1));
+      if(verbose) print(nrow(outp1)+", "+ncol(outp1));
+
+       # layer 2: conv2 -> relu2 -> pool2
+      lid = 2;
+      [Wl2, bl2] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [outc2, Houtc2, Woutc2] = conv2d_forward(outp1,Wl2,bl2,64,Houtp1,Woutp1,5,5,1,1,2,2);
+      if(verbose) print("sum(conv2) = "+sum(outc2));
+      if(verbose) print(nrow(outc2)+", "+ncol(outc2));
+      outr2 = relu_forward(outc2);
+      [outp2, Houtp2, Woutp2] = max_pool2d_forward(outr2,192,Houtc2,Woutc2,3,3,2,2,0,0);
+      if(verbose) print("sum(pool2) = "+sum(outp2));
+      if(verbose) print(nrow(outp2)+", "+ncol(outp2));
+
+      # layer 3: conv3 -> relu3
+      lid = 3;
+      [Wl3, bl3] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outc3, Houtc3, Woutc3] = conv2d_forward(outp2,Wl3,bl3,192,Houtp2,Woutp2,3,3,1,1,1,1);
+      if(verbose) print("sum(conv3) = "+sum(outc3));
+      if(verbose) print(nrow(outc3)+", "+ncol(outc3));
+      outr3 = relu_forward(outc3);
+
+      # layer 4: conv4 -> relu4
+      lid = 4;
+      [Wl4, bl4] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [outc4, Houtc4, Woutc4] = conv2d_forward(outr3,Wl4,bl4,384,Houtc3,Woutc3,3,3,1,1,1,1);
+      if(verbose) print("sum(conv4) = "+sum(outc4));
+      if(verbose) print(nrow(outc4)+", "+ncol(outc4));
+      outr4 = relu_forward(outc4);
+
+      # layer 5: conv5 -> relu5 -> pool3
+      lid = 5;
+      [Wl5, bl5] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outc5, Houtc5, Woutc5] = conv2d_forward(outr4,Wl5,bl5,256,Houtc4,Woutc4,3,3,1,1,1,1);
+      if(verbose) print("sum(conv5) = "+sum(outc5));
+      if(verbose) print(nrow(outc5)+", "+ncol(outc5));
+      outr5 = relu_forward(outc5);
+      [outp5, Houtp5, Woutp5] = max_pool2d_forward(outr5,256,Houtc5,Woutc5,3,3,2,2,1,1)
+      if(verbose) print("sum(pool3) = "+sum(outp5));
+      if(verbose) print(nrow(outp5)+", "+ncol(outp5));
+
+      # layer 6: affine1 -> relu6
+      lid = 6;
+      [Wl6, bl6] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      outa6 = affine_forward(outp5, Wl6, bl6);
+      if(verbose) print(nrow(outa6)+", "+ncol(outa6));
+      outr6 = relu_forward(outa6);
+
+      # layer 7: affine2 -> relu7
+      lid = 7;
+      [Wl7, bl7] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      outa7 = affine_forward(outr6, Wl7, bl7);
+      if(verbose) print(nrow(outa7)+", "+ncol(outa7));
+      outr7 = relu_forward(outa7);
+
+      # layer 8: affine3 -> softmax
+      lid = 8;
+      [Wl8, bl8] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      outa8 = affine_forward(outr7, Wl8, bl8);
+      if(verbose) print(nrow(outa8)+", "+ncol(outa8));
+      probs_batch = softmax_forward(outa8);
+
+      # Store the predicted classes
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+generate_dummy_data = function(int N, int C, int Hin, int Win, int K)
+  return (matrix[double] X, matrix[double] Y) {
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal", seed=45) #linearized images
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform", seed=46))
+  Y = table(seq(1, N), classes, N, K)  #one-hot encoding
+}
+
+##########################################################################
+
+# Read training data and settings
+N = 512;     #num of images in the target dataset
+C = 3;       #num of color channels
+K = 10;      #num of classes
+dataset = "cifar";
+if (dataset == "cifar")
+  Hin = 32; #input image height
+if (dataset == "imagenet")
+  Hin = 224; #input image height
+Win = Hin; #input image width
+
+# Generate dummy data
+[X, Y] = generate_dummy_data(N, C, Hin, Win, K);
+
+# Load the CuDNN libraries by calling a conv2d
+print("Eagerly loading cuDNN library");
+[W1, b1] = conv2d_init(96, C, Hf=11, Wf=11, 42);
+[outc1, Houtc1, Woutc1] = conv2d_forward(X[1:8,], W1, b1, C, Hin, Win, 11, 11, 1, 1, 2, 2);
+print(sum(outc1));
+
+print("Starting exploratory feature transfers");
+Y_pred = matrix(0, rows=N, cols=10);
+t1 = time();
+if (Hin == 32)
+  Y_pred[,1:4] = predict_alex_32(X, C, Hin, Win, K);
+if (Hin == 224)
+  Y_pred[,1:4] = predict_alex(X, C, Hin, Win, K);
+Y_pred[,5:7] = predict_vgg(X, C, Hin, Win, K, Hin);
+Y_pred[,8:10] = predict_resnet18(X, C, Hin, Win, K);
+print(toString(colSums(Y_pred)));
+
+t2 = time();
+print("Elapsed time for feature transfers = "+floor((t2-t1)/1000000)+" millsec");
+write(Y_pred, $1, format="text");
+