This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5898293b8d [SYSTEMDS-3660] GPU cache eviction operator and related 
rewrite This patch introduces a new operator, _evict, to clean up the free 
pointer cached in the lineage cache. A shift in the allocation pattern leads to 
large eviction overhead and memory fragmentation. To address that, we 
speculatively clear a fraction of the free pointers. Currently, we place a 
_evict before every mini-batch processing.
5898293b8d is described below

commit 5898293b8db25b9b1784ff30ec732812f4402d54
Author: Arnab Phani <[email protected]>
AuthorDate: Fri Dec 29 01:16:36 2023 +0100

    [SYSTEMDS-3660] GPU cache eviction operator and related rewrite
    This patch introduces a new operator, _evict, to clean up the free
    pointer cached in the lineage cache. A shift in the allocation pattern
    leads to large eviction overhead and memory fragmentation. To address
    that, we speculatively clear a fraction of the free pointers. Currently,
    we place a _evict before every mini-batch processing.
    
    Closes #1964
---
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 .../apache/sysds/conf/ConfigurationManager.java    |   4 +
 .../java/org/apache/sysds/hops/OptimizerUtils.java |   6 +
 src/main/java/org/apache/sysds/hops/UnaryOp.java   |   3 +
 .../org/apache/sysds/lops/rewrite/LopRewriter.java |   1 +
 .../sysds/lops/rewrite/RewriteAddGPUEvictLop.java  | 115 ++++
 .../runtime/instructions/CPInstructionParser.java  |   5 +
 .../runtime/instructions/cp/CPInstruction.java     |   1 +
 .../instructions/cp/EvictCPInstruction.java        |  49 ++
 .../runtime/lineage/LineageGPUCacheEviction.java   |   8 +-
 .../lineage/GPULineageCacheEvictionTest.java       |  16 +-
 .../functions/lineage/GPUCacheEviction6.dml        | 746 +++++++++++++++++++++
 12 files changed, 953 insertions(+), 3 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 84019e8078..30cd6bf5bd 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -341,7 +341,7 @@ public class Types
                CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
                CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, 
INVERSE,
                IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
-               MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, 
STOP,
+               MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, 
STOP, _EVICT,
                SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
                //fused ML-specific operators for performance 
                SPROP, //sample proportion: P * (1 - P)
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java 
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 1ac4d13974..8c7d5547f5 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -294,6 +294,10 @@ public class ConfigurationManager{
                        || OptimizerUtils.RULE_BASED_GPU_EXEC));
        }
 
