This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5596fcf  [SYSTEMDS-2650] Non-recursive construction of HOPs from 
Lineage
5596fcf is described below

commit 5596fcf0d77946cf11ff34085e8879552dd852be
Author: arnabp <[email protected]>
AuthorDate: Sat Sep 5 22:37:22 2020 +0200

    [SYSTEMDS-2650] Non-recursive construction of HOPs from Lineage
    
    This patch implements a non-recursive version of HOP dag construction
    from lineage dag, which fixes the stack overflow while re-computing
    from lineage.
---
 .../runtime/lineage/LineageRecomputeUtils.java     | 325 ++++++++++++++++++++-
 .../functions/lineage/LineageTraceDedupTest.java   |   7 +-
 2 files changed, 314 insertions(+), 18 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index fffc2dc..0df1651 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -25,8 +25,10 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Stack;
 import java.util.stream.Collectors;
 
+import org.apache.commons.lang3.mutable.MutableInt;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.OpOp1;
@@ -100,7 +102,7 @@ public class LineageRecomputeUtils {
                root.resetVisitStatusNR();
                Map<Long, Hop> operands = new HashMap<>();
                Map<String, Hop> partDagRoots = new HashMap<>();
-               rConstructHops(root, operands, partDagRoots, prog);
+               constructHopsNR(root, operands, partDagRoots, prog);
                Hop out = HopRewriteUtils.createTransientWrite(
                        varname, operands.get(rootId));
                
@@ -134,17 +136,38 @@ public class LineageRecomputeUtils {
                prog.addProgramBlock(pb);
        }
        
-       
-       private static void rConstructHops(LineageItem item, Map<Long, Hop> 
operands, Map<String, Hop> partDagRoots, Program prog) 
+       private static void constructHopsNR(LineageItem item, Map<Long, Hop> 
operands, Map<String, Hop> partDagRoots, Program prog) 
+       {
+               //NOTE: This method follows the same non-recursive 
+               //skeleton as explainLineageItemNR
+               Stack<LineageItem> stackItem = new Stack<>();
+               Stack<MutableInt> stackPos = new Stack<>();
+               stackItem.push(item); stackPos.push(new MutableInt(0));
+               while (!stackItem.empty()) {
+                       LineageItem tmpItem = stackItem.peek();
+                       MutableInt tmpPos = stackPos.peek();
+                       //check ascent condition - no item processing
+                       if (tmpItem.isVisited()) {
+                               stackItem.pop(); stackPos.pop();
+                       }
+                       //check ascent condition - append item
+                       else if( tmpItem.getInputs() == null 
+                               || tmpItem.getInputs().length <= 
tmpPos.intValue() ) {
+                               constructSingleHop(tmpItem, operands, 
partDagRoots, prog);
+                               stackItem.pop(); stackPos.pop();
+                               tmpItem.setVisited();
+                       }
+                       //check descent condition
+                       else if( tmpItem.getInputs() != null ) {
+                               
stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+                               tmpPos.increment();
+                               stackPos.push(new MutableInt(0));
+                       }
+               }
+       }
+
+       private static void constructSingleHop(LineageItem item, Map<Long, Hop> 
operands, Map<String, Hop> partDagRoots, Program prog) 
        {
-               if (item.isVisited())
-                       return;
-               
-               //recursively process children (ordering by data dependencies)
-               if (!item.isLeaf())
-                       for (LineageItem c : item.getInputs())
-                               rConstructHops(c, operands, partDagRoots, prog);
-               
                //process current lineage item
                //NOTE: we generate instructions from hops (but without 
rewrites) to automatically
                //handle execution types, rmvar instructions, and rewiring of 
inputs/outputs
@@ -406,8 +429,6 @@ public class LineageRecomputeUtils {
                                break;
                        }
                }
-               
-               item.setVisited();
        }
 
        // Construct and compile the function body
