[MINOR] Performance function invocation of dml-bodied UDFs This patch slightly improved the function invocation performance of dml-bodied UDFs from 452K/s to 521K/s.
Furthermore, this also includes a fix of the test for LinregCG over compressed data. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14c410ce Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14c410ce Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14c410ce Branch: refs/heads/master Commit: 14c410ce06f3a5c56d1bcb1ac509fab4a0711f5f Parents: d7d312c Author: Matthias Boehm <[email protected]> Authored: Fri Nov 3 14:23:50 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Nov 3 18:59:29 2017 -0700 ---------------------------------------------------------------------- .../controlprogram/LocalVariableMap.java | 4 + .../controlprogram/ParForProgramBlock.java | 4 +- .../context/ExecutionContext.java | 41 +++--- .../cp/FunctionCallCPInstruction.java | 135 +++++++++---------- .../functions/compress/CompressedLinregCG.java | 19 ++- 5 files changed, 95 insertions(+), 108 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java index 7ebe1a0..e894495 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java @@ -63,6 +63,10 @@ public class LocalVariableMap implements Cloneable return localMap.keySet(); } + public Set<Entry<String, Data>> entrySet() { + return localMap.entrySet(); + } + /** * Retrieves the data object given its name. * http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java index e2568cb..760ddff 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java @@ -633,7 +633,7 @@ public class ParForProgramBlock extends ForProgramBlock //preserve shared input/result variables of cleanup ArrayList<String> varList = ec.getVarList(); - HashMap<String, Boolean> varState = ec.pinVariables(varList); + boolean[] varState = ec.pinVariables(varList); try { @@ -1329,7 +1329,7 @@ public class ParForProgramBlock extends ForProgramBlock } } - private void cleanupSharedVariables( ExecutionContext ec, HashMap<String,Boolean> varState ) + private void cleanupSharedVariables( ExecutionContext ec, boolean[] varState ) throws DMLRuntimeException { //TODO needs as precondition a systematic treatment of persistent read information. http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index 79a658d..ecb9629 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -20,7 +20,6 @@ package org.apache.sysml.runtime.controlprogram.context; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import org.apache.commons.logging.Log; @@ -151,6 +150,11 @@ public class ExecutionContext { return _variables.get(name); } + public Data getVariable(CPOperand operand) throws DMLRuntimeException { + return operand.getDataType().isScalar() ? + getScalarInput(operand) : getVariable(operand.getName()); + } + public void setVariable(String name, Data val) { _variables.put(name, val); } @@ -528,30 +532,25 @@ public class ExecutionContext { * The function returns the OLD "clean up" state of matrix objects. * * @param varList variable list - * @return map of old cleanup state of matrix objects + * @return indicator vector of old cleanup state of matrix objects */ - public HashMap<String,Boolean> pinVariables(ArrayList<String> varList) + public boolean[] pinVariables(ArrayList<String> varList) { //2-pass approach since multiple vars might refer to same matrix object - HashMap<String, Boolean> varsState = new HashMap<>(); + boolean[] varsState = new boolean[varList.size()]; //step 1) get current information - for( String var : varList ) - { - Data dat = _variables.get(var); - if( dat instanceof MatrixObject ) { - MatrixObject mo = (MatrixObject)dat; - varsState.put( var, mo.isCleanupEnabled() ); - } + for( int i=0; i<varList.size(); i++ ) { + Data dat = _variables.get(varList.get(i)); + if( dat instanceof MatrixObject ) + varsState[i] = ((MatrixObject)dat).isCleanupEnabled(); } //step 2) pin variables - for( String var : varList ) { - Data dat = _variables.get(var); - if( dat instanceof MatrixObject ) { - MatrixObject mo = (MatrixObject)dat; - mo.enableCleanup(false); - } + for( int i=0; i<varList.size(); i++ ) { + Data dat = _variables.get(varList.get(i)); + if( dat instanceof MatrixObject ) + ((MatrixObject)dat).enableCleanup(false); } return varsState; @@ -573,11 +572,11 @@ public class ExecutionContext { * @param varList variable list * @param varsState variable state */ - public void unpinVariables(ArrayList<String> varList, HashMap<String,Boolean> varsState) { - for( String var : varList) { - Data dat = _variables.get(var); + public void unpinVariables(ArrayList<String> varList, boolean[] varsState) { + for( int i=0; i<varList.size(); i++ ) { + Data dat = _variables.get(varList.get(i)); if( dat instanceof MatrixObject ) - ((MatrixObject)dat).enableCleanup(varsState.get(var)); + ((MatrixObject)dat).enableCleanup(varsState[i]); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java index b901dfc..e785196 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java @@ -20,10 +20,8 @@ package org.apache.sysml.runtime.instructions.cp; import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedList; +import java.util.Map.Entry; import org.apache.sysml.api.DMLScript; import org.apache.sysml.lops.Lop; @@ -43,32 +41,30 @@ import org.apache.sysml.runtime.instructions.InstructionUtils; public class FunctionCallCPInstruction extends CPInstruction { private String _functionName; private String _namespace; + private final CPOperand[] _boundInputs; + private final ArrayList<String> _boundInputNames; + private final ArrayList<String> _boundOutputNames; + private HashSet<String> _expectRetVars = null; - public String getFunctionName() { - return _functionName; - } - - public String getNamespace() { - return _namespace; - } - - // stores both the bound input and output parameters - private ArrayList<CPOperand> _boundInputParamOperands; - private ArrayList<String> _boundInputParamNames; - private ArrayList<String> _boundOutputParamNames; - - private FunctionCallCPInstruction(String namespace, String functName, ArrayList<CPOperand> boundInParamOperands, - ArrayList<String> boundInParamNames, ArrayList<String> boundOutParamNames, String istr) { + private FunctionCallCPInstruction(String namespace, String functName, CPOperand[] boundInputs, + ArrayList<String> boundInputNames, ArrayList<String> boundOutputNames, String istr) { super(null, functName, istr); _cptype = CPINSTRUCTION_TYPE.External; _functionName = functName; _namespace = namespace; - _boundInputParamOperands = boundInParamOperands; - _boundInputParamNames = boundInParamNames; - _boundOutputParamNames = boundOutParamNames; + _boundInputs = boundInputs; + _boundInputNames = boundInputNames; + _boundOutputNames = boundOutputNames; + } + public String getFunctionName() { + return _functionName; } + public String getNamespace() { + return _namespace; + } + public static FunctionCallCPInstruction parseInstruction(String str) throws DMLRuntimeException { @@ -78,20 +74,17 @@ public class FunctionCallCPInstruction extends CPInstruction { String functionName = parts[2]; int numInputs = Integer.valueOf(parts[3]); int numOutputs = Integer.valueOf(parts[4]); - ArrayList<CPOperand> boundInParamOperands = new ArrayList<>(); - ArrayList<String> boundInParamNames = new ArrayList<>(); - ArrayList<String> boundOutParamNames = new ArrayList<>(); + CPOperand[] boundInputs = new CPOperand[numInputs]; + ArrayList<String> boundInputNames = new ArrayList<>(); + ArrayList<String> boundOutputNames = new ArrayList<>(); for (int i = 0; i < numInputs; i++) { - CPOperand operand = new CPOperand(parts[5 + i]); - boundInParamOperands.add(operand); - boundInParamNames.add(operand.getName()); + boundInputs[i] = new CPOperand(parts[5 + i]); + boundInputNames.add(boundInputs[i].getName()); } - for (int i = 0; i < numOutputs; i++) { - boundOutParamNames.add(parts[5 + numInputs + i]); - } - - return new FunctionCallCPInstruction ( namespace,functionName, - boundInParamOperands, boundInParamNames, boundOutParamNames, str ); + for (int i = 0; i < numOutputs; i++) + boundOutputNames.add(parts[5 + numInputs + i]); + return new FunctionCallCPInstruction ( namespace, + functionName, boundInputs, boundInputNames, boundOutputNames, str ); } @Override @@ -120,10 +113,10 @@ public class FunctionCallCPInstruction extends CPInstruction { // get the function program block (stored in the Program object) FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName); - // sanity check number of function paramters - if( _boundInputParamNames.size() < fpb.getInputParams().size() ) { + // sanity check number of function parameters + if( _boundInputs.length < fpb.getInputParams().size() ) { throw new DMLRuntimeException("Number of bound input parameters does not match the function signature " - + "("+_boundInputParamNames.size()+", but "+fpb.getInputParams().size()+" expected)"); + + "("+_boundInputs.length+", but "+fpb.getInputParams().size()+" expected)"); } // create bindings to formal parameters for given function call @@ -131,35 +124,31 @@ public class FunctionCallCPInstruction extends CPInstruction { LocalVariableMap functionVariables = new LocalVariableMap(); for( int i=0; i<fpb.getInputParams().size(); i++) { - DataIdentifier currFormalParam = fpb.getInputParams().get(i); - String currFormalParamName = currFormalParam.getName(); - Data currFormalParamValue = null; - - CPOperand operand = _boundInputParamOperands.get(i); - String varname = operand.getName(); //error handling non-existing variables - if( !operand.isLiteral() && !ec.containsVariable(varname) ) { - throw new DMLRuntimeException("Input variable '"+varname+"' not existing on call of " + + CPOperand input = _boundInputs[i]; + if( !input.isLiteral() && !ec.containsVariable(input.getName()) ) { + throw new DMLRuntimeException("Input variable '"+input.getName()+"' not existing on call of " + DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line "+getLineNum()+")."); } //get input matrix/frame/scalar - currFormalParamValue = (operand.getDataType()!=DataType.SCALAR) ? ec.getVariable(varname) : - ec.getScalarInput(varname, operand.getValueType(), operand.isLiteral()); + DataIdentifier currFormalParam = fpb.getInputParams().get(i); + Data value = ec.getVariable(input); //graceful value type conversion for scalar inputs with wrong type - if( currFormalParamValue.getDataType() == DataType.SCALAR - && currFormalParamValue.getValueType() != currFormalParam.getValueType() ) + if( value.getDataType() == DataType.SCALAR + && value.getValueType() != currFormalParam.getValueType() ) { - currFormalParamValue = ScalarObjectFactory.createScalarObject( - currFormalParam.getValueType(), (ScalarObject) currFormalParamValue); + value = ScalarObjectFactory.createScalarObject( + currFormalParam.getValueType(), (ScalarObject)value); } - functionVariables.put(currFormalParamName, currFormalParamValue); + //set input parameter + functionVariables.put(currFormalParam.getName(), value); } // Pin the input variables so that they do not get deleted // from pb's symbol table at the end of execution of function - HashMap<String,Boolean> pinStatus = ec.pinVariables(_boundInputParamNames); + boolean[] pinStatus = ec.pinVariables(_boundInputNames); // Create a symbol table under a new execution context for the function invocation, // and copy the function arguments into the created table. @@ -182,29 +171,29 @@ public class FunctionCallCPInstruction extends CPInstruction { String fname = DMLProgram.constructFunctionKey(_namespace, _functionName); throw new DMLRuntimeException("error executing function " + fname, e); } - LocalVariableMap retVars = fn_ec.getVariables(); // cleanup all returned variables w/o binding - Collection<String> retVarnames = new LinkedList<>(retVars.keySet()); - HashSet<String> probeVars = new HashSet<>(); - for(DataIdentifier di : fpb.getOutputParams()) - probeVars.add(di.getName()); - for( String var : retVarnames ) { - if( !probeVars.contains(var) ) //cleanup candidate - { - Data dat = fn_ec.removeVariable(var); - if( dat != null && dat instanceof MatrixObject ) - fn_ec.cleanupMatrixObject((MatrixObject)dat); - } + if( _expectRetVars == null ) { + _expectRetVars = new HashSet<>(); + for(DataIdentifier di : fpb.getOutputParams()) + _expectRetVars.add(di.getName()); + } + + LocalVariableMap retVars = fn_ec.getVariables(); + for( Entry<String,Data> var : retVars.entrySet() ) { + if( _expectRetVars.contains(var.getKey()) ) + continue; + //cleanup unexpected return values to avoid leaks + if( var.getValue() instanceof MatrixObject ) + fn_ec.cleanupMatrixObject((MatrixObject)var.getValue()); } // Unpin the pinned variables - ec.unpinVariables(_boundInputParamNames, pinStatus); + ec.unpinVariables(_boundInputNames, pinStatus); // add the updated binding for each return variable to the variables in original symbol table for (int i=0; i< fpb.getOutputParams().size(); i++){ - - String boundVarName = _boundOutputParamNames.get(i); + String boundVarName = _boundOutputNames.get(i); Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName()); if (boundValue == null) throw new DMLRuntimeException(boundVarName + " was not assigned a return value"); @@ -240,14 +229,12 @@ public class FunctionCallCPInstruction extends CPInstruction { LOG.debug("ExternalBuiltInFunction: " + this.toString()); } - public ArrayList<String> getBoundInputParamNames() - { - return _boundInputParamNames; + public ArrayList<String> getBoundInputParamNames() { + return _boundInputNames; } - public ArrayList<String> getBoundOutputParamNames() - { - return _boundOutputParamNames; + public ArrayList<String> getBoundOutputParamNames() { + return _boundOutputNames; } public void setFunctionName(String fname) @@ -277,6 +264,4 @@ public class FunctionCallCPInstruction extends CPInstruction { return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() ); } - - } http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java index 6e2ddef..a7e0971 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java @@ -100,24 +100,23 @@ public class CompressedLinregCG extends AutomatedTestBase try { - String TEST_NAME = testname; - TestConfiguration config = getTestConfiguration(TEST_NAME); + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); /* This is for running the junit test the new way, i.e., construct the arguments directly */ - String HOME = SCRIPT_DIR + "functions/codegen/"; + String HOME1 = SCRIPT_DIR + "functions/compress/"; + String HOME2 = SCRIPT_DIR + "functions/codegen/"; fullDMLScriptName = "scripts/algorithms/LinearRegCG.dml"; programArgs = new String[]{ "-explain", "-stats", "-nvargs", "X="+input("X"), "Y="+input("y"), "icpt="+String.valueOf(intercept), "tol="+String.valueOf(epsilon), "maxi="+String.valueOf(maxiter), "reg="+String.valueOf(regular), "B="+output("w")}; - fullRScriptName = HOME + "Algorithm_LinregCG.R"; + fullRScriptName = HOME2 + "Algorithm_LinregCG.R"; rCmd = "Rscript" + " " + fullRScriptName + " " + - HOME + INPUT_DIR + " " + - String.valueOf(intercept) + " " + String.valueOf(epsilon) + " " + - String.valueOf(maxiter) + " " + String.valueOf(regular) + HOME + EXPECTED_DIR; + HOME1 + INPUT_DIR + " " + + String.valueOf(intercept) + " " + String.valueOf(epsilon) + " " + + String.valueOf(maxiter) + " " + String.valueOf(regular) + " "+ HOME1 + EXPECTED_DIR; - loadTestConfiguration(config); - //generate actual datasets double[][] X = getRandomMatrix(rows, cols, 1, 1, sparse?sparsity2:sparsity1, 7); writeInputMatrixWithMTD("X", X, true); @@ -141,7 +140,7 @@ public class CompressedLinregCG extends AutomatedTestBase finally { rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - InfrastructureAnalyzer.setLocalMaxMemory(memOld); + InfrastructureAnalyzer.setLocalMaxMemory(memOld); CompressedMatrixBlock.ALLOW_DDC_ENCODING = true; } }