+       public static boolean isAutoEvictionEnabled() {
+               return OptimizerUtils.AUTO_GPU_CACHE_EVICTION;
+       }
+
        public static ILinearize.DagLinearization getLinearizationOrder() {
                if (OptimizerUtils.COST_BASED_ORDERING)
                        return ILinearize.DagLinearization.AUTO;
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index dc2bc487ed..8953cba378 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -310,6 +310,12 @@ public class OptimizerUtils
         */
        public static boolean RULE_BASED_GPU_EXEC = false;
 
+       /**
+        * Automatic placement of GPU lineage cache eviction
+        */
+
+       public static boolean AUTO_GPU_CACHE_EVICTION = true;
+
        //////////////////////
        // Optimizer levels //
        //////////////////////
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java 
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 5dbb55a303..d394beaf0e 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -145,6 +145,9 @@ public class UnaryOp extends MultiThreadedHop
                                case LOCAL:
                                        ret = new Local(input.constructLops(), 
getDataType(), getValueType());
                                        break;
+                               case _EVICT:
+                                       ret = new 
UnaryCP(input.constructLops(), _op, getDataType(), getValueType());
+                                       break;
                                default:
                                        final boolean isScalarIn = 
getInput().get(0).getDataType() == DataType.SCALAR;
                                        if(getDataType() == DataType.SCALAR // 
value type casts or matrix to scalar
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java 
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 8d2c0a63f8..88c1787843 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -45,6 +45,7 @@ public class LopRewriter
                _lopSBRuleSet.add(new RewriteAddBroadcastLop());
                _lopSBRuleSet.add(new RewriteAddChkpointLop());
                _lopSBRuleSet.add(new RewriteAddChkpointInLoop());
+               _lopSBRuleSet.add(new RewriteAddGPUEvictLop());
                // TODO: A rewrite pass to remove less effective chkpoints
                // Last rewrite to reset Lop IDs in a depth-first manner
                _lopSBRuleSet.add(new RewriteFixIDs());
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.java
new file mode 100644
index 0000000000..8618e6a2eb
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.lops.BinaryScalar;
+import org.apache.sysds.lops.Data;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.lops.RightIndex;
+import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.VariableSet;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class RewriteAddGPUEvictLop extends LopRewriteRule
+{
+       @Override
+       public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock 
sb) {
+               // TODO: Move this as a Statement block rewrite
+               if (!ConfigurationManager.isAutoEvictionEnabled())
+                       return List.of(sb);
+
+               if (sb == null || !(sb instanceof ForStatementBlock)
+                       || !DMLScript.USE_ACCELERATOR || 
LineageCacheConfig.ReuseCacheType.isNone())
+                       return List.of(sb);
+
+               // Collect the LOPs
+               StatementBlock csb = ((ForStatement) 
sb.getStatement(0)).getBody().get(0);
+               ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(csb);
+
+               // Check if this loop is for mini-batch processing
+               boolean isMiniBatch = findMiniBatchSlicing(lops);
+
+               // Insert statement block with _evict instruction before the 
loop
+               ArrayList<StatementBlock> ret = new ArrayList<>();
+               if (isMiniBatch) {
+                       int evictFrac = 100;
+                       StatementBlock sb0 = new StatementBlock();
+                       sb0.setDMLProg(sb.getDMLProg());
+                       sb0.setParseInfo(sb);
+                       sb0.setLiveIn(new VariableSet());
+                       sb0.setLiveOut(new VariableSet());
+                       // Create both lops and hops (hops for recompilation)
+                       // TODO: Add another input for the backend 
(GPU/CPU/Spark)
+                       ArrayList<Lop> newlops = new ArrayList<>();
+                       ArrayList<Hop> newhops = new ArrayList<>();
+                       Lop fr = Data.createLiteralLop(Types.ValueType.INT64, 
Integer.toString(evictFrac));
+                       fr.getOutputParameters().setDimensions(0, 0, 0, -1);
+                       UnaryCP evict = new UnaryCP(fr, Types.OpOp1._EVICT, 
fr.getDataType(), fr.getValueType(), Types.ExecType.CP);
+                       Hop in = new LiteralOp(evictFrac);
+                       Hop evictHop = new UnaryOp("tmp", 
Types.DataType.SCALAR, Types.ValueType.INT64, Types.OpOp1._EVICT, in);
+                       newlops.add(evict);
+                       newhops.add(evictHop);
+                       sb0.setLops(newlops);
+                       sb0.setHops(newhops);
+                       ret.add(sb0);
+               }
+               ret.add(sb);
+
+               return ret;
+       }
+
+       @Override
+       public List<StatementBlock> 
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+               return sbs;
+       }
+
+       // To verify mini-batch processing, match the below pattern
+       // beg = ((i-1) * batch_size) %% N + 1;
+       // end = min(N, beg+batch_size-1);
+       // X_batch = X[beg:end];
+       private boolean findMiniBatchSlicing(ArrayList<Lop> lops) {
+               for (Lop l : lops) {
+                       if (l instanceof RightIndex) {
+                               ArrayList<Lop> inputs = l.getInputs();
+                               if (inputs.get(0) instanceof Data && ((Data) 
inputs.get(0)).isTransientRead()
+                                       && inputs.get(0).getInputs().size() == 
0                //input1 is the dataset
+                                       && inputs.get(1) instanceof 
BinaryScalar                //input2 is beg
+                                       && ((BinaryScalar) 
inputs.get(1)).getOperationType() == Types.OpOp2.PLUS
+                                       && inputs.get(2) instanceof 
BinaryScalar                //input3 is end
+                                       && ((BinaryScalar) 
inputs.get(2)).getOperationType() == Types.OpOp2.MIN)
+                                       return true;
+                       }
+               }
+               return false;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 3de9fcd65d..c73d755b5e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -50,6 +50,7 @@ import 
org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.DeCompressionCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.DnnCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.EvictCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.LocalCPInstruction;
@@ -337,6 +338,7 @@ public class CPInstructionParser extends InstructionParser {
                String2CPInstructionType.put( DeCompression.OPCODE, 
CPType.DeCompression);
                String2CPInstructionType.put( "spoof",     CPType.SpoofFused);
                String2CPInstructionType.put( "prefetch",  CPType.Prefetch);
+               String2CPInstructionType.put( "_evict",  
CPType.EvictLineageCache);
                String2CPInstructionType.put( "broadcast",  CPType.Broadcast);
                String2CPInstructionType.put( "trigremote",  CPType.TrigRemote);
                String2CPInstructionType.put( Local.OPCODE, CPType.Local);
@@ -483,6 +485,9 @@ public class CPInstructionParser extends InstructionParser {
                                
                        case Broadcast:
                                return 
BroadcastCPInstruction.parseInstruction(str);
+
+                       case EvictLineageCache:
+                               return EvictCPInstruction.parseInstruction(str);
                        
                        default:
                                throw new DMLRuntimeException("Invalid CP 
Instruction Type: " + cptype );
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index 3503b256f7..1398d4365b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -45,6 +45,7 @@ public abstract class CPInstruction extends Instruction {
                Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, 
Local,
                MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, 
Compression, DeCompression, SpoofFused,
                StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, 
Sql, Prefetch, Broadcast, TrigRemote,
+               EvictLineageCache,
                NoOp,
         }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvictCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvictCPInstruction.java
new file mode 100644
index 0000000000..d958f6e1ed
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvictCPInstruction.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageGPUCacheEviction;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class EvictCPInstruction extends UnaryCPInstruction
+{
+       private EvictCPInstruction(Operator op, CPOperand in, CPOperand out, 
String opcode, String istr) {
+               super(CPType.EvictLineageCache, op, in, out, opcode, istr);
+       }
+
+       public static EvictCPInstruction parseInstruction(String str) {
+               InstructionUtils.checkNumFields(str, 3);
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+               CPOperand in = new CPOperand(parts[1]);
+               CPOperand out = new CPOperand(parts[2]);
+               return new EvictCPInstruction(null, in, out, opcode, str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               // Evict fraction of cached objects
+               ScalarObject fr = ec.getScalarInput(input1);
+               double evictFrac = ((double) fr.getLongValue()) / 100;
+               LineageGPUCacheEviction.removeAllEntries(evictFrac);
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
index 7eac5e4a54..5497210999 100644
--- 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
@@ -131,10 +131,13 @@ public class LineageGPUCacheEviction
                }
        }
 
-       public static void removeAllEntries() {
+       // Speculative eviction
+       public static void removeAllEntries(double evictFrac) {
                List<Long> sizes = new ArrayList<>(freeQueues.keySet());
                for (Long size : sizes) {
                        TreeSet<LineageCacheEntry> freeList = 
freeQueues.get(size);
+                       int evictLim = (int) (freeList.size() * evictFrac);
+                       int evictCount = 1;
                        LineageCacheEntry le = pollFirstFreeEntry(size);
                        while (le != null) {
                                // Free the pointer
@@ -142,6 +145,9 @@ public class LineageGPUCacheEviction
                                if (DMLScript.STATISTICS)
                                        
LineageCacheStatistics.incrementGpuDel();
                                le = pollFirstFreeEntry(size);
+                               if (evictCount > evictLim)
+                                       break;
+                               evictCount++;
                        }
                }
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java
index cc59da7de9..0a12ffb9af 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/GPULineageCacheEvictionTest.java
@@ -29,6 +29,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
 import org.junit.Assume;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -39,7 +41,7 @@ public class GPULineageCacheEvictionTest extends 
AutomatedTestBase{
        
        protected static final String TEST_DIR = "functions/lineage/";
        protected static final String TEST_NAME = "GPUCacheEviction";
-       protected static final int TEST_VARIANTS = 5;
+       protected static final int TEST_VARIANTS = 6;
        protected String TEST_CLASS_DIR = TEST_DIR + 
GPULineageCacheEvictionTest.class.getSimpleName() + "/";
        
        @BeforeClass
@@ -80,6 +82,11 @@ public class GPULineageCacheEvictionTest extends 
AutomatedTestBase{
                testLineageTraceExec(TEST_NAME+"5");
        }
 
+       @Test
+       public void TransferLearning3Models() {  //transfer learning and reuse 
(AlexNet,VGG,ResNet)
+               testLineageTraceExec(TEST_NAME+"6");
+       }
+
 
        private void testLineageTraceExec(String testname) {
                System.out.println("------------ BEGIN " + testname + 
"------------");
@@ -117,6 +124,13 @@ public class GPULineageCacheEvictionTest extends 
AutomatedTestBase{
 
                //compare results 
                TestUtils.compareMatrices(R_orig, R_reused, 1e-6, "Origin", 
"Reused");
+
+               //Match _evict count
+               if (testname.equalsIgnoreCase(TEST_NAME+"6")) {
+                       long exp_numev = 3;
+                       long numev = Statistics.getCPHeavyHitterCount("_evict");
+                       Assert.assertTrue("Violated Prefetch instruction count: 
"+numev, numev == exp_numev);
+               }
        }
 }
 
diff --git a/src/test/scripts/functions/lineage/GPUCacheEviction6.dml 
b/src/test/scripts/functions/lineage/GPUCacheEviction6.dml
new file mode 100644
index 0000000000..fdf4285610
--- /dev/null
+++ b/src/test/scripts/functions/lineage/GPUCacheEviction6.dml
@@ -0,0 +1,746 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+conv2d_forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
+  int C, int Hin, int Win, int Hf, int Wf, int strideh, int stridew,
+  int padh, int padw) return (matrix[double] out, int Hout, int Wout)
+{
+  N = nrow(X)
+  F = nrow(W)
+  Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
+  Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1))
+  # Convolution - built-in implementation
+  out = conv2d(X, W, input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf],
+               stride=[strideh,stridew], padding=[padh,padw])
+  # Add bias term to each output filter
+  out = bias_add(out, b)
+}
+
+conv2d_backward = function(matrix[double] dout, int Hout, int Wout, 
matrix[double] X,
+  matrix[double] W, matrix[double] b, int C, int Hin, int Win, int Hf, int Wf,
+  int strideh, int stridew, int padh, int padw)
+  return (matrix[double] dX, matrix[double] dW, matrix[double] db)
+{
+  N = nrow(X)
+  F = nrow(W)
+  # Partial derivatives for convolution - built-in implementation
+  dW = conv2d_backward_filter(X, dout, stride=[strideh,stridew], 
padding=[padh,padw],
+                              input_shape=[N,C,Hin,Win], 
filter_shape=[F,C,Hf,Wf])
+  dX = conv2d_backward_data(W, dout, stride=[strideh,stridew], 
padding=[padh,padw],
+                            input_shape=[N,C,Hin,Win], 
filter_shape=[F,C,Hf,Wf])
+  # Partial derivatives for bias vector
+  # Here we sum each column, reshape to (F, Hout*Wout), and sum each row
+  # to result in the summation for each channel.
+  db = rowSums(matrix(colSums(dout), rows=F, cols=Hout*Wout))  # shape (F, 1)
+}
+
+conv2d_init = function(int F, int C, int Hf, int Wf, int seed = -1)
+  return (matrix[double] W, matrix[double] b) {
+  W = rand(rows=F, cols=C*Hf*Wf, pdf="normal", seed=seed) * sqrt(2.0/(C*Hf*Wf))
+  b = matrix(0, rows=F, cols=1)
+}
+
+bn2d_forward = function(matrix[double] X, int C, int Hin, int Win, 
+    double mu, double epsilon) return (matrix[double] out)
+{
+    gamma = matrix(1, rows=C, cols=1)
+    beta = matrix(0, rows=C, cols=1)
+    ema_mean = matrix(0, rows=C, cols=1)
+    ema_var = matrix(1, rows=C, cols=1)
+    ema_mean_upd = ema_mean; 
+    ema_var_upd = ema_var;  
+    cache_mean = ema_mean; 
+    cache_inv_var = ema_var
+    mode = 'train';
+    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = 
batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu)
+}
+
+affine_forward = function(matrix[double] X, matrix[double] W, matrix[double] 
b) return (matrix[double] out) {
+  out = X %*% W + b;
+}
+
+affine_init = function(int D, int M, int seed = -1 ) return (matrix[double] W, 
matrix[double] b) {
+  W = rand(rows=D, cols=M, pdf="normal", seed=seed) * sqrt(2.0/D);
+  b = matrix(0, rows=1, cols=M);
+}
+
+relu_forward = function(matrix[double] X) return (matrix[double] out) {
+  out = max(0, X);
+}
+
+max_pool2d_forward = function(matrix[double] X, int C, int Hin, int Win, int 
Hf, int Wf,
+  int strideh, int stridew, int padh, int padw) return(matrix[double] out, int 
Hout, int Wout)
+{
+  N = nrow(X)
+  Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
+  Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1))
+  out = max_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf],
+    stride=[strideh,stridew], padding=[padh,padw])
+}
+
+avg_pool2d_forward = function(matrix[double] X, int C, int Hin, int Win)
+  return (matrix[double] out, int Hout, int Wout) {
+  N = nrow(X)
+  Hout = 1
+  Wout = 1
+  out = avg_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hin,Win], 
stride=[1,1], padding=[0, 0])
+}
+
+softmax_forward = function(matrix[double] scores) return (matrix[double] 
probs) {
+  scores = scores - rowMaxs(scores);  # numerical stability
+  unnorm_probs = exp(scores);  # unnormalized probabilities
+  probs = unnorm_probs / rowSums(unnorm_probs);  # normalized probabilities
+}
+
+basic_block = function(matrix[double] X, int C, int C_base, int Hin, int Win, 
int strideh,
+    int stridew, matrix[double] WC1, matrix[double] bC1, matrix[double] WC2, 
matrix[double] bC2)
+  return (matrix[double] out, int Hout, int Wout)
+{
+  mu_bn = 0.1;
+  ep_bn = 1e-05;
+  downsample = strideh > 1 | stridew > 1 | C != C_base;
+  if (downsample) {
+    [WC3, bC3] = conv2d_init(C_base, C, Hf=1, Wf=1, 42);
+  }
+  # Residual Path
+  # conv1 -> bn1 -> relu1
+  [out, Hout, Wout] = 
conv2d_forward(X,WC1,bC1,C,Hin,Win,3,3,strideh,stridew,1,1);
+  out = bn2d_forward(out,C_base,Hout,Wout,mu_bn,ep_bn);
+  out = relu_forward(out);
+  # conv2 -> bn2 -> relu2
+  [out, Hout, Wout] = conv2d_forward(out,WC2,bC2,C_base,Hout,Wout,3,3,1,1,1,1);
+  out = bn2d_forward(out,C_base,Hout,Wout,mu_bn,ep_bn);
+  # Identity Path
+  identity = X;
+  if (downsample) {
+    # Downsample input
+    [identity, Hout, Wout] = 
conv2d_forward(X,WC3,bC3,C,Hin,Win,1,1,strideh,stridew,0,0);
+    out = bn2d_forward(identity,C_base,Hout,Wout,mu_bn,ep_bn);
+  }
+  out = relu_forward(out + identity);
+}
+
+getWeights = function(int fel, int lid,
+    matrix[double] W_pt, matrix[double] b_pt,
+    matrix[double] W_init, matrix[double] b_init)
+  return (matrix[double] Wl, matrix[double] bl)
+{
+  if (lid < fel) { #extract pretrained features
+    Wl = W_pt;
+    bl = b_pt;
+  }
+  else {  #use initialized weights
+    Wl = W_init;
+    bl = b_init;
+  }
+}
+
+rwRowIndexMax = function(matrix[double] X, matrix[double] oneVec, 
matrix[double] idxSeq)
+    return (matrix[double] index) {
+  rm = rowMaxs(X) %*% oneVec;
+  I = X == rm;
+  index = rowMaxs(I * idxSeq);
+}
+
+####################################################################
+
+# Exploratory feature extraction from pre-trained resnet18 model 
+predict_resnet18 = function(matrix[double] X, int C, int Hin, int Win, int K)
+  return (matrix[double] Y_pred)
+{
+  mu_bn = 0.1;
+  ep_bn = 1e-05;
+
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(64, C, Hf=7, Wf=7, 42);
+  [W2_pt, b2_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W3_pt, b3_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = conv2d_init(128, 64, Hf=3, Wf=3, 42);
+  [W7_pt, b7_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W8_pt, b8_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W9_pt, b9_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W10_pt, b10_pt] = conv2d_init(256, 128, Hf=3, Wf=3, 42);
+  [W11_pt, b11_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W12_pt, b12_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W13_pt, b13_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W14_pt, b14_pt] = conv2d_init(512, 256, Hf=3, Wf=3, 42);
+  [W15_pt, b15_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W16_pt, b16_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W17_pt, b17_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W18_pt, b18_pt] = affine_init(512, K, 42);
+  W18_pt = W18_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(64, C, Hf=7, Wf=7, 43);
+  [W2_init, b2_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W3_init, b3_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = conv2d_init(128, 64, Hf=3, Wf=3, 43);
+  [W7_init, b7_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W8_init, b8_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W9_init, b9_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W10_init, b10_init] = conv2d_init(256, 128, Hf=3, Wf=3, 42);
+  [W11_init, b11_init] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W12_init, b12_init] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W13_init, b13_init] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W14_init, b14_init] = conv2d_init(512, 256, Hf=3, Wf=3, 42);
+  [W15_init, b15_init] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W16_init, b16_init] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W17_init, b17_init] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W18_init, b18_init] = affine_init(512, K, 42);
+  W18_init = W18_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  Y_pred = matrix(0, rows=N, cols=3);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 3 layers
+    j = 1;
+    fel = 10; #extract 9, 8, 7, 6 
+    while (j < 4) {
+      # Compute forward pass
+      # Layer1: conv2d 7x7 -> bn -> relu -> maxpool 3x3
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = 
conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,7,7,2,2,3,3);
+      outb1 = bn2d_forward(outc1,64,Houtc1,Woutc1,mu_bn,ep_bn);
+      outr1 = relu_forward(outb1);
+      [outp1, Houtp1, Woutp1] = max_pool2d_forward(outr1,64,Houtc1, 
Woutc1,3,3,2,2,1,1);
+
+      # Layer2: residual block1
+      lid = 2;
+      [Wc1, bc1] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [Wc2, bc2] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outrb1, Houtrb1, Woutrb1] = 
basic_block(outp1,64,64,Houtp1,Woutp1,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer3: residual block2
+      lid = 3;
+      [Wc1, bc1] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [Wc2, bc2] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outrb2, Houtrb2, Woutrb2] = 
basic_block(outrb1,64,64,Houtrb1,Woutrb1,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer4: residual block3
+      lid = 4;
+      [Wc1, bc1] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      [Wc2, bc2] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      [outrb3, Houtrb3, Woutrb3] = 
basic_block(outrb2,64,128,Houtrb2,Woutrb2,2,2,Wc1,bc1,Wc2,bc2);
+
+      # Layer5: residual block4
+      lid = 5;
+      [Wc1, bc1] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      [Wc2, bc2] = getWeights(fel, lid, W9_pt, b9_pt, W9_init, b9_init);
+      [outrb4, Houtrb4, Woutrb4] = 
basic_block(outrb3,128,128,Houtrb3,Woutrb3,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer6: residual block5
+      lid = 6;
+      [Wc1, bc1] = getWeights(fel, lid, W10_pt, b10_pt, W10_init, b10_init);
+      [Wc2, bc2] = getWeights(fel, lid, W11_pt, b11_pt, W11_init, b11_init);
+      [outrb5, Houtrb5, Woutrb5] = 
basic_block(outrb4,128,256,Houtrb4,Woutrb4,2,2,Wc1,bc1,Wc2,bc2);
+
+      # Layer7: residual block6
+      lid = 7;
+      [Wc1, bc1] = getWeights(fel, lid, W12_pt, b12_pt, W12_init, b12_init);
+      [Wc2, bc2] = getWeights(fel, lid, W13_pt, b13_pt, W13_init, b13_init);
+      [outrb6, Houtrb6, Woutrb6] = 
basic_block(outrb5,256,256,Houtrb5,Woutrb5,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Layer8: residual block7
+      lid = 8;
+      [Wc1, bc1] = getWeights(fel, lid, W14_pt, b14_pt, W14_init, b14_init);
+      [Wc2, bc2] = getWeights(fel, lid, W15_pt, b15_pt, W15_init, b15_init);
+      [outrb7, Houtrb7, Woutrb7] = 
basic_block(outrb6,256,512,Houtrb6,Woutrb6,2,2,Wc1,bc1,Wc2,bc2);
+
+      # Layer9: residual block8
+      lid = 9;
+      [Wc1, bc1] = getWeights(fel, lid, W16_pt, b16_pt, W16_init, b16_init);
+      [Wc2, bc2] = getWeights(fel, lid, W17_pt, b17_pt, W17_init, b17_init);
+      [outrb8, Houtrb8, Woutrb8] = 
basic_block(outrb7,512,512,Houtrb7,Woutrb7,1,1,Wc1,bc1,Wc2,bc2);
+
+      # Global average pooling 
+      [outap1, Houtap1, Houtap2] = avg_pool2d_forward(outrb8, 512, Houtrb8, 
Woutrb8);
+
+      # layer10 : Fully connected layer
+      lid = 10;
+      [Wl10, bl10] = getWeights(fel, lid, W18_pt, b18_pt, W18_init, b18_init);
+      outa1 = affine_forward(outap1, Wl10, bl10);
+      probs_batch = softmax_forward(outa1);
+
+      # Store the predictions
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+
+# Exploratory feature extraction from pre-trained VGG16 model 
+predict_vgg = function(matrix[double] X, int C, int Hin, int Win, int K, int 
dim)
+  return (matrix[double] Y_pred)
+{
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(64, C, Hf=3, Wf=3, 42);
+  [W2_pt, b2_pt] = conv2d_init(64, 64, Hf=3, Wf=3, 42);
+  [W3_pt, b3_pt] = conv2d_init(128, 64, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(128, 128, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(256, 128, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W7_pt, b7_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W8_pt, b8_pt] = conv2d_init(512, 256, Hf=3, Wf=3, 42);
+  [W9_pt, b9_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W10_pt, b10_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W11_pt, b11_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W12_pt, b12_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  [W13_pt, b13_pt] = conv2d_init(512, 512, Hf=3, Wf=3, 42);
+  if (dim == 224)
+    [W14_pt, b14_pt] = affine_init(25088, 4096, 42);
+  if (dim == 32)
+    [W14_pt, b14_pt] = affine_init(512, 4096, 42);
+  [W15_pt, b15_pt] = affine_init(4096, 4096, 42);
+  [W16_pt, b16_pt] = affine_init(4096, K, 42);
+  W16_pt = W16_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(64, C, Hf=3, Wf=3, 43);
+  [W2_init, b2_init] = conv2d_init(64, 64, Hf=3, Wf=3, 43);
+  [W3_init, b3_init] = conv2d_init(128, 64, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(128, 128, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(256, 128, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = conv2d_init(256, 256, Hf=3, Wf=3, 43);
+  [W7_init, b7_init] = conv2d_init(256, 256, Hf=3, Wf=3, 43);
+  [W8_init, b8_init] = conv2d_init(512, 256, Hf=3, Wf=3, 43);
+  [W9_init, b9_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W10_init, b10_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W11_init, b11_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W12_init, b12_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  [W13_init, b13_init] = conv2d_init(512, 512, Hf=3, Wf=3, 43);
+  if (dim == 224)
+    [W14_init, b14_init] = affine_init(25088, 4096, 43);
+  if (dim == 32)
+    [W14_init, b14_init] = affine_init(512, 4096, 43);
+  [W15_init, b15_init] = affine_init(4096, 4096, 43);
+  [W16_init, b16_init] = affine_init(4096, K, 43);
+  W16_init = W16_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  Y_pred = matrix(0, rows=N, cols=3);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 3 layers
+    j = 1;
+    fel = 8; #extract 7, 6, 5
+    while (j < 4) {
+      # Compute forward pass
+      # layer 1: Two conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = 
conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,3,3,1,1,1,1);
+      outr1 = relu_forward(outc1);
+      [Wl2, bl2] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [outc2, Houtc2, Woutc2] = 
conv2d_forward(outr1,Wl2,bl2,64,Houtc1,Woutc1,3,3,1,1,1,1);
+      outr2 = relu_forward(outc2);
+      [outp1, Houtp1, Woutp1] = max_pool2d_forward(outr2,64,Houtc2, 
Woutc2,2,2,2,2,0,0);
+
+      # layer 2: Two conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 2;
+      [Wl3, bl3] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outc3, Houtc3, Woutc3] = 
conv2d_forward(outp1,Wl3,bl3,64,Houtp1,Woutp1,3,3,1,1,1,1);
+      outr3 = relu_forward(outc3);
+      [Wl4, bl4] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [outc4, Houtc4, Woutc4] = 
conv2d_forward(outr3,Wl4,bl4,128,Houtc3,Woutc3,3,3,1,1,1,1);
+      outr4 = relu_forward(outc4);
+      [outp2, Houtp2, Woutp2] = max_pool2d_forward(outr4,128,Houtc4, 
Woutc4,2,2,2,2,0,0);
+
+      # layer 3: Three conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 3;
+      [Wl5, bl5] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outc5, Houtc5, Woutc5] = 
conv2d_forward(outp2,Wl5,bl5,128,Houtp2,Woutp2,3,3,1,1,1,1);
+      outr5 = relu_forward(outc5);
+      [Wl6, bl6] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      [outc6, Houtc6, Woutc6] = 
conv2d_forward(outr5,Wl6,bl6,256,Houtc5,Woutc5,3,3,1,1,1,1);
+      outr6 = relu_forward(outc6);
+      [Wl7, bl7] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      [outc7, Houtc7, Woutc7] = 
conv2d_forward(outr6,Wl7,bl7,256,Houtc6,Woutc6,3,3,1,1,1,1);
+      outr7 = relu_forward(outc7);
+      [outp3, Houtp3, Woutp3] = max_pool2d_forward(outr7,256,Houtc7, 
Woutc7,2,2,2,2,0,0);
+
+      # layer 4: Three conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 4;
+      [Wl8, bl8] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      [outc8, Houtc8, Woutc8] = 
conv2d_forward(outp3,Wl8,bl8,256,Houtp3,Woutp3,3,3,1,1,1,1);
+      outr8 = relu_forward(outc8);
+      [Wl9, bl9] = getWeights(fel, lid, W9_pt, b9_pt, W9_init, b9_init);
+      [outc9, Houtc9, Woutc9] = 
conv2d_forward(outr8,Wl9,bl9,512,Houtc8,Woutc8,3,3,1,1,1,1);
+      outr9 = relu_forward(outc9);
+      [Wl10, bl10] = getWeights(fel, lid, W10_pt, b10_pt, W10_init, b10_init);
+      [outc10, Houtc10, Woutc10] = 
conv2d_forward(outr9,Wl10,bl10,512,Houtc9,Woutc9,3,3,1,1,1,1);
+      outr10 = relu_forward(outc10);
+      [outp4, Houtp4, Woutp4] = max_pool2d_forward(outr10,512,Houtc10, 
Woutc10,2,2,2,2,0,0);
+
+      # layer 5: Three conv2d layers (w/ activation relu) + 1 max-pooling layer
+      lid = 5;
+      [Wl11, bl11] = getWeights(fel, lid, W11_pt, b11_pt, W11_init, b11_init);
+      [outc11, Houtc11, Woutc11] = 
conv2d_forward(outp4,Wl11,bl11,512,Houtp4,Woutp4,3,3,1,1,1,1);
+      outr11 = relu_forward(outc11);
+      [Wl12, bl12] = getWeights(fel, lid, W12_pt, b12_pt, W12_init, b12_init);
+      [outc12, Houtc12, Woutc12] = 
conv2d_forward(outr11,Wl12,bl12,512,Houtc11,Woutc11,3,3,1,1,1,1);
+      outr12 = relu_forward(outc12);
+      [Wl13, bl13] = getWeights(fel, lid, W13_pt, b13_pt, W13_init, b13_init);
+      [outc13, Houtc13, Woutc13] = 
conv2d_forward(outr12,Wl13,bl13,512,Houtc12,Woutc12,3,3,1,1,1,1);
+      outr13 = relu_forward(outc13);
+      [outp5, Houtp5, Woutp5] = max_pool2d_forward(outr13,512,Houtc13, 
Woutc13,2,2,2,2,0,0);
+
+      # layer 6: Fully connected layer (w/ activation relu)
+      lid = 6;
+      [Wl14, bl14] = getWeights(fel, lid, W14_pt, b14_pt, W14_init, b14_init);
+      outa6 = affine_forward(outp5, Wl14, bl14);
+      outr6 = relu_forward(outa6);
+
+      # layer 7: Fully connected layer (w/ activation relu)
+      lid = 7;
+      [Wl15, bl15] = getWeights(fel, lid, W15_pt, b15_pt, W15_init, b15_init);
+      outa7 = affine_forward(outr6, Wl15, bl15);
+      outr7 = relu_forward(outa7);
+
+      # layer 8: Fully connected layer (w/ activation softmax)
+      lid = 8;
+      [Wl16, bl16] = getWeights(fel, lid, W16_pt, b16_pt, W16_init, b16_init);
+      outa8 = affine_forward(outr7, Wl16, bl16);
+      probs_batch = softmax_forward(outa8);
+
+      # Store the predictions
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+predict_alex = function(matrix[double] X, int C, int Hin, int Win, int K) 
+  return (matrix[double] Y_pred)
+{
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(96, C, Hf=11, Wf=11, 42);
+  [W2_pt, b2_pt] = conv2d_init(256, 96, Hf=5, Wf=5, 42);
+  [W3_pt, b3_pt] = conv2d_init(384, 256, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(384, 384, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(256, 384, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = affine_init(6400, 4096, 42); 
+  [W7_pt, b7_pt] = affine_init(4096, 4096, 42);
+  [W8_pt, b8_pt] = affine_init(4096, K, 42);
+  W8_pt = W8_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(96, C, Hf=11, Wf=11, 43);
+  [W2_init, b2_init] = conv2d_init(256, 96, Hf=5, Wf=5, 43);
+  [W3_init, b3_init] = conv2d_init(384, 256, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(384, 384, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(256, 384, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = affine_init(6400, 4096, 43);
+  [W7_init, b7_init] = affine_init(4096, 4096, 43);
+  [W8_init, b8_init] = affine_init(4096, K, 43);
+  W8_init = W8_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  verbose = FALSE;
+  Y_pred = matrix(0, rows=N, cols=4);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 4 layers
+    j = 1;
+    fel = 8;
+    while (j < 5) {
+      # Compute forward pass
+      # layer 1: conv1 -> relu1 -> pool1
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = 
conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,11,11,4,4,0,0);
+      if(verbose) print("sum(conv1) = "+sum(outc1));
+      if(verbose) print(nrow(outc1)+", "+ncol(outc1));
+      outr1 = relu_forward(outc1);
+      [outp1, Houtp1, Woutp1] = 
max_pool2d_forward(outr1,96,Houtc1,Woutc1,3,3,2,2,0,0)
+      if(verbose) print("sum(pool1) = "+sum(outp1));
+      if(verbose) print(nrow(outp1)+", "+ncol(outp1));
+
+      # layer 2: conv2 -> relu2 -> pool2
+      lid = 2;
+      [Wl2, bl2] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [outc2, Houtc2, Woutc2] = 
conv2d_forward(outp1,Wl2,bl2,96,Houtp1,Woutp1,5,5,1,1,2,2);
+      if(verbose) print("sum(conv2) = "+sum(outc2));
+      if(verbose) print(nrow(outc2)+", "+ncol(outc2));
+      outr2 = relu_forward(outc2);
+      [outp2, Houtp2, Woutp2] = 
max_pool2d_forward(outr2,256,Houtc2,Woutc2,3,3,2,2,0,0);
+      if(verbose) print("sum(pool2) = "+sum(outp2));
+      if(verbose) print(nrow(outp2)+", "+ncol(outp2));
+
+      # layer 3: conv3 -> relu3
+      lid = 3;
+      [Wl3, bl3] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outc3, Houtc3, Woutc3] = 
conv2d_forward(outp2,Wl3,bl3,256,Houtp2,Woutp2,3,3,1,1,1,1);
+      if(verbose) print("sum(conv3) = "+sum(outc3));
+      if(verbose) print(nrow(outc3)+", "+ncol(outc3));
+      outr3 = relu_forward(outc3);
+      
+      # layer 4: conv4 -> relu4
+      lid = 4;
+      [Wl4, bl4] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [outc4, Houtc4, Woutc4] = 
conv2d_forward(outr3,Wl4,bl4,384,Houtc3,Woutc3,3,3,1,1,1,1);
+      if(verbose) print("sum(conv4) = "+sum(outc4));
+      if(verbose) print(nrow(outc4)+", "+ncol(outc4));
+      outr4 = relu_forward(outc4);
+
+      # layer 5: conv5 -> relu5 -> pool3
+      lid = 5;
+      [Wl5, bl5] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outc5, Houtc5, Woutc5] = 
conv2d_forward(outr4,Wl5,bl5,384,Houtc4,Woutc4,3,3,1,1,1,1);
+      if(verbose) print("sum(conv5) = "+sum(outc5));
+      if(verbose) print(nrow(outc5)+", "+ncol(outc5));
+      outr5 = relu_forward(outc5);
+      [outp5, Houtp5, Woutp5] = 
max_pool2d_forward(outr5,256,Houtc5,Woutc5,3,3,2,2,0,0)
+      if(verbose) print("sum(pool3) = "+sum(outp5));
+      if(verbose) print(nrow(outp5)+", "+ncol(outp5));
+
+      # layer 6: affine1 -> relu6
+      lid = 6;
+      [Wl6, bl6] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      outa6 = affine_forward(outp5, Wl6, bl6);
+      if(verbose) print(nrow(outa6)+", "+ncol(outa6));
+      outr6 = relu_forward(outa6);
+
+      # layer 7: affine2 -> relu7
+      lid = 7;
+      [Wl7, bl7] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      outa7 = affine_forward(outr6, Wl7, bl7);
+      if(verbose) print(nrow(outa7)+", "+ncol(outa7));
+      outr7 = relu_forward(outa7);
+
+      # layer 8: affine3 -> softmax
+      lid = 8;
+      [Wl8, bl8] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      outa8 = affine_forward(outr7, Wl8, bl8);
+      if(verbose) print(nrow(outa8)+", "+ncol(outa8));
+      probs_batch = softmax_forward(outa8);
+
+      # Store the predicted classes
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+predict_alex_32 = function(matrix[double] X, int C, int Hin, int Win, int K)
+  return (matrix[double] Y_pred)
+{
+  # Get the transferred layers. FIXME: use pretrained weights
+  [W1_pt, b1_pt] = conv2d_init(64, C, Hf=11, Wf=11, 42);
+  [W2_pt, b2_pt] = conv2d_init(192, 64, Hf=5, Wf=5, 42);
+  [W3_pt, b3_pt] = conv2d_init(384, 192, Hf=3, Wf=3, 42);
+  [W4_pt, b4_pt] = conv2d_init(256, 384, Hf=3, Wf=3, 42);
+  [W5_pt, b5_pt] = conv2d_init(256, 256, Hf=3, Wf=3, 42);
+  [W6_pt, b6_pt] = affine_init(256, 4096, 42);
+  [W7_pt, b7_pt] = affine_init(4096, 4096, 42);
+  [W8_pt, b8_pt] = affine_init(4096, K, 42);
+  W8_pt = W8_pt/sqrt(2);
+
+  # Initialize the weights for the non-transferred layers
+  [W1_init, b1_init] = conv2d_init(64, C, Hf=11, Wf=11, 43);
+  [W2_init, b2_init] = conv2d_init(192, 64, Hf=5, Wf=5, 43);
+  [W3_init, b3_init] = conv2d_init(384, 192, Hf=3, Wf=3, 43);
+  [W4_init, b4_init] = conv2d_init(256, 384, Hf=3, Wf=3, 43);
+  [W5_init, b5_init] = conv2d_init(256, 256, Hf=3, Wf=3, 43);
+  [W6_init, b6_init] = affine_init(256, 4096, 43);
+  [W7_init, b7_init] = affine_init(4096, 4096, 43);
+  [W8_init, b8_init] = affine_init(4096, K, 43);
+  W8_init = W8_init/sqrt(2);
+
+  # Compute prediction over mini-batches
+  N = nrow(X);
+  verbose = FALSE;
+  Y_pred = matrix(0, rows=N, cols=4);
+  batch_size = 64;
+  oneVec = matrix(1, rows=1, cols=K);
+  idxSeq = matrix(1, rows=batch_size, cols=1) %*% t(seq(1, K));
+  iters = ceil (N / batch_size);
+  for (i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1;
+    end = min(N, beg+batch_size-1);
+    X_batch = X[beg:end,];
+
+    # Extract 4 layers
+    j = 1;
+    fel = 8;
+    while (j < 5) {
+      # Compute forward pass
+      # layer 1: conv1 -> relu1 -> pool1
+      lid = 1;
+      [Wl1, bl1] = getWeights(fel, lid, W1_pt, b1_pt, W1_init, b1_init);
+      [outc1, Houtc1, Woutc1] = 
conv2d_forward(X_batch,Wl1,bl1,C,Hin,Win,11,11,4,4,2,2);
+      if(verbose) print("sum(conv1) = "+sum(outc1));
+      if(verbose) print(nrow(outc1)+", "+ncol(outc1));
+      outr1 = relu_forward(outc1);
+      [outp1, Houtp1, Woutp1] = 
max_pool2d_forward(outr1,64,Houtc1,Woutc1,3,3,2,2,0,0)
+      if(verbose) print("sum(pool1) = "+sum(outp1));
+      if(verbose) print(nrow(outp1)+", "+ncol(outp1));
+
+       # layer 2: conv2 -> relu2 -> pool2
+      lid = 2;
+      [Wl2, bl2] = getWeights(fel, lid, W2_pt, b2_pt, W2_init, b2_init);
+      [outc2, Houtc2, Woutc2] = 
conv2d_forward(outp1,Wl2,bl2,64,Houtp1,Woutp1,5,5,1,1,2,2);
+      if(verbose) print("sum(conv2) = "+sum(outc2));
+      if(verbose) print(nrow(outc2)+", "+ncol(outc2));
+      outr2 = relu_forward(outc2);
+      [outp2, Houtp2, Woutp2] = 
max_pool2d_forward(outr2,192,Houtc2,Woutc2,3,3,2,2,0,0);
+      if(verbose) print("sum(pool2) = "+sum(outp2));
+      if(verbose) print(nrow(outp2)+", "+ncol(outp2));
+
+      # layer 3: conv3 -> relu3
+      lid = 3;
+      [Wl3, bl3] = getWeights(fel, lid, W3_pt, b3_pt, W3_init, b3_init);
+      [outc3, Houtc3, Woutc3] = 
conv2d_forward(outp2,Wl3,bl3,192,Houtp2,Woutp2,3,3,1,1,1,1);
+      if(verbose) print("sum(conv3) = "+sum(outc3));
+      if(verbose) print(nrow(outc3)+", "+ncol(outc3));
+      outr3 = relu_forward(outc3);
+
+      # layer 4: conv4 -> relu4
+      lid = 4;
+      [Wl4, bl4] = getWeights(fel, lid, W4_pt, b4_pt, W4_init, b4_init);
+      [outc4, Houtc4, Woutc4] = 
conv2d_forward(outr3,Wl4,bl4,384,Houtc3,Woutc3,3,3,1,1,1,1);
+      if(verbose) print("sum(conv4) = "+sum(outc4));
+      if(verbose) print(nrow(outc4)+", "+ncol(outc4));
+      outr4 = relu_forward(outc4);
+
+      # layer 5: conv5 -> relu5 -> pool3
+      lid = 5;
+      [Wl5, bl5] = getWeights(fel, lid, W5_pt, b5_pt, W5_init, b5_init);
+      [outc5, Houtc5, Woutc5] = 
conv2d_forward(outr4,Wl5,bl5,256,Houtc4,Woutc4,3,3,1,1,1,1);
+      if(verbose) print("sum(conv5) = "+sum(outc5));
+      if(verbose) print(nrow(outc5)+", "+ncol(outc5));
+      outr5 = relu_forward(outc5);
+      [outp5, Houtp5, Woutp5] = 
max_pool2d_forward(outr5,256,Houtc5,Woutc5,3,3,2,2,1,1)
+      if(verbose) print("sum(pool3) = "+sum(outp5));
+      if(verbose) print(nrow(outp5)+", "+ncol(outp5));
+
+      # layer 6: affine1 -> relu6
+      lid = 6;
+      [Wl6, bl6] = getWeights(fel, lid, W6_pt, b6_pt, W6_init, b6_init);
+      outa6 = affine_forward(outp5, Wl6, bl6);
+      if(verbose) print(nrow(outa6)+", "+ncol(outa6));
+      outr6 = relu_forward(outa6);
+
+      # layer 7: affine2 -> relu7
+      lid = 7;
+      [Wl7, bl7] = getWeights(fel, lid, W7_pt, b7_pt, W7_init, b7_init);
+      outa7 = affine_forward(outr6, Wl7, bl7);
+      if(verbose) print(nrow(outa7)+", "+ncol(outa7));
+      outr7 = relu_forward(outa7);
+
+      # layer 8: affine3 -> softmax
+      lid = 8;
+      [Wl8, bl8] = getWeights(fel, lid, W8_pt, b8_pt, W8_init, b8_init);
+      outa8 = affine_forward(outr7, Wl8, bl8);
+      if(verbose) print(nrow(outa8)+", "+ncol(outa8));
+      probs_batch = softmax_forward(outa8);
+
+      # Store the predicted classes
+      Y_pred[beg:end,j] = rwRowIndexMax(probs_batch, oneVec, idxSeq);
+      j = j + 1;
+      fel = fel - 1;
+    }
+  }
+}
+
+generate_dummy_data = function(int N, int C, int Hin, int Win, int K)
+  return (matrix[double] X, matrix[double] Y) {
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal", seed=45) #linearized images
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform", seed=46))
+  Y = table(seq(1, N), classes, N, K)  #one-hot encoding
+}
+
+##########################################################################
+
+# Read training data and settings
+N = 512;     #num of images in the target dataset
+C = 3;       #num of color channels
+K = 10;      #num of classes
+dataset = "cifar";
+if (dataset == "cifar")
+  Hin = 32; #input image height
+if (dataset == "imagenet")
+  Hin = 224; #input image height
+Win = Hin; #input image width
+
+# Generate dummy data
+[X, Y] = generate_dummy_data(N, C, Hin, Win, K);
+
+# Load the CuDNN libraries by calling a conv2d
+print("Eagerly loading cuDNN library");
+[W1, b1] = conv2d_init(96, C, Hf=11, Wf=11, 42);
+[outc1, Houtc1, Woutc1] = conv2d_forward(X[1:8,], W1, b1, C, Hin, Win, 11, 11, 
1, 1, 2, 2);
+print(sum(outc1));
+
+print("Starting exploratory feature transfers");
+Y_pred = matrix(0, rows=N, cols=10);
+t1 = time();
+if (Hin == 32)
+  Y_pred[,1:4] = predict_alex_32(X, C, Hin, Win, K);
+if (Hin == 224)
+  Y_pred[,1:4] = predict_alex(X, C, Hin, Win, K);
+Y_pred[,5:7] = predict_vgg(X, C, Hin, Win, K, Hin);
+Y_pred[,8:10] = predict_resnet18(X, C, Hin, Win, K);
+print(toString(colSums(Y_pred)));
+
+t2 = time();
+print("Elapsed time for feature transfers = "+floor((t2-t1)/1000000)+" 
millsec");
+write(Y_pred, $1, format="text");
+

Reply via email to