This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c2c8864  [SYSTEMDS-3084] Basic Auto-differentiation for Affine layer
c2c8864 is described below

commit c2c88645ac45b4106979841bbd3261ab3cc30169
Author: Shafaq Siddiqi <[email protected]>
AuthorDate: Fri Aug 6 14:14:31 2021 +0200

    [SYSTEMDS-3084] Basic Auto-differentiation for Affine layer
    
    This patch introduces autoDiff builtin. autoDiff takes the output
    and lineage trace of the last layer and a list of weights and
    biases, and returns their derivatives.
    It internally uses the lineage trace of the forward layer to
    construct the Hop dags for the derivatives (reuse common
    sub-dags), compile those and execute to produce the outputs.
    Current support is limited to Affine layer and local execution.
    
    AMLS project SS2021
    Closes #1350.
---
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  |   3 +-
 .../apache/sysds/lops/ParameterizedBuiltin.java    |   8 +-
 .../org/apache/sysds/parser/DMLTranslator.java     |   2 +-
 .../ParameterizedBuiltinFunctionExpression.java    |  23 +-
 .../sysds/runtime/functionobjects/Builtin.java     |   3 +-
 .../functionobjects/ParameterizedBuiltin.java      |   5 +-
 .../runtime/instructions/CPInstructionParser.java  |   3 +-
 .../runtime/instructions/SPInstructionParser.java  |   1 +
 .../cp/ParameterizedBuiltinCPInstruction.java      |  32 +--
 .../org/apache/sysds/runtime/util/AutoDiff.java    | 262 +++++++++++++++++++++
 .../sysds/test/functions/builtin/AutoDiffTest.java |  70 ++++++
 src/test/scripts/functions/builtin/autoDiff.dml    |  58 +++++
 14 files changed, 449 insertions(+), 24 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index a063b47..f2b6c6a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -272,6 +272,7 @@ public enum Builtins {
        XOR("xor", false),
 
        //parameterized builtin functions
+       AUTODIFF("autoDiff", false, true),
        CDF("cdf", false, true),
        CVLM("cvlm", true, false),
        GROUPEDAGG("aggregate", "groupedAggregate", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index da53091..d15adad 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -466,7 +466,7 @@ public class Types
        }
        
        public enum ParamBuiltinOp {
-               INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
+               AUTODIFF, INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, 
REXPAND,
                LOWER_TRI, UPPER_TRI,
                TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
                TOKENIZE, TOSTRING, LIST, PARAMSERV
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index b5be675..8c9666b 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -196,7 +196,8 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
                        case TRANSFORMMETA:
                        case TOSTRING:
                        case PARAMSERV:
-                       case LIST: {
+                       case LIST:
+                       case AUTODIFF:{
                                ExecType et = optFindExecType();
                                ParameterizedBuiltin pbilop = new 
ParameterizedBuiltin(
                                        inputlops, _op, getDataType(), 
getValueType(), et);
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java 
b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index 7c39548..a0f9331 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -171,7 +171,7 @@ public class ParameterizedBuiltin extends Lop
                                        
                                        sb.append(OPERAND_DELIMITOR);
                                }
-                               
+
                                break;
 
                        case TOKENIZE:
@@ -184,6 +184,12 @@ public class ParameterizedBuiltin extends Lop
                                sb.append(compileGenericParamMap(_inputParams));
                                break;
                        }
+                       case AUTODIFF: {
+                               sb.append("autoDiff"); //opcode
+                               sb.append(OPERAND_DELIMITOR);
+                               sb.append(compileGenericParamMap(_inputParams));
+                               break;
+                       }
                        case LIST: {
                                sb.append("nvlist"); //opcode
                                sb.append(OPERAND_DELIMITOR);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 89fc4ca..430fe7f 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2008,6 +2008,7 @@ public class DMLTranslator
                        case TRANSFORMCOLMAP:
                        case TRANSFORMMETA:
                        case PARAMSERV:
+                       case AUTODIFF:
                                currBuiltinOp = new 
ParameterizedBuiltinOp(target.getName(), target.getDataType(),
                                        target.getValueType(), 
ParamBuiltinOp.valueOf(source.getOpCode().name()), paramHops);
                                break;
@@ -2029,7 +2030,6 @@ public class DMLTranslator
                                                target.getValueType(), 
ParamBuiltinOp.TOSTRING, paramHops) :
                                        
HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), 
OpOp2.PLUS);
                                break;
