Repository: systemml Updated Branches: refs/heads/master f7fe43420 -> c1ed79150
[SYSTEMML-1989] Performance external UDF invocation (loading, parsing) This patch improves the performance of external (i.e., java) UDF functions by moving the class loading and input/output parsing out of the critical path into the instruction setup. Furthermore, this includes a number of cleanups and updates of the UDF framework. On a test scenario of a pass-through external UDF with three inputs and output parameters, this patch improved performance from 75K/s to 267K/s, which good considering the rate of 520K/s for dml-bodied functions. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c1ed7915 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c1ed7915 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c1ed7915 Branch: refs/heads/master Commit: c1ed7915082f2e33027819861d32745f27e39fe4 Parents: f7fe434 Author: Matthias Boehm <[email protected]> Authored: Sun Nov 5 19:22:14 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Sun Nov 5 19:33:32 2017 -0800 ---------------------------------------------------------------------- .../ExternalFunctionProgramBlock.java | 500 +++---------------- .../ExternalFunctionProgramBlockCP.java | 104 +--- .../controlprogram/FunctionProgramBlock.java | 4 +- .../ExternalFunctionInvocationInstruction.java | 189 +++++-- src/main/java/org/apache/sysml/udf/Matrix.java | 29 +- .../org/apache/sysml/udf/PackageFunction.java | 8 +- .../sysml/udf/lib/DynamicProjectMatrixCP.java | 2 +- 7 files changed, 270 insertions(+), 566 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c1ed7915/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java index 9738203..da6ffc3 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java @@ -23,7 +23,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.StringTokenizer; import java.util.TreeMap; import org.apache.sysml.api.DMLScript; @@ -31,53 +30,35 @@ import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.compile.JobType; import org.apache.sysml.parser.DataIdentifier; -import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.parser.ExternalFunctionStatement; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.CacheException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.MRJobInstruction; -import org.apache.sysml.runtime.instructions.cp.BooleanObject; +import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.Data; -import org.apache.sysml.runtime.instructions.cp.DoubleObject; -import org.apache.sysml.runtime.instructions.cp.IntObject; -import org.apache.sysml.runtime.instructions.cp.ScalarObject; -import org.apache.sysml.runtime.instructions.cp.StringObject; import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.MetaDataFormat; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.udf.ExternalFunctionInvocationInstruction; -import org.apache.sysml.udf.FunctionParameter; -import org.apache.sysml.udf.Matrix; import org.apache.sysml.udf.PackageFunction; -import org.apache.sysml.udf.Scalar; -import org.apache.sysml.udf.FunctionParameter.FunctionParameterType; -import org.apache.sysml.udf.BinaryObject; import org.apache.sysml.udf.Scalar.ScalarValueType; public class ExternalFunctionProgramBlock extends FunctionProgramBlock { - protected static final IDSequence _idSeq = new IDSequence(); + protected long _runID = -1; //ID for block of statements protected String _baseDir = null; + protected HashMap<String, String> _otherParams; // holds other key value parameters - ArrayList<Instruction> block2CellInst; - ArrayList<Instruction> cell2BlockInst; - - // holds other key value parameters specified in function declaration - protected HashMap<String, String> _otherParams; + private ArrayList<Instruction> block2CellInst; + private ArrayList<Instruction> cell2BlockInst; + private HashMap<String, String> _unblockedFileNames; + private HashMap<String, String> _blockedFileNames; - protected HashMap<String, String> _unblockedFileNames; - protected HashMap<String, String> _blockedFileNames; - - protected long _runID = -1; //ID for block of statements /** * Constructor that also provides otherParams that are needed for external @@ -95,7 +76,7 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock ArrayList<DataIdentifier> outputParams, String baseDir) throws DMLRuntimeException { - super(prog, inputParams, outputParams); + super(prog, inputParams, outputParams); _baseDir = baseDir; } @@ -140,17 +121,19 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock * * @param id this field does nothing */ - private void changeTmpOutput( long id ) - { + private void changeTmpOutput( long id ) { ArrayList<DataIdentifier> outputParams = getOutputParams(); cell2BlockInst = getCell2BlockInstructions(outputParams, _blockedFileNames); } - public String getBaseDir() - { + public String getBaseDir() { return _baseDir; } + public HashMap<String,String> getOtherParams() { + return _otherParams; + } + /** * Method to be invoked to execute instructions for the external function * invocation @@ -169,12 +152,10 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock try { inputParams = getInputParams(); - for(DataIdentifier di : inputParams ) { + for(DataIdentifier di : inputParams ) { Data d = ec.getVariable(di.getName()); - if ( d.getDataType() == DataType.MATRIX ) { - MatrixObject inputObj = (MatrixObject) d; - inputObj.exportData(); - } + if( d.getDataType().isMatrix() ) + ((MatrixObject) d).exportData(); } } catch (Exception e){ @@ -200,7 +181,7 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock { try { if (_inst.get(i) instanceof ExternalFunctionInvocationInstruction) - executeInstruction(ec, (ExternalFunctionInvocationInstruction) _inst.get(i)); + ((ExternalFunctionInvocationInstruction) _inst.get(i)).processInstruction(ec); } catch(Exception e) { throw new DMLRuntimeException(this.printBlockErrorLocation() + @@ -228,52 +209,28 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock } /** - * Given a list of parameters as data identifiers, returns a string - * representation. + * Given a list of parameters as data identifiers, returns an array + * of instruction operands. * * @param params list of data identifiers - * @return parameter string + * @return operands */ - protected String getParameterString(ArrayList<DataIdentifier> params) { - String parameterString = ""; - + protected CPOperand[] getOperands(ArrayList<DataIdentifier> params) { + CPOperand[] ret = new CPOperand[params.size()]; for (int i = 0; i < params.size(); i++) { - if (i != 0) - parameterString += ","; - DataIdentifier param = params.get(i); - - if (param.getDataType() == DataType.MATRIX) { - String s = getDataTypeString(DataType.MATRIX) + ":"; - s = s + "" + param.getName() + "" + ":"; - s = s + getValueTypeString(param.getValueType()); - parameterString += s; - continue; - } - - if (param.getDataType() == DataType.SCALAR) { - String s = getDataTypeString(DataType.SCALAR) + ":"; - s = s + "" + param.getName() + "" + ":"; - s = s + getValueTypeString(param.getValueType()); - parameterString += s; - continue; - } - - if (param.getDataType() == DataType.OBJECT) { - String s = getDataTypeString(DataType.OBJECT) + ":"; - s = s + "" + param.getName() + "" + ":"; - parameterString += s; - continue; - } + ret[i] = new CPOperand(param.getName(), + param.getValueType(), param.getDataType()); } - - return parameterString; + return ret; } /** * method to create instructions + * + * @throws DMLRuntimeException */ - protected void createInstructions() { + protected void createInstructions() throws DMLRuntimeException { _inst = new ArrayList<>(); @@ -284,32 +241,62 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock String className = _otherParams.get(ExternalFunctionStatement.CLASS_NAME); String configFile = _otherParams.get(ExternalFunctionStatement.CONFIG_FILE); - // class name cannot be null, however, configFile and execLocation can - // be null + // class name cannot be null, however, configFile and execLocation can be null if (className == null) throw new RuntimeException(this.printBlockErrorLocation() + ExternalFunctionStatement.CLASS_NAME + " not provided!"); - // assemble input and output param strings - String inputParameterString = getParameterString(getInputParams()); - String outputParameterString = getParameterString(getOutputParams()); - - // generate instruction - ExternalFunctionInvocationInstruction einst = new ExternalFunctionInvocationInstruction( - className, configFile, inputParameterString, - outputParameterString); + // assemble input and output operands + CPOperand[] inputs = getOperands(getInputParams()); + CPOperand[] outputs = getOperands(getOutputParams()); + // generate instruction + PackageFunction fun = createFunctionObject(className, configFile); + ExternalFunctionInvocationInstruction einst = + new ExternalFunctionInvocationInstruction(inputs, outputs, fun, _baseDir, InputInfo.TextCellInputInfo); + verifyFunctionInputsOutputs(fun, inputs, outputs); if (getInputParams().size() > 0) einst.setLocation(getInputParams().get(0)); else if (getOutputParams().size() > 0) einst.setLocation(getOutputParams().get(0)); else - einst.setLocation(this.getFilename(), this._beginLine, this._endLine, this._beginColumn, this._endColumn); - + einst.setLocation(getFilename(), _beginLine, _endLine, _beginColumn, _endColumn); _inst.add(einst); // block output matrices cell2BlockInst = getCell2BlockInstructions(getOutputParams(),_blockedFileNames); } + + @SuppressWarnings("unchecked") + protected PackageFunction createFunctionObject(String className, String configFile) throws DMLRuntimeException { + try { + //create instance of package function + Class<Instruction> cla = (Class<Instruction>) Class.forName(className); + Object o = cla.newInstance(); + if (!(o instanceof PackageFunction)) + throw new DMLRuntimeException(this.printBlockErrorLocation() + "Class is not of type PackageFunction"); + PackageFunction fun = (PackageFunction) o; + + //configure package function + fun.setConfiguration(configFile); + fun.setBaseDir(_baseDir); + + return fun; + } + catch (Exception e) { + throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error instantiating package function ", e ); + } + } + + protected void verifyFunctionInputsOutputs(PackageFunction fun, CPOperand[] inputs, CPOperand[] outputs) + throws DMLRuntimeException + { + // verify number of outputs + if( outputs.length != fun.getNumFunctionOutputs() ) { + throw new DMLRuntimeException( + "Number of function outputs ("+fun.getNumFunctionOutputs()+") " + + "does not match with declaration ("+outputs.length+")."); + } + } /** @@ -331,7 +318,7 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock // identify outputs that are matrices for (int i = 0; i < outputParams.size(); i++) { - if (outputParams.get(i).getDataType() == DataType.MATRIX) { + if( outputParams.get(i).getDataType().isMatrix() ) { if( _skipOutReblock.contains(outputParams.get(i).getName()) ) matricesNoReblock.add(outputParams.get(i)); else @@ -452,7 +439,7 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock // find all inputs that are matrices for (int i = 0; i < inputParams.size(); i++) { - if (inputParams.get(i).getDataType() == DataType.MATRIX) { + if( inputParams.get(i).getDataType().isMatrix() ) { if( _skipInReblock.contains(inputParams.get(i).getName()) ) matricesNoReblock.add(inputParams.get(i)); else @@ -559,168 +546,7 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock return b2cinst; //null if no input matrices } - - /** - * Method to execute an external function invocation instruction. - * - * @param ec execution context - * @param inst external function invocation instructions - * @throws DMLRuntimeException if DMLRuntimeException occurs - */ - @SuppressWarnings("unchecked") - public void executeInstruction(ExecutionContext ec, ExternalFunctionInvocationInstruction inst) - throws DMLRuntimeException - { - String className = inst.getClassName(); - String configFile = inst.getConfigFile(); - - if (className == null) - throw new DMLRuntimeException(this.printBlockErrorLocation() + "Class name can't be null"); - - // create instance of package function. - Object o; - try - { - Class<Instruction> cla = (Class<Instruction>) Class.forName(className); - o = cla.newInstance(); - } - catch (Exception e) - { - throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error generating package function object " ,e ); - } - - if (!(o instanceof PackageFunction)) - throw new DMLRuntimeException(this.printBlockErrorLocation() + "Class is not of type PackageFunction"); - - PackageFunction func = (PackageFunction) o; - - // add inputs to this package function based on input parameter - // and their mappings. - setupInputs(func, inst.getInputParams(), ec.getVariables()); - func.setConfiguration(configFile); - func.setBaseDir(_baseDir); - - //executes function - func.execute(); - - // verify output of function execution matches declaration - // and add outputs to variableMapping and Metadata - verifyAndAttachOutputs(ec, func, inst.getOutputParams()); - } - - /** - * Method to verify that function outputs match with declared outputs - * - * @param ec execution context - * @param returnFunc package function - * @param outputParams output parameters - * @throws DMLRuntimeException if DMLRuntimeException occurs - */ - protected void verifyAndAttachOutputs(ExecutionContext ec, PackageFunction returnFunc, - String outputParams) throws DMLRuntimeException { - - ArrayList<String> outputs = getParameters(outputParams); - // make sure they are of equal size first - - if (outputs.size() != returnFunc.getNumFunctionOutputs()) { - throw new DMLRuntimeException( - "Number of function outputs ("+returnFunc.getNumFunctionOutputs()+") " + - "does not match with declaration ("+outputs.size()+")."); - } - - // iterate over each output and verify that type matches - for (int i = 0; i < outputs.size(); i++) { - StringTokenizer tk = new StringTokenizer(outputs.get(i), ":"); - ArrayList<String> tokens = new ArrayList<>(); - while (tk.hasMoreTokens()) { - tokens.add(tk.nextToken()); - } - - if (returnFunc.getFunctionOutput(i).getType() == FunctionParameterType.Matrix) { - Matrix m = (Matrix) returnFunc.getFunctionOutput(i); - - if (!(tokens.get(0).equals(getFunctionParameterDataTypeString(FunctionParameterType.Matrix))) - || !(tokens.get(2).equals(getMatrixValueTypeString(m.getValueType())))) { - throw new DMLRuntimeException( - "Function output '"+outputs.get(i)+"' does not match with declaration."); - } - - // add result to variableMapping - String varName = tokens.get(1); - MatrixObject newVar = createOutputMatrixObject( m ); - newVar.setVarName(varName); - - //getVariables().put(varName, newVar); //put/override in local symbol table - ec.setVariable(varName, newVar); - - continue; - } - - if (returnFunc.getFunctionOutput(i).getType() == FunctionParameterType.Scalar) { - Scalar s = (Scalar) returnFunc.getFunctionOutput(i); - - if (!tokens.get(0).equals(getFunctionParameterDataTypeString(FunctionParameterType.Scalar)) - || !tokens.get(2).equals( - getScalarValueTypeString(s.getScalarType()))) { - throw new DMLRuntimeException( - "Function output '"+outputs.get(i)+"' does not match with declaration."); - } - - // allocate and set appropriate object based on type - ScalarObject scalarObject = null; - ScalarValueType type = s.getScalarType(); - switch (type) { - case Integer: - scalarObject = new IntObject(tokens.get(1), - Long.parseLong(s.getValue())); - break; - case Double: - scalarObject = new DoubleObject(tokens.get(1), - Double.parseDouble(s.getValue())); - break; - case Boolean: - scalarObject = new BooleanObject(tokens.get(1), - Boolean.parseBoolean(s.getValue())); - break; - case Text: - scalarObject = new StringObject(tokens.get(1), s.getValue()); - break; - default: - throw new DMLRuntimeException( - "Unknown scalar value type '"+type+"' of output '"+outputs.get(i)+"'."); - } - - //this.getVariables().put(tokens.get(1), scalarObject); - ec.setVariable(tokens.get(1), scalarObject); - continue; - } - - if (returnFunc.getFunctionOutput(i).getType() == FunctionParameterType.Object) { - if (!tokens.get(0).equals(getFunctionParameterDataTypeString(FunctionParameterType.Object))) { - throw new DMLRuntimeException( - "Function output '"+outputs.get(i)+"' does not match with declaration."); - } - - throw new DMLRuntimeException( - "Object types not yet supported"); - - // continue; - } - - throw new DMLRuntimeException( - "Unknown data type '"+returnFunc.getFunctionOutput(i).getType()+"' " + - "of output '"+outputs.get(i)+"'."); - } - } - - protected MatrixObject createOutputMatrixObject( Matrix m ) - throws CacheException - { - MatrixCharacteristics mc = new MatrixCharacteristics(m.getNumRows(),m.getNumCols(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()); - MetaDataFormat mfmd = new MetaDataFormat(mc, OutputInfo.TextCellOutputInfo, InputInfo.TextCellInputInfo); - return new MatrixObject(ValueType.DOUBLE, m.getFilePath(), mfmd); - } - + /** * Method to get string representation of scalar value type * @@ -733,184 +559,6 @@ public class ExternalFunctionProgramBlock extends FunctionProgramBlock else return scalarType.toString(); } - - /** - * Method to parse inputs, update labels, and add to package function. - * - * @param func package function - * @param inputParams input parameters - * @param variableMapping local variable map - */ - protected void setupInputs (PackageFunction func, String inputParams, LocalVariableMap variableMapping) { - ArrayList<String> inputs = getParameters(inputParams); - ArrayList<FunctionParameter> inputObjects = getInputObjects(inputs, variableMapping); - func.setNumFunctionInputs(inputObjects.size()); - for (int i = 0; i < inputObjects.size(); i++) - func.setInput(inputObjects.get(i), i); - } - - /** - * Method to convert string representation of input into function input - * object. - * - * @param inputs list of inputs - * @param variableMapping local variable map - * @return list of function parameters - */ - protected ArrayList<FunctionParameter> getInputObjects(ArrayList<String> inputs, - LocalVariableMap variableMapping) { - ArrayList<FunctionParameter> inputObjects = new ArrayList<>(); - - for (int i = 0; i < inputs.size(); i++) { - ArrayList<String> tokens = new ArrayList<>(); - StringTokenizer tk = new StringTokenizer(inputs.get(i), ":"); - while (tk.hasMoreTokens()) { - tokens.add(tk.nextToken()); - } - - if (tokens.get(0).equals("Matrix")) { - String varName = tokens.get(1); - MatrixObject mobj = (MatrixObject) variableMapping.get(varName); - MatrixCharacteristics mc = mobj.getMatrixCharacteristics(); - Matrix m = new Matrix(mobj.getFileName(), - mc.getRows(), mc.getCols(), - getMatrixValueType(tokens.get(2))); - modifyInputMatrix(m, mobj); - inputObjects.add(m); - } - - if (tokens.get(0).equals("Scalar")) { - String varName = tokens.get(1); - ScalarObject so = (ScalarObject) variableMapping.get(varName); - Scalar s = new Scalar(getScalarValueType(tokens.get(2)), - so.getStringValue()); - inputObjects.add(s); - - } - - if (tokens.get(0).equals("Object")) { - String varName = tokens.get(1); - Object o = variableMapping.get(varName); - BinaryObject obj = new BinaryObject(o); - inputObjects.add(obj); - - } - } - - return inputObjects; - - } - - protected void modifyInputMatrix(Matrix m, MatrixObject mobj) - { - //do nothing, intended for extensions - } - - /** - * Converts string representation of scalar value type to enum type - * - * @param string scalar value string - * @return scalar value type - */ - protected ScalarValueType getScalarValueType(String string) { - if (string.equals("String")) - return ScalarValueType.Text; - else - return ScalarValueType.valueOf(string); - } - - /** - * Get string representation of matrix value type - * - * @param t matrix value type - * @return matrix value type as string - */ - protected String getMatrixValueTypeString(Matrix.ValueType t) { - return t.toString(); - } - - /** - * Converts string representation of matrix value type into enum type - * - * @param string matrix value type as string - * @return matrix value type - */ - protected Matrix.ValueType getMatrixValueType(String string) { - return Matrix.ValueType.valueOf(string); - } - - /** - * Method to break the comma separated input parameters into an arraylist of - * parameters - * - * @param inputParams input parameters - * @return list of string inputs - */ - protected ArrayList<String> getParameters(String inputParams) { - ArrayList<String> inputs = new ArrayList<>(); - - StringTokenizer tk = new StringTokenizer(inputParams, ","); - while (tk.hasMoreTokens()) { - inputs.add(tk.nextToken()); - } - - return inputs; - } - - /** - * Get string representation for data type - * - * @param d data type - * @return string representation of data type - */ - protected String getDataTypeString(DataType d) { - if (d.equals(DataType.MATRIX)) - return "Matrix"; - - if (d.equals(DataType.SCALAR)) - return "Scalar"; - - if (d.equals(DataType.OBJECT)) - return "Object"; - - throw new RuntimeException("Should never come here"); - } - - /** - * Method to get string representation of function parameter type. - * - * @param t function parameter type - * @return function parameter type as string - */ - protected String getFunctionParameterDataTypeString(FunctionParameterType t) { - return t.toString(); - } - - /** - * Get string representation of value type - * - * @param v value type - * @return value type string - */ - protected String getValueTypeString(ValueType v) { - if (v.equals(ValueType.DOUBLE)) - return "Double"; - - if (v.equals(ValueType.INT)) - return "Integer"; - - if (v.equals(ValueType.BOOLEAN)) - return "Boolean"; - - if (v.equals(ValueType.STRING)) - return "String"; - - throw new RuntimeException("Should never come here"); - } - - public HashMap<String,String> getOtherParams() { - return _otherParams; - } @Override public String printBlockErrorLocation(){ http://git-wip-us.apache.org/repos/asf/systemml/blob/c1ed7915/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java index 8be1455..48a4f52 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java @@ -22,20 +22,15 @@ package org.apache.sysml.runtime.controlprogram; import java.util.ArrayList; import java.util.HashMap; -import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.ExternalFunctionStatement; -import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.MetaDataFormat; +import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.matrix.data.InputInfo; -import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.udf.ExternalFunctionInvocationInstruction; -import org.apache.sysml.udf.Matrix; +import org.apache.sysml.udf.PackageFunction; /** * CP external function program block, that overcomes the need for @@ -51,10 +46,6 @@ import org.apache.sysml.udf.Matrix; */ public class ExternalFunctionProgramBlockCP extends ExternalFunctionProgramBlock { - - public static String DEFAULT_FILENAME = "ext_funct"; - private static IDSequence _defaultSeq = new IDSequence(); - /** * Constructor that also provides otherParams that are needed for external * functions. Remaining parameters will just be passed to constructor for @@ -90,45 +81,24 @@ public class ExternalFunctionProgramBlockCP extends ExternalFunctionProgramBlock @Override public void execute(ExecutionContext ec) throws DMLRuntimeException { - _runID = _idSeq.getNextID(); + if( _inst.size() != 1 ) + throw new DMLRuntimeException("Invalid number of instructions: "+_inst.size()); - ExternalFunctionInvocationInstruction inst = null; - - // execute package function - for (int i=0; i < _inst.size(); i++) - { - try { - inst = (ExternalFunctionInvocationInstruction)_inst.get(i); - inst._namespace = _namespace; - inst._functionName = _functionName; - executeInstruction( ec, inst ); - } - catch (Exception e){ - throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating instruction " + i + " in external function programBlock. inst: " + inst.toString(), e); - } + // execute package function via ExternalFunctionInvocationInstruction + try { + _inst.get(0).processInstruction(ec); + } + catch (Exception e){ + throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating external function: " + + DMLProgram.constructFunctionKey(_namespace, _functionName), e); } // check return values checkOutputParameters(ec.getVariables()); } - - /** - * Executes the external function instruction. - * - */ - @Override - public void executeInstruction(ExecutionContext ec, ExternalFunctionInvocationInstruction inst) - throws DMLRuntimeException - { - // After the udf framework rework, we moved the code of ExternalFunctionProgramBlockCP - // to ExternalFunctionProgramBlock and hence hence both types of external functions can - // share the same code path here. - super.executeInstruction(ec, inst); - } - @Override - protected void createInstructions() + protected void createInstructions() throws DMLRuntimeException { _inst = new ArrayList<>(); @@ -141,48 +111,16 @@ public class ExternalFunctionProgramBlockCP extends ExternalFunctionProgramBlock throw new RuntimeException(this.printBlockErrorLocation() + ExternalFunctionStatement.CLASS_NAME + " not provided!"); // assemble input and output param strings - String inputParameterString = getParameterString(getInputParams()); - String outputParameterString = getParameterString(getOutputParams()); - - // generate instruction - ExternalFunctionInvocationInstruction einst = new ExternalFunctionInvocationInstruction( - className, configFile, inputParameterString, - outputParameterString); - - _inst.add(einst); - - } - - @Override - protected void modifyInputMatrix(Matrix m, MatrixObject mobj) - { - //pass in-memory object to external function - m.setMatrixObject( mobj ); - } - - @Override - protected MatrixObject createOutputMatrixObject(Matrix m) - { - MatrixObject ret = m.getMatrixObject(); + CPOperand[] inputs = getOperands(getInputParams()); + CPOperand[] outputs = getOperands(getOutputParams()); - if( ret == null ) //otherwise, pass in-memory matrix from extfunct back to invoking program - { - MatrixCharacteristics mc = new MatrixCharacteristics(m.getNumRows(),m.getNumCols(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()); - MetaDataFormat mfmd = new MetaDataFormat(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo); - ret = new MatrixObject(ValueType.DOUBLE, m.getFilePath(), mfmd); - } + // generate instruction + PackageFunction fun = createFunctionObject(className, configFile); + ExternalFunctionInvocationInstruction einst = + new ExternalFunctionInvocationInstruction(inputs, outputs, fun, _baseDir, InputInfo.BinaryBlockInputInfo); + verifyFunctionInputsOutputs(fun, inputs, outputs); - //for allowing in-memory packagesupport matrices w/o filesnames - if( ret.getFileName().equals( DEFAULT_FILENAME ) ) - { - ret.setFileName( createDefaultOutputFilePathAndName() ); - } - - return ret; - } - - public String createDefaultOutputFilePathAndName( ) { - return _baseDir + DEFAULT_FILENAME + _defaultSeq.getNextID(); + _inst.add(einst); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/c1ed7915/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java index fa4838a..a8e1d17 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java @@ -130,8 +130,7 @@ public class FunctionProgramBlock extends ProgramBlock protected void checkOutputParameters( LocalVariableMap vars ) { - for( DataIdentifier diOut : _outputParams ) - { + for( DataIdentifier diOut : _outputParams ) { String varName = diOut.getName(); Data dat = vars.get( varName ); if( dat == null ) @@ -140,7 +139,6 @@ public class FunctionProgramBlock extends ProgramBlock LOG.warn("Function output "+ varName +" has wrong data type: "+dat.getDataType()+"."); else if( dat.getValueType() != diOut.getValueType() ) LOG.warn("Function output "+ varName +" has wrong value type: "+dat.getValueType()+"."); - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/c1ed7915/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java index 93717e5..70eb4bf 100644 --- a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java +++ b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java @@ -19,9 +19,27 @@ package org.apache.sysml.udf; +import java.util.ArrayList; + +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.parser.Expression; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.instructions.Instruction; +import org.apache.sysml.runtime.instructions.cp.BooleanObject; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.DoubleObject; +import org.apache.sysml.runtime.instructions.cp.IntObject; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; +import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.MetaDataFormat; +import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.udf.Matrix.ValueType; +import org.apache.sysml.udf.Scalar.ScalarValueType; /** * Class to maintain external function invocation instructions. @@ -31,55 +49,146 @@ import org.apache.sysml.runtime.instructions.Instruction; */ public class ExternalFunctionInvocationInstruction extends Instruction { + private static final IDSequence _defaultSeq = new IDSequence(); - public static final String ELEMENT_DELIM = ":"; - - public String _namespace; - public String _functionName; - protected String className; // name of class that contains the function - protected String configFile; // optional configuration file parameter - protected String inputParams; // string representation of input parameters - protected String outputParams; // string representation of output parameters + protected final CPOperand[] inputs; + protected final CPOperand[] outputs; + protected final PackageFunction fun; + protected final String baseDir; + protected final InputInfo iinfo; - public ExternalFunctionInvocationInstruction(String className, - String configFile, String inputParams, - String outputParams) + public ExternalFunctionInvocationInstruction(CPOperand[] inputs, CPOperand[] outputs, + PackageFunction fun, String baseDir, InputInfo format) { + this.inputs = inputs; + this.outputs = outputs; + this.fun = fun; + this.baseDir = baseDir; + this.iinfo = format; + } + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException { - this.className = className; - this.configFile = configFile; - this.inputParams = inputParams; - this.outputParams = outputParams; + // get the inputs, wrapped into external data types + fun.setFunctionInputs(getInputObjects(inputs, ec.getVariables())); + + //executes function + fun.execute(); + + // get and verify the outputs + verifyAndAttachOutputs(ec, fun, outputs); } - public String getClassName() { - return className; - } - - public String getConfigFile() { - return configFile; + @SuppressWarnings("incomplete-switch") + private ArrayList<FunctionParameter> getInputObjects(CPOperand[] inputs, LocalVariableMap vars) { + ArrayList<FunctionParameter> ret = new ArrayList<>(); + for( CPOperand input : inputs ) { + switch( input.getDataType() ) { + case MATRIX: + MatrixObject mobj = (MatrixObject) vars.get(input.getName()); + ret.add(new Matrix(mobj, getMatrixValueType(input.getValueType()))); + break; + case SCALAR: + ScalarObject so = (ScalarObject) vars.get(input.getName()); + ret.add(new Scalar(getScalarValueType(input.getValueType()), so.getStringValue())); + break; + case OBJECT: + ret.add(new BinaryObject(vars.get(input.getName()))); + break; + } + } + return ret; } - - public String getInputParams() { - return inputParams; + + private ScalarValueType getScalarValueType(Expression.ValueType vt) { + switch(vt) { + case STRING: return ScalarValueType.Text; + case DOUBLE: return ScalarValueType.Double; + case INT: return ScalarValueType.Integer; + case BOOLEAN: return ScalarValueType.Boolean; + default: + throw new RuntimeException("Unknown type: "+vt.name()); + } } - - public String getOutputParams() { - return outputParams; + + private ValueType getMatrixValueType(Expression.ValueType vt) { + switch(vt) { + case DOUBLE: return ValueType.Double; + case INT: return ValueType.Integer; + default: + throw new RuntimeException("Unknown type: "+vt.name()); + } } - - @Override - public String toString() { - return className + ELEMENT_DELIM + - configFile + ELEMENT_DELIM + - inputParams + ELEMENT_DELIM + - outputParams; + + private void verifyAndAttachOutputs(ExecutionContext ec, PackageFunction fun, CPOperand[] outputs) + throws DMLRuntimeException + { + for( int i = 0; i < outputs.length; i++) { + CPOperand output = outputs[i]; + switch( fun.getFunctionOutput(i).getType() ) { + case Matrix: + Matrix m = (Matrix) fun.getFunctionOutput(i); + MatrixObject newVar = createOutputMatrixObject( m ); + newVar.setVarName(output.getName()); + ec.setVariable(output.getName(), newVar); + break; + case Scalar: + Scalar s = (Scalar) fun.getFunctionOutput(i); + ScalarObject scalarObject = null; + switch( s.getScalarType() ) { + case Integer: + scalarObject = new IntObject(output.getName(), + Long.parseLong(s.getValue())); + break; + case Double: + scalarObject = new DoubleObject(output.getName(), + Double.parseDouble(s.getValue())); + break; + case Boolean: + scalarObject = new BooleanObject(output.getName(), + Boolean.parseBoolean(s.getValue())); + break; + case Text: + scalarObject = new StringObject(output.getName(), s.getValue()); + break; + default: + throw new DMLRuntimeException("Unknown scalar value type '" + + s.getScalarType()+"' of output '"+output.getName()+"'."); + } + ec.setVariable(output.getName(), scalarObject); + break; + default: + throw new DMLRuntimeException("Unsupported data type: " + +fun.getFunctionOutput(i).getType().name()); + } + } } - - @Override - public void processInstruction(ExecutionContext ec) - throws DMLRuntimeException + + private MatrixObject createOutputMatrixObject(Matrix m) throws DMLRuntimeException { - //do nothing (not applicable because this instruction is only used as - //meta data container) + MatrixObject ret = m.getMatrixObject(); + + if( ret == null ) { //otherwise, pass in-memory matrix from extfunct back to invoking program + MatrixCharacteristics mc = new MatrixCharacteristics(m.getNumRows(),m.getNumCols(), + ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()); + MetaDataFormat mfmd = new MetaDataFormat(mc, InputInfo.getMatchingOutputInfo(iinfo), iinfo); + ret = new MatrixObject(Expression.ValueType.DOUBLE, m.getFilePath(), mfmd); + } + + //for allowing in-memory packagesupport matrices w/o file names + if( ret.getFileName().equals( Matrix.DEFAULT_FILENAME ) ) { + ret.setFileName( createDefaultOutputFilePathAndName() ); + } + + return ret; + } + + private String createDefaultOutputFilePathAndName( ) { + StringBuilder sb = new StringBuilder(); + sb.append(baseDir); + sb.append(Matrix.DEFAULT_FILENAME); + sb.append(_defaultSeq.getNextID()); + return sb.toString(); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/c1ed7915/src/main/java/org/apache/sysml/udf/Matrix.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/Matrix.java b/src/main/java/org/apache/sysml/udf/Matrix.java index 5da92e3..b75b915 100644 --- a/src/main/java/org/apache/sysml/udf/Matrix.java +++ b/src/main/java/org/apache/sysml/udf/Matrix.java @@ -24,7 +24,6 @@ import java.io.IOException; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.parser.Expression; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.ExternalFunctionProgramBlockCP; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.io.MatrixReader; import org.apache.sysml.runtime.io.MatrixReaderFactory; @@ -43,13 +42,13 @@ import org.apache.sysml.runtime.util.DataConverter; */ public class Matrix extends FunctionParameter { - private static final long serialVersionUID = -1058329938431848909L; + public static final String DEFAULT_FILENAME = "ext_funct"; - private String _filePath; - private long _rows; - private long _cols; - private ValueType _vType; + private String _filePath; + private long _rows; + private long _cols; + private ValueType _vType; private MatrixObject _mo; public enum ValueType { @@ -67,8 +66,7 @@ public class Matrix extends FunctionParameter * @param vType value type */ public Matrix(long rows, long cols, ValueType vType) { - this( ExternalFunctionProgramBlockCP.DEFAULT_FILENAME, - rows, cols, vType ); + this( DEFAULT_FILENAME, rows, cols, vType ); } /** @@ -88,13 +86,20 @@ public class Matrix extends FunctionParameter _vType = vType; } - public void setMatrixObject( MatrixObject mo ) - { + public Matrix(MatrixObject mo, ValueType vType) { + super(FunctionParameterType.Matrix); + _filePath = mo.getFileName(); + _rows = mo.getNumRows(); + _cols = mo.getNumColumns(); + _vType = vType; _mo = mo; } - public MatrixObject getMatrixObject() - { + public void setMatrixObject( MatrixObject mo ) { + _mo = mo; + } + + public MatrixObject getMatrixObject() { return _mo; } http://git-wip-us.apache.org/repos/asf/systemml/blob/c1ed7915/src/main/java/org/apache/sysml/udf/PackageFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/PackageFunction.java b/src/main/java/org/apache/sysml/udf/PackageFunction.java index cf9afe3..7d61639 100644 --- a/src/main/java/org/apache/sysml/udf/PackageFunction.java +++ b/src/main/java/org/apache/sysml/udf/PackageFunction.java @@ -85,7 +85,13 @@ public abstract class PackageFunction implements Serializable * @return function parameter */ public abstract FunctionParameter getFunctionOutput(int pos); - + + public final void setFunctionInputs(ArrayList<FunctionParameter> inputs) { + setNumFunctionInputs(inputs.size()); + for (int i = 0; i < inputs.size(); i++) + setInput(inputs.get(i), i); + } + /** * Method to set the number of inputs for this package function * http://git-wip-us.apache.org/repos/asf/systemml/blob/c1ed7915/src/main/java/org/apache/sysml/udf/lib/DynamicProjectMatrixCP.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/lib/DynamicProjectMatrixCP.java b/src/main/java/org/apache/sysml/udf/lib/DynamicProjectMatrixCP.java index 6d93e39..8e77a8e 100644 --- a/src/main/java/org/apache/sysml/udf/lib/DynamicProjectMatrixCP.java +++ b/src/main/java/org/apache/sysml/udf/lib/DynamicProjectMatrixCP.java @@ -51,7 +51,7 @@ public class DynamicProjectMatrixCP extends PackageFunction public void execute() { try - { + { Matrix mD = (Matrix) this.getFunctionInput(0); Matrix mC = (Matrix) this.getFunctionInput(1); MatrixBlock mbD = mD.getMatrixObject().acquireRead();
