[MINOR] Performance function invocation of dml-bodied UDFs

This patch slightly improved the function invocation performance of
dml-bodied UDFs from 452K/s to 521K/s. 

Furthermore, this also includes a fix of the test for LinregCG over
compressed data.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14c410ce
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14c410ce
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14c410ce

Branch: refs/heads/master
Commit: 14c410ce06f3a5c56d1bcb1ac509fab4a0711f5f
Parents: d7d312c
Author: Matthias Boehm <[email protected]>
Authored: Fri Nov 3 14:23:50 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Nov 3 18:59:29 2017 -0700

----------------------------------------------------------------------
 .../controlprogram/LocalVariableMap.java        |   4 +
 .../controlprogram/ParForProgramBlock.java      |   4 +-
 .../context/ExecutionContext.java               |  41 +++---
 .../cp/FunctionCallCPInstruction.java           | 135 +++++++++----------
 .../functions/compress/CompressedLinregCG.java  |  19 ++-
 5 files changed, 95 insertions(+), 108 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
index 7ebe1a0..e894495 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
@@ -63,6 +63,10 @@ public class LocalVariableMap implements Cloneable
                return localMap.keySet();
        }
        
+       public Set<Entry<String, Data>> entrySet() {
+               return localMap.entrySet();
+       }
+       
        /**
         * Retrieves the data object given its name.
         * 

http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
index e2568cb..760ddff 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
@@ -633,7 +633,7 @@ public class ParForProgramBlock extends ForProgramBlock
                
                //preserve shared input/result variables of cleanup
                ArrayList<String> varList = ec.getVarList();
-               HashMap<String, Boolean> varState = ec.pinVariables(varList);
+               boolean[] varState = ec.pinVariables(varList);
                
                try 
                {
@@ -1329,7 +1329,7 @@ public class ParForProgramBlock extends ForProgramBlock
                }
        }
 
-       private void cleanupSharedVariables( ExecutionContext ec, 
HashMap<String,Boolean> varState ) 
+       private void cleanupSharedVariables( ExecutionContext ec, boolean[] 
varState ) 
                throws DMLRuntimeException 
        {
                //TODO needs as precondition a systematic treatment of 
persistent read information.

http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 79a658d..ecb9629 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -20,7 +20,6 @@
 package org.apache.sysml.runtime.controlprogram.context;
 
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
 
 import org.apache.commons.logging.Log;
@@ -151,6 +150,11 @@ public class ExecutionContext {
                return _variables.get(name);
        }
        
+       public Data getVariable(CPOperand operand) throws DMLRuntimeException {
+               return operand.getDataType().isScalar() ?
+                       getScalarInput(operand) : 
getVariable(operand.getName());
+       }
+       
        public void setVariable(String name, Data val) {
                _variables.put(name, val);
        }
@@ -528,30 +532,25 @@ public class ExecutionContext {
         * The function returns the OLD "clean up" state of matrix objects.
         * 
         * @param varList variable list
-        * @return map of old cleanup state of matrix objects
+        * @return indicator vector of old cleanup state of matrix objects
         */