-                       
                        case LISTNV:
                                currBuiltinOp = new 
ParameterizedBuiltinOp(target.getName(), target.getDataType(),
                                        target.getValueType(), 
ParamBuiltinOp.LIST, paramHops);
diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index d074d0d..26a54ac 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -48,12 +48,14 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        public static final String TF_FN_PARAM_DATA = "target";
        public static final String TF_FN_PARAM_MTD2 = "meta";
        public static final String TF_FN_PARAM_SPEC = "spec";
+       public static final String LINEAGE_TRACE = "lineage";
        public static final String TF_FN_PARAM_MTD = "transformPath"; //NOTE 
MB: for backwards compatibility
        
        public static HashMap<Builtins, ParamBuiltinOp> pbHopMap;
        static {
                pbHopMap = new HashMap<>();
                
+               pbHopMap.put(Builtins.AUTODIFF, ParamBuiltinOp.AUTODIFF);
                pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG);
                pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY);
                pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE);
@@ -231,7 +233,10 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                case TOSTRING:
                        validateCastAsString(output, conditional);
                        break;
-               
+
+               case AUTODIFF:
+                       validateAutoDiff(output, conditional);
+                       break;
                case LISTNV:
                        validateNamedList(output, conditional);
                        break;
@@ -251,6 +256,22 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                }
        }
 