@@ -428,7 +449,7 @@ public class LineageRecomputeUtils {
                for (int i=0; i<inputs.length; i++)
                        operands.put((long)i, 
HopRewriteUtils.createTransientRead(inputs[i], inpHops.get(i))); //order 
preserving
                // Construct the Hop dag.
-               rConstructHops(patchRoot, operands, null, null);
+               constructHopsNR(patchRoot, operands, null, null);
                // TWrite the func return (pass dag root to copy datatype)
                Hop out = HopRewriteUtils.createTransientWrite(outname, 
operands.get(patchRoot.getId()));
                // Save the Hop dag
@@ -518,6 +539,282 @@ public class LineageRecomputeUtils {
                throw new DMLRuntimeException("Unsupported opcode: 
"+item.getOpcode());
        }
        
+       @Deprecated
+       @SuppressWarnings("unused")
+       private static void rConstructHops(LineageItem item, Map<Long, Hop> 
operands, Map<String, Hop> partDagRoots, Program prog) 
+       {
+               if (item.isVisited())
+                       return;
+               
+               //recursively process children (ordering by data dependencies)
+               if (!item.isLeaf())
+                       for (LineageItem c : item.getInputs())
+                               rConstructHops(c, operands, partDagRoots, prog);
+               
+               //process current lineage item
+               //NOTE: we generate instructions from hops (but without 
rewrites) to automatically
+               //handle execution types, rmvar instructions, and rewiring of 
inputs/outputs
+               switch (item.getType()) {
+                       case Creation: {
+                               if (item.getData().startsWith(LPLACEHOLDER)) {
+                                       long phId = 
Long.parseLong(item.getData().substring(3));
+                                       Hop input = operands.get(phId);
+                                       operands.remove(phId);
+                                       // Replace the placeholders with TReads
+                                       operands.put(item.getId(), input); // 
order preserving
+                                       break;
+                               }
+                               Instruction inst = 
InstructionParser.parseSingleInstruction(item.getData());
+                               
+                               if (inst instanceof DataGenCPInstruction) {
+                                       DataGenCPInstruction rand = 
(DataGenCPInstruction) inst;
+                                       HashMap<String, Hop> params = new 
HashMap<>();
+                                       if( rand.getOpcode().equals("rand") ) {
+                                               if( rand.output.getDataType() 
== DataType.TENSOR)
+                                                       
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+                                               else {
+                                                       
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+                                                       
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+                                               }
+                                               
params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+                                               
params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+                                               
params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+                                               
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+                                               
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+                                               
params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+                                       }
+                                       else if( rand.getOpcode().equals("seq") 
) {
+                                               params.put(Statement.SEQ_FROM, 
new LiteralOp(rand.getFrom()));
+                                               params.put(Statement.SEQ_TO, 
new LiteralOp(rand.getTo()));
+                                               params.put(Statement.SEQ_INCR, 
new LiteralOp(rand.getIncr()));
+                                       }
+                                       Hop datagen = new 
DataGenOp(OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
+                                               new DataIdentifier("tmp"), 
params);
+                                       
datagen.setBlocksize(rand.getBlocksize());
+                                       operands.put(item.getId(), datagen);
+                               } else if (inst instanceof VariableCPInstruction
+                                               && ((VariableCPInstruction) 
inst).isCreateVariable()) {
+                                       String parts[] = 
InstructionUtils.getInstructionPartsWithValueType(inst.toString());
+                                       DataType dt = 
DataType.valueOf(parts[4]);
+                                       ValueType vt = dt == DataType.MATRIX ? 
ValueType.FP64 : ValueType.STRING;
+                                       HashMap<String, Hop> params = new 
HashMap<>();
+                                       params.put(DataExpression.IO_FILENAME, 
new LiteralOp(parts[2]));
+                                       params.put(DataExpression.READROWPARAM, 
new LiteralOp(Long.parseLong(parts[6])));
+                                       params.put(DataExpression.READCOLPARAM, 
new LiteralOp(Long.parseLong(parts[7])));
+                                       params.put(DataExpression.READNNZPARAM, 
new LiteralOp(Long.parseLong(parts[8])));
+                                       params.put(DataExpression.FORMAT_TYPE, 
new LiteralOp(parts[5]));
+                                       DataOp pread = new 
DataOp(parts[1].substring(5), dt, vt, OpOpData.PERSISTENTREAD, params);
+                                       pread.setFileName(parts[2]);
+                                       operands.put(item.getId(), pread);
+                               }
+                               else if  (inst instanceof RandSPInstruction) {
+                                       RandSPInstruction rand = 
(RandSPInstruction) inst;
+                                       HashMap<String, Hop> params = new 
HashMap<>();
+                                       if (rand.output.getDataType() == 
DataType.TENSOR)
+                                               
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+                                       else {
+                                               
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+                                               
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+                                       }
+                                       params.put(DataExpression.RAND_MIN, new 
LiteralOp(rand.getMinValue()));
+                                       params.put(DataExpression.RAND_MAX, new 
LiteralOp(rand.getMaxValue()));
+                                       params.put(DataExpression.RAND_PDF, new 
LiteralOp(rand.getPdf()));
+                                       params.put(DataExpression.RAND_LAMBDA, 
new LiteralOp(rand.getPdfParams()));
+                                       
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+                                       params.put(DataExpression.RAND_SEED, 
new LiteralOp(rand.getSeed()));
+                                       Hop datagen = new 
DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
+                                       
datagen.setBlocksize(rand.getBlocksize());
+                                       operands.put(item.getId(), datagen);
+                               }
+                               break;
+                       }
+                       case Dedup: {
+                               // Create function call for each dedup entry 
+                               String[] parts = 
item.getOpcode().split(LineageDedupUtils.DEDUP_DELIM); //e.g. dedup_R_SB13_0
+                               String name = parts[2] + parts[1] + parts[3];  
//loopId + outVar + pathId
+                               List<Hop> finputs = 
Arrays.stream(item.getInputs())
+                                               .map(inp -> 
operands.get(inp.getId())).collect(Collectors.toList());
+                               String[] inputNames = new 
String[item.getInputs().length];
+                               for (int i=0; i<item.getInputs().length; i++)
+                                       inputNames[i] = LPLACEHOLDER + i;  
//e.g. IN#0, IN#1
+                               Hop funcOp = new FunctionOp(FunctionType.DML, 
DMLProgram.DEFAULT_NAMESPACE, 
+                                               name, inputNames, finputs, new 
String[] {parts[1]}, false);
+
+                               // Cut the Hop dag after function calls 
+                               partDagRoots.put(parts[1], funcOp);
+                               // Compile the dag and save
+                               constructBasicBlock(partDagRoots, parts[1], 
prog);
+
+                               // Construct a Hop dag for the function body 
from the dedup patch, and compile
+                               Hop output = constructHopsDedupPatch(parts, 
inputNames, finputs, prog);
+                               // Create a TRead on the function o/p as a leaf 
for the next Hop dag
+                               // Use the function body root/return hop to 
propagate right data type
+                               operands.put(item.getId(), 
HopRewriteUtils.createTransientRead(parts[1], output));
+                               break;
+                       }
+                       case Instruction: {
+                               CPType ctype = 
InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+                               SPType stype = 
InstructionUtils.getSPTypeByOpcode(item.getOpcode());
+                               
+                               if (ctype != null) {
+                                       switch (ctype) {
+                                               case AggregateUnary: {
+                                                       Hop input = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop aggunary = 
InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
+                                                               
HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
+                                                               
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
+                                                       
operands.put(item.getId(), aggunary);
+                                                       break;
+                                               }
+                                               case AggregateBinary: {
+                                                       Hop input1 = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop input2 = 
operands.get(item.getInputs()[1].getId());
+                                                       Hop aggbinary = 
HopRewriteUtils.createMatrixMultiply(input1, input2);
+                                                       
operands.put(item.getId(), aggbinary);
+                                                       break;
+                                               }
+                                               case AggregateTernary: {
+                                                       Hop input1 = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop input2 = 
operands.get(item.getInputs()[1].getId());
+                                                       Hop input3 = 
operands.get(item.getInputs()[2].getId());
+                                                       Hop aggternary = 
HopRewriteUtils.createSum(
+                                                               
HopRewriteUtils.createBinary(
+                                                               
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
+                                                               input3, 
OpOp2.MULT));
+                                                       
operands.put(item.getId(), aggternary);
+                                                       break;
+                                               }
+                                               case Unary:
+                                               case Builtin: {
+                                                       Hop input = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop unary = 
HopRewriteUtils.createUnary(input, item.getOpcode());
+                                                       
operands.put(item.getId(), unary);
+                                                       break;
+                                               }
+                                               case Reorg: {
+                                                       
operands.put(item.getId(), HopRewriteUtils.createReorg(
+                                                               
operands.get(item.getInputs()[0].getId()), item.getOpcode()));
+                                                       break;
+                                               }
+                                               case Reshape: {
+                                                       ArrayList<Hop> inputs = 
new ArrayList<>();
+                                                       for(int i=0; i<5; i++)
+                                                               
inputs.add(operands.get(item.getInputs()[i].getId()));
+                                                       
operands.put(item.getId(), HopRewriteUtils.createReorg(inputs, 
ReOrgOp.RESHAPE));
+                                                       break;
+                                               }
+                                               case Binary: {
+                                                       //handle special cases 
of binary operations 
+                                                       String opcode = 
("^2".equals(item.getOpcode()) 
+                                                               || 
"*2".equals(item.getOpcode())) ? 
+                                                               
item.getOpcode().substring(0, 1) : item.getOpcode();
+                                                       Hop input1 = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop input2 = 
operands.get(item.getInputs()[1].getId());
+                                                       Hop binary = 
HopRewriteUtils.createBinary(input1, input2, opcode);
+                                                       
operands.put(item.getId(), binary);
+                                                       break;
+                                               }
+                                               case Ternary: {
+                                                       
operands.put(item.getId(), HopRewriteUtils.createTernary(
+                                                               
operands.get(item.getInputs()[0].getId()), 
+                                                               
operands.get(item.getInputs()[1].getId()), 
+                                                               
operands.get(item.getInputs()[2].getId()), item.getOpcode()));
+                                                       break;
+                                               }
+                                               case Ctable: { //e.g., ctable 
+                                                       if( 
item.getInputs().length==3 )
+                                                               
operands.put(item.getId(), HopRewriteUtils.createTernary(
+                                                                       
operands.get(item.getInputs()[0].getId()),
+                                                                       
operands.get(item.getInputs()[1].getId()),
+                                                                       
operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
+                                                       else if( 
item.getInputs().length==5 )
+                                                               
operands.put(item.getId(), HopRewriteUtils.createTernary(
+                                                                       
operands.get(item.getInputs()[0].getId()),
+                                                                       
operands.get(item.getInputs()[1].getId()),
+                                                                       
operands.get(item.getInputs()[2].getId()),
+                                                                       
operands.get(item.getInputs()[3].getId()),
+                                                                       
operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
+                                                       break;
+                                               }
+                                               case BuiltinNary: {
+                                                       String opcode = 
item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
+                                                       
operands.put(item.getId(), HopRewriteUtils.createNary(
+                                                               
OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
+                                                       break;
+                                               }
+                                               case ParameterizedBuiltin: {
+                                                       
operands.put(item.getId(), constructParameterizedBuiltinOp(item, operands));
+                                                       break;
+                                               }
+                                               case MatrixIndexing: {
+                                                       
operands.put(item.getId(), constructIndexingOp(item, operands));
+                                                       break;
+                                               }
+                                               case MMTSJ: {
+                                                       //TODO handling of tsmm 
type left and right -> placement transpose
+                                                       Hop input = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop aggunary = 
HopRewriteUtils.createMatrixMultiply(
+                                                               
HopRewriteUtils.createTranspose(input), input);
+                                                       
operands.put(item.getId(), aggunary);
+                                                       break;
+                                               }
+                                               case Variable: {
+                                                       if( 
item.getOpcode().startsWith("cast") )
+                                                               
operands.put(item.getId(), HopRewriteUtils.createUnary(
+                                                                       
operands.get(item.getInputs()[0].getId()),
+                                                                       
OpOp1.valueOfByOpcode(item.getOpcode())));
+                                                       else //cpvar, write
+                                                               
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
+                                                       break;
+                                               }
+                                               default:
+                                                       throw new 
DMLRuntimeException("Unsupported instruction "
+                                                               + "type: " + 
ctype.name() + " (" + item.getOpcode() + ").");
+                                       }
+                               }
+                               else if( stype != null ) {
+                                       switch(stype) {
+                                               case Reblock: {
+                                                       Hop input = 
operands.get(item.getInputs()[0].getId());
+                                                       
input.setBlocksize(ConfigurationManager.getBlocksize());
+                                                       
input.setRequiresReblock(true);
+                                                       
operands.put(item.getId(), input);
+                                                       break;
+                                               }
+                                               case Checkpoint: {
+                                                       Hop input = 
operands.get(item.getInputs()[0].getId());
+                                                       
operands.put(item.getId(), input);
+                                                       break;
+                                               }
+                                               case MatrixIndexing: {
+                                                       
operands.put(item.getId(), constructIndexingOp(item, operands));
+                                                       break;
+                                               }
+                                               case GAppend: {
+                                                       
operands.put(item.getId(), HopRewriteUtils.createBinary(
+                                                               
operands.get(item.getInputs()[0].getId()),
+                                                               
operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
+                                                       break;
+                                               }
+                                               default:
+                                                       throw new 
DMLRuntimeException("Unsupported instruction "
+                                                               + "type: " + 
stype.name() + " (" + item.getOpcode() + ").");
+                                       }
+                               }
+                               else
+                                       throw new 
DMLRuntimeException("Unsupported instruction: " + item.getOpcode());
+                               break;
+                       }
+                       case Literal: {
+                               CPOperand op = new CPOperand(item.getData());
+                               operands.put(item.getId(), ScalarObjectFactory
+                                       .createLiteralOp(op.getValueType(), 
op.getName()));
+                               break;
+                       }
+               }
+               
+               item.setVisited();
+       }
        
        // Below class represents a single loop and contains related data
        // that are needed for recomputation.
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
index 18da399..3b1ae65 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
@@ -95,12 +95,11 @@ public class LineageTraceDedupTest extends AutomatedTestBase
                testLineageTrace(TEST_NAME5);
        }
        
-       /*@Test
+       @Test
        public void testLineageTrace6() {
                testLineageTrace(TEST_NAME6);
-       }*/
-       //FIXME: stack overflow only when ran the full package
-       
+       }
+
        @Test
        public void testLineageTrace7() {
                testLineageTrace(TEST_NAME7);

Reply via email to