-       public HashMap<String,Boolean> pinVariables(ArrayList<String> varList) 
+       public boolean[] pinVariables(ArrayList<String> varList) 
        {
                //2-pass approach since multiple vars might refer to same 
matrix object
-               HashMap<String, Boolean> varsState = new HashMap<>();
+               boolean[] varsState = new boolean[varList.size()];
                
                //step 1) get current information
-               for( String var : varList )
-               {
-                       Data dat = _variables.get(var);
-                       if( dat instanceof MatrixObject ) {
-                               MatrixObject mo = (MatrixObject)dat;
-                               varsState.put( var, mo.isCleanupEnabled() );
-                       }
+               for( int i=0; i<varList.size(); i++ ) {
+                       Data dat = _variables.get(varList.get(i));
+                       if( dat instanceof MatrixObject )
+                               varsState[i] = 
((MatrixObject)dat).isCleanupEnabled();
                }
                
                //step 2) pin variables
-               for( String var : varList ) {
-                       Data dat = _variables.get(var);
-                       if( dat instanceof MatrixObject ) {
-                               MatrixObject mo = (MatrixObject)dat;
-                               mo.enableCleanup(false); 
-                       }
+               for( int i=0; i<varList.size(); i++ ) {
+                       Data dat = _variables.get(varList.get(i));
+                       if( dat instanceof MatrixObject )
+                               ((MatrixObject)dat).enableCleanup(false); 
                }
                
                return varsState;
@@ -573,11 +572,11 @@ public class ExecutionContext {
         * @param varList variable list
         * @param varsState variable state
         */
-       public void unpinVariables(ArrayList<String> varList, 
HashMap<String,Boolean> varsState) {
-               for( String var : varList) {
-                       Data dat = _variables.get(var);
+       public void unpinVariables(ArrayList<String> varList, boolean[] 
varsState) {
+               for( int i=0; i<varList.size(); i++ ) {
+                       Data dat = _variables.get(varList.get(i));
                        if( dat instanceof MatrixObject )
-                               
((MatrixObject)dat).enableCleanup(varsState.get(var));
+                               ((MatrixObject)dat).enableCleanup(varsState[i]);
                }
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
index b901dfc..e785196 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -20,10 +20,8 @@
 package org.apache.sysml.runtime.instructions.cp;
 
 import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
 import java.util.HashSet;
-import java.util.LinkedList;
+import java.util.Map.Entry;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.lops.Lop;
@@ -43,32 +41,30 @@ import 
org.apache.sysml.runtime.instructions.InstructionUtils;
 public class FunctionCallCPInstruction extends CPInstruction {
        private String _functionName;
        private String _namespace;
+       private final CPOperand[] _boundInputs;
+       private final ArrayList<String> _boundInputNames;
+       private final ArrayList<String> _boundOutputNames;
+       private HashSet<String> _expectRetVars = null;
 
-       public String getFunctionName() {
-               return _functionName;
-       }
-
-       public String getNamespace() {
-               return _namespace;
-       }
-
-       // stores both the bound input and output parameters
-       private ArrayList<CPOperand> _boundInputParamOperands;
-       private ArrayList<String> _boundInputParamNames;
-       private ArrayList<String> _boundOutputParamNames;
-
-       private FunctionCallCPInstruction(String namespace, String functName, 
ArrayList<CPOperand> boundInParamOperands,
-                       ArrayList<String> boundInParamNames, ArrayList<String> 
boundOutParamNames, String istr) {
+       private FunctionCallCPInstruction(String namespace, String functName, 
CPOperand[] boundInputs, 
+               ArrayList<String> boundInputNames, ArrayList<String> 
boundOutputNames, String istr) {
                super(null, functName, istr);
                _cptype = CPINSTRUCTION_TYPE.External;
                _functionName = functName;
                _namespace = namespace;
-               _boundInputParamOperands = boundInParamOperands;
-               _boundInputParamNames = boundInParamNames;
-               _boundOutputParamNames = boundOutParamNames;
+               _boundInputs = boundInputs;
+               _boundInputNames = boundInputNames;
+               _boundOutputNames = boundOutputNames;
+       }
 
+       public String getFunctionName() {
+               return _functionName;
        }
 
+       public String getNamespace() {
+               return _namespace;
+       }
+       
        public static FunctionCallCPInstruction parseInstruction(String str) 
                throws DMLRuntimeException 
        {
@@ -78,20 +74,17 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                String functionName = parts[2];
                int numInputs = Integer.valueOf(parts[3]);
                int numOutputs = Integer.valueOf(parts[4]);
-               ArrayList<CPOperand> boundInParamOperands = new ArrayList<>();
-               ArrayList<String> boundInParamNames = new ArrayList<>();
-               ArrayList<String> boundOutParamNames = new ArrayList<>();
+               CPOperand[] boundInputs = new CPOperand[numInputs];
+               ArrayList<String> boundInputNames = new ArrayList<>();
+               ArrayList<String> boundOutputNames = new ArrayList<>();
                for (int i = 0; i < numInputs; i++) {
-                       CPOperand operand = new CPOperand(parts[5 + i]);
-                       boundInParamOperands.add(operand);
-                       boundInParamNames.add(operand.getName());
+                       boundInputs[i] = new CPOperand(parts[5 + i]);
+                       boundInputNames.add(boundInputs[i].getName());
                }
-               for (int i = 0; i < numOutputs; i++) {
-                       boundOutParamNames.add(parts[5 + numInputs + i]);
-               }
-               
-               return new FunctionCallCPInstruction ( namespace,functionName, 
-                               boundInParamOperands, boundInParamNames, 
boundOutParamNames, str );
+               for (int i = 0; i < numOutputs; i++)
+                       boundOutputNames.add(parts[5 + numInputs + i]);
+               return new FunctionCallCPInstruction ( namespace,
+                       functionName, boundInputs, boundInputNames, 
boundOutputNames, str );
        }
        
        @Override
@@ -120,10 +113,10 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                // get the function program block (stored in the Program object)
                FunctionProgramBlock fpb = 
ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
                
-               // sanity check number of function paramters
-               if( _boundInputParamNames.size() < fpb.getInputParams().size() 
) {
+               // sanity check number of function parameters
+               if( _boundInputs.length < fpb.getInputParams().size() ) {
                        throw new DMLRuntimeException("Number of bound input 
parameters does not match the function signature "
-                               + "("+_boundInputParamNames.size()+", but 
"+fpb.getInputParams().size()+" expected)");
+                               + "("+_boundInputs.length+", but 
"+fpb.getInputParams().size()+" expected)");
                }
                
                // create bindings to formal parameters for given function call
@@ -131,35 +124,31 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                LocalVariableMap functionVariables = new LocalVariableMap();
                for( int i=0; i<fpb.getInputParams().size(); i++) 
                {
-                       DataIdentifier currFormalParam = 
fpb.getInputParams().get(i);
-                       String currFormalParamName = currFormalParam.getName();
-                       Data currFormalParamValue = null; 
-                       
-                       CPOperand operand = _boundInputParamOperands.get(i);
-                       String varname = operand.getName();
                        //error handling non-existing variables
-                       if( !operand.isLiteral() && 
!ec.containsVariable(varname) ) {
-                               throw new DMLRuntimeException("Input variable 
'"+varname+"' not existing on call of " + 
+                       CPOperand input = _boundInputs[i];
+                       if( !input.isLiteral() && 
!ec.containsVariable(input.getName()) ) {
+                               throw new DMLRuntimeException("Input variable 
'"+input.getName()+"' not existing on call of " + 
                                        
DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line 
"+getLineNum()+").");
                        }
                        //get input matrix/frame/scalar
-                       currFormalParamValue = 
(operand.getDataType()!=DataType.SCALAR) ? ec.getVariable(varname) : 
-                               ec.getScalarInput(varname, 
operand.getValueType(), operand.isLiteral());
+                       DataIdentifier currFormalParam = 
fpb.getInputParams().get(i);
+                       Data value = ec.getVariable(input);
                        
                        //graceful value type conversion for scalar inputs with 
wrong type
-                       if( currFormalParamValue.getDataType() == 
DataType.SCALAR
-                               && currFormalParamValue.getValueType() != 
currFormalParam.getValueType() ) 
+                       if( value.getDataType() == DataType.SCALAR
+                               && value.getValueType() != 
currFormalParam.getValueType() ) 
                        {
-                               currFormalParamValue = 
ScalarObjectFactory.createScalarObject(
-                                       currFormalParam.getValueType(), 
(ScalarObject) currFormalParamValue);
+                               value = ScalarObjectFactory.createScalarObject(
+                                       currFormalParam.getValueType(), 
(ScalarObject)value);
                        }
                        
-                       functionVariables.put(currFormalParamName, 
currFormalParamValue);
+                       //set input parameter
+                       functionVariables.put(currFormalParam.getName(), value);
                }
                
                // Pin the input variables so that they do not get deleted 
                // from pb's symbol table at the end of execution of function
-               HashMap<String,Boolean> pinStatus = 
ec.pinVariables(_boundInputParamNames);
+               boolean[] pinStatus = ec.pinVariables(_boundInputNames);
                
                // Create a symbol table under a new execution context for the 
function invocation,
                // and copy the function arguments into the created table. 
@@ -182,29 +171,29 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                        String fname = 
DMLProgram.constructFunctionKey(_namespace, _functionName);
                        throw new DMLRuntimeException("error executing function 
" + fname, e);
                }
-               LocalVariableMap retVars = fn_ec.getVariables();
                
                // cleanup all returned variables w/o binding 
-               Collection<String> retVarnames = new 
LinkedList<>(retVars.keySet());
-               HashSet<String> probeVars = new HashSet<>();
-               for(DataIdentifier di : fpb.getOutputParams())
-                       probeVars.add(di.getName());
-               for( String var : retVarnames ) {
-                       if( !probeVars.contains(var) ) //cleanup candidate
-                       {
-                               Data dat = fn_ec.removeVariable(var);
-                               if( dat != null && dat instanceof MatrixObject )
-                                       
fn_ec.cleanupMatrixObject((MatrixObject)dat);
-                       }
+               if( _expectRetVars == null ) {
+                       _expectRetVars = new HashSet<>();
+                       for(DataIdentifier di : fpb.getOutputParams())
+                               _expectRetVars.add(di.getName());
+               }
+               
+               LocalVariableMap retVars = fn_ec.getVariables();
+               for( Entry<String,Data> var : retVars.entrySet() ) {
+                       if( _expectRetVars.contains(var.getKey()) )
+                               continue;
+                       //cleanup unexpected return values to avoid leaks
+                       if( var.getValue() instanceof MatrixObject )
+                               
fn_ec.cleanupMatrixObject((MatrixObject)var.getValue());
                }
                
                // Unpin the pinned variables
-               ec.unpinVariables(_boundInputParamNames, pinStatus);
+               ec.unpinVariables(_boundInputNames, pinStatus);
 
                // add the updated binding for each return variable to the 
variables in original symbol table
                for (int i=0; i< fpb.getOutputParams().size(); i++){
-               
-                       String boundVarName = _boundOutputParamNames.get(i); 
+                       String boundVarName = _boundOutputNames.get(i);
                        Data boundValue = 
retVars.get(fpb.getOutputParams().get(i).getName());
                        if (boundValue == null)
                                throw new DMLRuntimeException(boundVarName + " 
was not assigned a return value");
@@ -240,14 +229,12 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                LOG.debug("ExternalBuiltInFunction: " + this.toString());
        }
 
-       public ArrayList<String> getBoundInputParamNames()
-       {
-               return _boundInputParamNames;
+       public ArrayList<String> getBoundInputParamNames() {
+               return _boundInputNames;
        }
        
-       public ArrayList<String> getBoundOutputParamNames()
-       {
-               return _boundOutputParamNames;
+       public ArrayList<String> getBoundOutputParamNames() {
+               return _boundOutputNames;
        }
 
        public void setFunctionName(String fname)
@@ -277,6 +264,4 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
 
                return sb.substring( 0, 
sb.length()-Lop.OPERAND_DELIMITOR.length() );
        }
-       
-       
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/14c410ce/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
index 6e2ddef..a7e0971 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java
@@ -100,24 +100,23 @@ public class CompressedLinregCG extends AutomatedTestBase
                
                try
                {
-                       String TEST_NAME = testname;
-                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
                        
                        /* This is for running the junit test the new way, 
i.e., construct the arguments directly */
-                       String HOME = SCRIPT_DIR + "functions/codegen/";
+                       String HOME1 = SCRIPT_DIR + "functions/compress/";
+                       String HOME2 = SCRIPT_DIR + "functions/codegen/";
                        fullDMLScriptName = 
"scripts/algorithms/LinearRegCG.dml";
                        programArgs = new String[]{ "-explain", "-stats", 
"-nvargs", "X="+input("X"), "Y="+input("y"),
                                        "icpt="+String.valueOf(intercept), 
"tol="+String.valueOf(epsilon),
                                        "maxi="+String.valueOf(maxiter), 
"reg="+String.valueOf(regular), "B="+output("w")};
 
-                       fullRScriptName = HOME + "Algorithm_LinregCG.R";
+                       fullRScriptName = HOME2 + "Algorithm_LinregCG.R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
-                              HOME + INPUT_DIR + " " +
-                              String.valueOf(intercept) + " " + 
String.valueOf(epsilon) + " " + 
-                              String.valueOf(maxiter) + " " + 
String.valueOf(regular) + HOME + EXPECTED_DIR;
+                                       HOME1 + INPUT_DIR + " " +
+                                       String.valueOf(intercept) + " " + 
String.valueOf(epsilon) + " " + 
+                                       String.valueOf(maxiter) + " " + 
String.valueOf(regular) + " "+ HOME1 + EXPECTED_DIR;
                        
-                       loadTestConfiguration(config);
-       
                        //generate actual datasets
                        double[][] X = getRandomMatrix(rows, cols, 1, 1, 
sparse?sparsity2:sparsity1, 7);
                        writeInputMatrixWithMTD("X", X, true);
@@ -141,7 +140,7 @@ public class CompressedLinregCG extends AutomatedTestBase
                finally {
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-                       InfrastructureAnalyzer.setLocalMaxMemory(memOld);       
        
+                       InfrastructureAnalyzer.setLocalMaxMemory(memOld);
                        CompressedMatrixBlock.ALLOW_DDC_ENCODING = true;
                }
        }

Reply via email to