+       private void validateAutoDiff(DataIdentifier output, boolean 
conditional) {
+               //validate data / metadata (recode maps)
+               checkDataType("lineage", LINEAGE_TRACE, DataType.LIST, 
conditional);
+
+               //validate specification
+               checkDataValueType(false, "lineage", LINEAGE_TRACE, 
DataType.LIST, ValueType.UNKNOWN, conditional);
+               HashMap<String, Expression> varParams = getVarParams();
+               // set output characteristics
+               output.setDataType(DataType.LIST);
+               output.setValueType(ValueType.UNKNOWN);
+               // TODO dimension should be set to -1 but could not set due to 
lineage parsing error in Spark contetx
+               output.setDimensions(varParams.size(), 1);
+               // output.setDimensions(-1, 1);
+               output.setBlocksize(-1);
+       }
+
        @Override
        public void validateExpression(MultiAssignmentStatement stmt, 
HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> 
constVars, boolean conditional)
        {
diff --git 
a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 46acba4..45b6c33 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -47,7 +47,7 @@ public class Builtin extends ValueFunction
 {
        private static final long serialVersionUID = 3836744687789840574L;
        
-       public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, 
ATAN, LOG, LOG_NZ, MIN,
+       public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, 
ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
                MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, 
LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
                STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, 
INVERSE, SPROP, SIGMOID, EVAL, LIST,
                TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, 
DROP_INVALID_LENGTH, MAP,
@@ -61,6 +61,7 @@ public class Builtin extends ValueFunction
        static public HashMap<String, BuiltinCode> String2BuiltinCode;
        static {
                String2BuiltinCode = new HashMap<>();
+               String2BuiltinCode.put( "autoDiff"    , BuiltinCode.AUTODIFF);
                String2BuiltinCode.put( "sin"    , BuiltinCode.SIN);
                String2BuiltinCode.put( "cos"    , BuiltinCode.COS);
                String2BuiltinCode.put( "tan"    , BuiltinCode.TAN);
diff --git 
a/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
 
b/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
index c15f6da..d800efd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
+++ 
b/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
@@ -43,7 +43,7 @@ public class ParameterizedBuiltin extends ValueFunction
        private static final long serialVersionUID = -7987603644903675052L;
        
        public enum ParameterizedBuiltinCode { 
-               CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI,
+               AUTODIFF, CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, 
UPPER_TRI,
                TOKENIZE, TRANSFORMAPPLY, TRANSFORMDECODE, PARAMSERV }
        public enum ProbabilityDistributionCode { 
                INVALID, NORMAL, EXP, CHISQ, F, T }
@@ -185,6 +185,9 @@ public class ParameterizedBuiltin extends ValueFunction
 
                        case PARAMSERV:
                                return new 
ParameterizedBuiltin(ParameterizedBuiltinCode.PARAMSERV);
+
+                       case AUTODIFF:
+                               return new 
ParameterizedBuiltin(ParameterizedBuiltinCode.AUTODIFF);
                                
                        default:
                                throw new DMLRuntimeException("Invalid 
parameterized builtin code: " + code);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b07cefa..323fac1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -160,7 +160,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
                String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
                String2CPInstructionType.put( "n+"  , CPType.BuiltinNary);
-               
+
                String2CPInstructionType.put( "exp"   , CPType.Unary);
                String2CPInstructionType.put( "abs"   , CPType.Unary);
                String2CPInstructionType.put( "sin"   , CPType.Unary);
@@ -203,6 +203,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "list",   CPType.BuiltinNary);
                
                // Parameterized Builtin Functions
+               String2CPInstructionType.put( "autoDiff" , 
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put("paramserv",       
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put( "nvlist",         
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put( "cdf",            
CPType.ParameterizedBuiltin);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index d1eb4a7..a02490f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -262,6 +262,7 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "isinf", SPType.Unary);
 
                // Parameterized Builtin Functions
+               String2SPInstructionType.put( "autoDiff"   , 
SPType.ParameterizedBuiltin);
                String2SPInstructionType.put( "groupedagg",     
SPType.ParameterizedBuiltin);
                String2SPInstructionType.put( "mapgroupedagg",  
SPType.ParameterizedBuiltin);
                String2SPInstructionType.put( "rmempty",        
SPType.ParameterizedBuiltin);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f115b52..6de5878 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -19,14 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -37,12 +29,9 @@ import org.apache.sysds.lops.Lop;
 import org.apache.sysds.parser.ParameterizedBuiltinFunctionExpression;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
+import org.apache.sysds.runtime.controlprogram.caching.*;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -61,8 +50,13 @@ import 
org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
 import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
 import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
+import org.apache.sysds.runtime.util.AutoDiff;
 import org.apache.sysds.runtime.util.DataConverter;
 
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
 public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction {
        private static final Log LOG = 
LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
        private static final int TOSTRING_MAXROWS = 100;
@@ -91,7 +85,6 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
        public static LinkedHashMap<String, String> 
constructParameterMap(String[] params) {
                // process all elements in "params" except first(opcode) and 
last(output)
                LinkedHashMap<String, String> paramMap = new LinkedHashMap<>();
-
                // all parameters are of form <name=value>
                String[] parts;
                for(int i = 1; i <= params.length - 2; i++) {
@@ -150,7 +143,7 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                }
                else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode") ||
                        opcode.equals("transformcolmap") || 
opcode.equals("transformmeta") || opcode.equals("tokenize") ||
-                       opcode.equals("toString") || opcode.equals("nvlist")) {
+                       opcode.equals("toString") || opcode.equals("nvlist") || 
opcode.equals("autoDiff")) {
                        return new ParameterizedBuiltinCPInstruction(null, 
paramsMap, out, opcode, str);
                }
                else if("paramserv".equals(opcode)) {
@@ -178,6 +171,13 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        sores = new DoubleObject(result);
                        ec.setScalarOutput(output.getName(), sores);
                }
+               else if(opcode.equalsIgnoreCase("autoDiff"))
+               {
+                       ArrayList<Data> lineage = (ArrayList<Data>) 
ec.getListObject(params.get("lineage")).getData();
+                       MatrixObject mo = 
ec.getMatrixObject(params.get("output"));
+                       ListObject diffs = AutoDiff.getBackward(mo, lineage, 
ExecutionContextFactory.createContext());
+                       ec.setVariable(output.getName(), diffs);
+               }
                else if(opcode.equalsIgnoreCase("groupedagg")) {
                        // acquire locks
                        MatrixBlock target = 
ec.getMatrixInput(params.get(Statement.GAGG_TARGET));
@@ -501,7 +501,7 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        return Pair.of(output.getName(),
                                new LineageItem(getOpcode(), 
LineageItemUtils.getLineage(ec, target, meta, spec)));
                }
-               else if (opcode.equalsIgnoreCase("nvlist")) {
+               else if (opcode.equalsIgnoreCase("nvlist") || 
opcode.equalsIgnoreCase("autoDiff")) {
                        List<String> names = new ArrayList<>(params.keySet());
                        CPOperand[] listOperands = names.stream().map(n -> 
ec.containsVariable(params.get(n)) 
                                        ? new CPOperand(n, 
ec.getVariable(params.get(n))) 
diff --git a/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java 
b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
new file mode 100644
index 0000000..2178a13
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.util;
+
+import org.apache.commons.lang3.mutable.MutableInt;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.*;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionParser;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.*;
+import org.apache.sysds.runtime.instructions.spark.RandSPInstruction;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageParser;
+import org.apache.sysds.utils.Explain;
+
+import java.util.*;
+
+public class AutoDiff {
+       private static final String ADVARPREFIX = "adVar";
+       private static final boolean DEBUG = false;
+
+       public static ListObject getBackward(MatrixObject mo, ArrayList<Data> 
lineage, ExecutionContext adec) {
+
+               ArrayList<String> names = new ArrayList<String>();
+               // parse the lineage and take the number of instructions as for 
each instruction there is separate hop DAG
+               String lin = lineage.get(0).toString();
+               // get rid of foo flag
+               lin = lin.replace("foo", "");
+               List<Data>  data = parseNComputeAutoDiffFromLineage(mo, lin, 
names, adec);
+               return new ListObject(data, names);
+       }
+
+       public static List<Data> parseNComputeAutoDiffFromLineage(MatrixObject 
mo, String mainTrace,
+               ArrayList<String> names, ExecutionContext ec ) {
+
+               LineageItem root = LineageParser.parseLineageTrace(mainTrace);
+               if (DEBUG) {
+                       System.out.println("Lineage trace of the forward pass");
+                       System.out.println(mainTrace);
+               }
+               // Recursively construct hops
+               root.resetVisitStatusNR();
+               Map<Long, Hop> operands = new HashMap<>();
+               // set variable for input matrix
+               ec.setVariable("X", mo);
+               DataOp input = HopRewriteUtils.createTransientRead("X", mo);
+               // each instruction Hop is stored separately as each 
instruction creates a new differentiation
+               ArrayList<Hop> allHops = constructHopsNR(root, operands, input, 
names);
+
+               ArrayList<Data> results = new ArrayList<>();
+               for(int i=0; i< allHops.size(); i++) {
+                       DataOp dop = 
HopRewriteUtils.createTransientWrite("advar"+i, allHops.get(i));
+                       ArrayList<Instruction> dInst = Recompiler
+                               .recompileHopsDag(dop, ec.getVariables(), null, 
true, true, 0);
+                       if (DEBUG) {
+                               System.out.println("HOP Dag and instructions 
for " + names.get(i));
+                               System.out.println(Explain.explain(dop));
+                               System.out.println(Explain.explain(dInst));
+                       }
+                       // create derivative instructions
+                       executeInst(dInst, ec);
+                       results.add(ec.getVariable("advar"+i));
+               }
+               return results;
+       }
+
+       public static ArrayList<Hop> constructHopsNR(LineageItem item, 
Map<Long, Hop> operands, Hop mo, ArrayList<String> names)
+       {
+               // Hop dags for the derivatives share common sub-dags with 
+               // the lineage dag of the forward pass. This method starts 
+               // constructing the hop dag from the lineage dag, but adds 
+               // extra hops to the resulting dags as needed.
+               ArrayList<Hop>  allHops = new ArrayList<>();
+               Stack<LineageItem> stackItem = new Stack<>();
+               Stack<MutableInt> stackPos = new Stack<>();
+               stackItem.push(item); stackPos.push(new MutableInt(0));
+               while (!stackItem.empty()) {
+                       LineageItem tmpItem = stackItem.peek();
+                       MutableInt tmpPos = stackPos.peek();
+                       // check ascent condition - no item processing
+                       if (tmpItem.isVisited()) {
+                               stackItem.pop(); stackPos.pop();
+                       }
+                       // check ascent condition - append item
+                       else if( tmpItem.getInputs() == null
+                               || tmpItem.getInputs().length <= 
tmpPos.intValue() ) {
+                               constructSingleHop(tmpItem, operands, mo, 
allHops, names);
+                               stackItem.pop(); stackPos.pop();
+                               tmpItem.setVisited();
+                       }
+                       // check descent condition
+                       else if( tmpItem.getInputs() != null ) {
+                               
stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+                               tmpPos.increment();
+                               stackPos.push(new MutableInt(0));
+                       }
+               }
+               return allHops;
+       }
+
+       private static void constructSingleHop(LineageItem item, Map<Long, Hop> 
operands, Hop mo,
+               ArrayList<Hop> allHops, ArrayList<String> names)
+       {
+               //process current lineage item
+               switch (item.getType()) {
+                       case Creation: {
+                               if(item.getData().startsWith(ADVARPREFIX)) {
+                                       long phId = 
Long.parseLong(item.getData().substring(3));
+                                       Hop input = operands.get(phId);
+                                       operands.remove(phId);
+                                       // Replace the placeholders with TReads
+                                       operands.put(item.getId(), input); // 
order preserving
+                                       break;
+                               }
+                               Instruction inst = 
InstructionParser.parseSingleInstruction(item.getData());
+
+                               if(inst instanceof DataGenCPInstruction) {
+                                       DataGenCPInstruction rand = 
(DataGenCPInstruction) inst;
+                                       HashMap<String, Hop> params = new 
HashMap<>();
+                                       if(rand.getOpcode().equals("rand")) {
+                                               if(rand.output.getDataType() == 
Types.DataType.TENSOR)
+                                                       
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+                                               else {
+                                                       
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+                                                       
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+                                               }
+                                               
params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+                                               
params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+                                               
params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+                                               
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+                                               
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+                                               
params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+                                       }
+                                       Hop datagen = new 
DataGenOp(Types.OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
+                                               new DataIdentifier("tmp"), 
params);
+                                       
datagen.setBlocksize(rand.getBlocksize());
+                                       operands.put(item.getId(), datagen);
+                               }
+                               else if(inst instanceof VariableCPInstruction 
&& ((VariableCPInstruction) inst).isCreateVariable()) {
+                                       String parts[] = 
InstructionUtils.getInstructionPartsWithValueType(inst.toString());
+                                       Types.DataType dt = 
Types.DataType.valueOf(parts[4]);
+                                       Types.ValueType vt = dt == 
Types.DataType.MATRIX ? Types.ValueType.FP64 : Types.ValueType.STRING;
+                                       HashMap<String, Hop> params = new 
HashMap<>();
+                                       params.put(DataExpression.IO_FILENAME, 
new LiteralOp(parts[2]));
+                                       params.put(DataExpression.READROWPARAM, 
new LiteralOp(Long.parseLong(parts[6])));
+                                       params.put(DataExpression.READCOLPARAM, 
new LiteralOp(Long.parseLong(parts[7])));
+                                       params.put(DataExpression.READNNZPARAM, 
new LiteralOp(Long.parseLong(parts[8])));
+                                       params.put(DataExpression.FORMAT_TYPE, 
new LiteralOp(parts[5]));
+                                       DataOp pread = new 
DataOp(parts[1].substring(5), dt, vt, Types.OpOpData.PERSISTENTREAD, params);
+                                       pread.setFileName(parts[2]);
+                                       operands.put(item.getId(), pread);
+                               }
+                               else if(inst instanceof RandSPInstruction) {
+                                       RandSPInstruction rand = 
(RandSPInstruction) inst;
+                                       HashMap<String, Hop> params = new 
HashMap<>();
+                                       if(rand.output.getDataType() == 
Types.DataType.TENSOR)
+                                               
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+                                       else {
+                                               
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+                                               
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+                                       }
+                                       params.put(DataExpression.RAND_MIN, new 
LiteralOp(rand.getMinValue()));
+                                       params.put(DataExpression.RAND_MAX, new 
LiteralOp(rand.getMaxValue()));
+                                       params.put(DataExpression.RAND_PDF, new 
LiteralOp(rand.getPdf()));
+                                       params.put(DataExpression.RAND_LAMBDA, 
new LiteralOp(rand.getPdfParams()));
+                                       
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+                                       params.put(DataExpression.RAND_SEED, 
new LiteralOp(rand.getSeed()));
+                                       Hop datagen = new 
DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("tmp"), params);
+                                       
datagen.setBlocksize(rand.getBlocksize());
+                                       operands.put(item.getId(), datagen);
+                               }
+                               break;
+                       }
+                       case Instruction: {
+                               CPInstruction.CPType ctype = 
InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+
+                               if(ctype != null) {
+                                       switch(ctype) {
+                                               case AggregateBinary: {
+                                                       Hop input1 = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop input2 = 
operands.get(item.getInputs()[1].getId());
+                                                       //Build the hops for 
the derivatives
+                                                       ReorgOp trasnX = 
HopRewriteUtils.createTranspose(input1);
+                                                       ReorgOp trasnW = 
HopRewriteUtils.createTranspose(input2);
+                                                       Hop dX = 
HopRewriteUtils.createMatrixMultiply(mo, trasnW);
+                                                       Hop dW = 
HopRewriteUtils.createMatrixMultiply(trasnX, mo);
+                                                       
operands.put(item.getId(), dX);
+                                                       
operands.put(item.getId() + 1, dW);
+                                                       allHops.add(dX);
+                                                       allHops.add(dW);
+                                                       names.add("dX");
+                                                       names.add("dW");
+                                                       break;
+                                               }
+                                               case Binary: {
+                                                       //handle special cases 
of binary operations
+                                                       String opcode = 
item.getOpcode();
+                                                       Hop output = null;
+                                                       if(opcode.equals("+"))
+                                                               output = 
HopRewriteUtils.createAggUnaryOp(mo, Types.AggOp.SUM, Types.Direction.Col);
+                                                       
operands.put(item.getId(), output);
+                                                       allHops.add(output);
+                                                       names.add("dB");
+                                                       break;
+                                               }
+                                               default:
+                                                       throw new 
DMLRuntimeException(
+                                                               "Unsupported 
autoDiff instruction " + "type: " + ctype.name() + " (" + item.getOpcode() + 
").");
+                                       }
+                               }
+                               break;
+                       }
+                       case Literal: {
+                               CPOperand op = new CPOperand(item.getData());
+                               operands.put(item.getId(), ScalarObjectFactory
+                                       .createLiteralOp(op.getValueType(), 
op.getName()));
+                               break;
+                       }
+                       default:
+                               throw new DMLRuntimeException("Lineage type " + 
item.getType() + " is not supported");
+               }
+       }
+       private static void executeInst(ArrayList<Instruction> newInst, 
ExecutionContext lrwec)
+       {
+               try {
+                       //execute instructions
+                       BasicProgramBlock pb = new BasicProgramBlock(new 
Program());
+                       pb.setInstructions(newInst);
+                       pb.execute(lrwec);
+               }
+               catch (Exception e) {
+                       throw new DMLRuntimeException("Error executing autoDiff 
instruction" , e);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java
new file mode 100644
index 0000000..ab4a373
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class AutoDiffTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "autoDiff";
+       private final static String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
AutoDiffTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+       }
+
+       @Test
+       public void testAutoDiffCP1() {
+               runAutoDiffTest(Types.ExecType.CP);
+       }
+
+       private void runAutoDiffTest(Types.ExecType instType)
+       {
+               ExecMode platformOld = setExecMode(instType);
+
+               try
+               {
+                       OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = false;
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-lineage", "-args", 
output("dX"), output("ad_dX")};
+                       runTest(true, false, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> dml_dX = 
readDMLMatrixFromOutputDir("dX");
+                       HashMap<MatrixValue.CellIndex, Double> autoDiff_dX = 
readDMLMatrixFromOutputDir("ad_dX");
+                       TestUtils.compareMatrices(dml_dX, autoDiff_dX, 1e-6, 
"Stat-DML", "Stat-AutoDiff");
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/autoDiff.dml 
b/src/test/scripts/functions/builtin/autoDiff.dml
new file mode 100644
index 0000000..26922dc
--- /dev/null
+++ b/src/test/scripts/functions/builtin/autoDiff.dml
@@ -0,0 +1,58 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("nn/layers/affine.dml") as affine
+
+
+# # # initializing the matrix by hand parsing issues in rand command within 
lineage
+M = 5; N = 5
+X_batch = rand(rows=M, cols=N, sparsity=1)
+                  
+W_1 = rand(rows=M, cols=N, sparsity=1)
+b_1 = matrix(0, rows=1, cols=M)
+
+prob = affine::forward(X_batch, W_1, b_1)
+lin = lineage(prob)
+
+# # TODO stop instruction parser to parse lineage string
+# # for now it is stopped by adding a string foo as a work-around
+if(sum(prob) > 0)
+  lin = lin+"foo"
+
+# # The lineage is passed as a list item because even after adding "foo" 
string the 
+# # compiler keep parsing the lineage instruction so it is passed as a list 
item to avoid parsing
+# # # create autodiff by parsing the lineage instructions
+diffs = autoDiff(output=prob, lineage=list(lin));
+
+ad_dX = as.matrix(diffs['dX'])
+ad_dW = as.matrix(diffs['dW'])
+ad_dB = as.matrix(diffs['dB'])
+
+# # # # compute the derivatives from the backward script
+[dX, dW, dB] = affine::backward(prob, X_batch, W_1, b_1)
+
+sameX = dX != ad_dX
+sameW = dW != ad_dW
+sameB = dB != ad_dB
+
+output = ((sum(sameX) == 0) & (sum(sameW) == 0) & (sum(sameB) == 0))
+
+write(dX, $1)
+write(ad_dX, $2)
\ No newline at end of file

Reply via email to