[SYSTEMML-555] New transformencode builtin function over frames, tests Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b1451c0c Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b1451c0c Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b1451c0c
Branch: refs/heads/master Commit: b1451c0cd2a7d80cdad799d4a7a9971d3f5c86af Parents: 17b5221 Author: Matthias Boehm <[email protected]> Authored: Thu Apr 7 13:30:02 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Apr 7 13:30:38 2016 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/parser/DMLTranslator.java | 56 +++++- .../org/apache/sysml/parser/Expression.java | 2 +- .../ParameterizedBuiltinFunctionExpression.java | 75 +++++++- .../org/apache/sysml/parser/StatementBlock.java | 8 +- .../instructions/CPInstructionParser.java | 5 + .../runtime/instructions/cp/CPInstruction.java | 2 +- ...ReturnParameterizedBuiltinCPInstruction.java | 95 +++++++++++ .../functions/jmlc/FrameEncodeTest.java | 170 +++++++++++++++++++ src/test/scripts/functions/jmlc/transform7.dml | 31 ++++ 9 files changed, 430 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 2adcd77..87ad226 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -1177,8 +1177,12 @@ public class DMLTranslator else if ( source instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)source).multipleReturns() ) { // construct input hops Hop fcall = processMultipleReturnBuiltinFunctionExpression((BuiltinFunctionExpression)source, mas.getTargetList(), ids); - output.add(fcall); - + output.add(fcall); + } + else if ( source instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression)source).multipleReturns() ) { + // construct input hops + Hop fcall = processMultipleReturnParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression)source, mas.getTargetList(), ids); + output.add(fcall); } else throw new LanguageException("Class \"" + source.getClass() + "\" is not supported in Multiple Assignment statements"); @@ -1892,6 +1896,54 @@ public class DMLTranslator } /** + * + * @param source + * @param targetList + * @param hops + * @return + * @throws ParseException + */ + private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, ArrayList<DataIdentifier> targetList, + HashMap<String, Hop> hops) throws ParseException + { + FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN; + String nameSpace = DMLProgram.INTERNAL_NAMESPACE; + + // Create an array list to hold the outputs of this lop. + // Exact list of outputs are added based on opcode. + ArrayList<Hop> outputs = new ArrayList<Hop>(); + + // Construct Hop for current builtin function expression based on its type + Hop currBuiltinOp = null; + switch (source.getOpCode()) { + case TRANSFORMENCODE: + ArrayList<Hop> inputs = new ArrayList<Hop>(); + inputs.add( processExpression(source.getVarParam("target"), null, hops) ); + inputs.add( processExpression(source.getVarParam("spec"), null, hops) ); + String[] outputNames = new String[targetList.size()]; + outputNames[0] = ((DataIdentifier)targetList.get(0)).getName(); + outputNames[1] = ((DataIdentifier)targetList.get(1)).getName(); + outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[0])); + outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[1])); + + currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs); + break; + + default: + throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnParameterizedBuiltinFunctionExpression(): " + source.getOpCode()); + } + + // set properties for created hops based on outputs of source expression + for ( int i=0; i < source.getOutputs().length; i++ ) { + setIdentifierParams( outputs.get(i), source.getOutputs()[i]); + outputs.get(i).setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); + } + currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); + + return currBuiltinOp; + } + + /** * Construct Hops from parse tree : Process ParameterizedBuiltinFunction Expression in an * assignment statement * http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 8695cab..654e7f1 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -127,7 +127,7 @@ public abstract class Expression GROUPEDAGG, RMEMPTY, REPLACE, ORDER, // Distribution Functions CDF, INVCDF, PNORM, QNORM, PT, QT, PF, QF, PCHISQ, QCHISQ, PEXP, QEXP, - TRANSFORM, TRANSFORMAPPLY, TRANSFORMDECODE, + TRANSFORM, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMENCODE, INVALID }; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java index adae9ad..e73a7c5 100644 --- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java @@ -67,6 +67,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier opcodeMap.put("transform", Expression.ParameterizedBuiltinFunctionOp.TRANSFORM); opcodeMap.put("transformapply", Expression.ParameterizedBuiltinFunctionOp.TRANSFORMAPPLY); opcodeMap.put("transformdecode", Expression.ParameterizedBuiltinFunctionOp.TRANSFORMDECODE); + opcodeMap.put("transformencode", Expression.ParameterizedBuiltinFunctionOp.TRANSFORMENCODE); } public static HashMap<Expression.ParameterizedBuiltinFunctionOp, ParamBuiltinOp> pbHopMap; @@ -182,11 +183,8 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier // validate all input parameters for ( String s : getVarParams().keySet() ) { Expression paramExpr = getVarParam(s); - - if (paramExpr instanceof FunctionCallIdentifier){ - raiseValidateError("UDF function call not supported as parameter to built-in function call", false); - } - + if (paramExpr instanceof FunctionCallIdentifier) + raiseValidateError("UDF function call not supported as parameter to built-in function call", false); paramExpr.validateExpression(ids, constVars, conditional); } @@ -243,8 +241,42 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier break; default: //always unconditional (because unsupported operation) - raiseValidateError("Unsupported parameterized function "+ this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); + raiseValidateError("Unsupported parameterized function "+ getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); + } + return; + } + + @Override + public void validateExpression(MultiAssignmentStatement stmt, HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional) + throws LanguageException + { + // validate all input parameters + for ( String s : getVarParams().keySet() ) { + Expression paramExpr = getVarParam(s); + if (paramExpr instanceof FunctionCallIdentifier) + raiseValidateError("UDF function call not supported as parameter to built-in function call", false); + paramExpr.validateExpression(ids, constVars, conditional); + } + + _outputs = new Identifier[stmt.getTargetList().size()]; + int count = 0; + for (DataIdentifier outParam: stmt.getTargetList()){ + DataIdentifier tmp = new DataIdentifier(outParam); + tmp.setAllPositions(this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn()); + _outputs[count++] = tmp; + } + + switch (this.getOpCode()) { + case TRANSFORMENCODE: + DataIdentifier out1 = (DataIdentifier) getOutputs()[0]; + DataIdentifier out2 = (DataIdentifier) getOutputs()[1]; + + validateTransformEncode(out1, out2, conditional); + break; + default: //always unconditional (because unsupported operation) + raiseValidateError("Unsupported parameterized function "+ getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); } + return; } @@ -335,6 +367,30 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier output.setDimensions(-1, -1); } + /** + * + * @param output + * @param conditional + * @throws LanguageException + */ + private void validateTransformEncode(DataIdentifier output1, DataIdentifier output2, boolean conditional) + throws LanguageException + { + //validate data / metadata (recode maps) + checkDataType("transformencode", TF_FN_PARAM_DATA, DataType.FRAME, conditional); + + //validate specification + checkDataValueType("transformencode", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); + + //set output dimensions + output1.setDataType(DataType.MATRIX); + output1.setValueType(ValueType.DOUBLE); + output1.setDimensions(-1, -1); + output2.setDataType(DataType.FRAME); + output2.setValueType(ValueType.STRING); + output2.setDimensions(-1, -1); + } + private void validateReplace(DataIdentifier output, boolean conditional) throws LanguageException { //check existence and correctness of arguments Expression target = getVarParam("target"); @@ -713,6 +769,11 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier @Override public boolean multipleReturns() { - return false; + switch(_opcode) { + case TRANSFORMENCODE: + return true; + default: + return false; + } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index ef38989..c756db0 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -189,7 +189,8 @@ public class StatementBlock extends LiveVariableAnalysis } else sourceExpr = ((MultiAssignmentStatement)stmt).getSource(); - if ( sourceExpr instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)sourceExpr).multipleReturns() ) + if ( (sourceExpr instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)sourceExpr).multipleReturns()) + || (sourceExpr instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression)sourceExpr).multipleReturns())) return false; //function calls (only mergable if inlined dml-bodied function) @@ -714,7 +715,8 @@ public class StatementBlock extends LiveVariableAnalysis FunctionCallIdentifier fci = (FunctionCallIdentifier)source; fci.validateExpression(dmlProg, ids.getVariables(), currConstVars, conditional); } - else if ( source instanceof BuiltinFunctionExpression && ((DataIdentifier)source).multipleReturns()) { + else if ( (source instanceof BuiltinFunctionExpression || source instanceof ParameterizedBuiltinFunctionExpression) + && ((DataIdentifier)source).multipleReturns()) { source.validateExpression(mas, ids.getVariables(), currConstVars, conditional); } else @@ -744,7 +746,7 @@ public class StatementBlock extends LiveVariableAnalysis ids.addVariable(target.getName(), target); } } - else if ( source instanceof BuiltinFunctionExpression ) { + else if ( source instanceof BuiltinFunctionExpression || source instanceof ParameterizedBuiltinFunctionExpression ) { Identifier[] outputs = source.getOutputs(); for (int j=0; j < targetList.size(); j++) { ids.addVariable(targetList.get(j).getName(), (DataIdentifier)outputs[j]); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index a84b912..f3ac621 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -46,6 +46,7 @@ import org.apache.sysml.runtime.instructions.cp.MMChainCPInstruction; import org.apache.sysml.runtime.instructions.cp.MMTSJCPInstruction; import org.apache.sysml.runtime.instructions.cp.MatrixReshapeCPInstruction; import org.apache.sysml.runtime.instructions.cp.MultiReturnBuiltinCPInstruction; +import org.apache.sysml.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction; import org.apache.sysml.runtime.instructions.cp.PMMJCPInstruction; import org.apache.sysml.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuantilePickCPInstruction; @@ -182,6 +183,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "transform" , CPINSTRUCTION_TYPE.ParameterizedBuiltin); String2CPInstructionType.put( "transformapply",CPINSTRUCTION_TYPE.ParameterizedBuiltin); String2CPInstructionType.put( "transformdecode",CPINSTRUCTION_TYPE.ParameterizedBuiltin); + String2CPInstructionType.put( "transformencode",CPINSTRUCTION_TYPE.MultiReturnParameterizedBuiltin); // Variable Instruction Opcodes String2CPInstructionType.put( "assignvar" , CPINSTRUCTION_TYPE.Variable); @@ -347,6 +349,9 @@ public class CPInstructionParser extends InstructionParser else //exectype CP_FILE return ParameterizedBuiltinCPFileInstruction.parseInstruction(str); + case MultiReturnParameterizedBuiltin: + return MultiReturnParameterizedBuiltinCPInstruction.parseInstruction(str); + case MultiReturnBuiltin: return MultiReturnBuiltinCPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java index 4f3f530..4dbd9a3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java @@ -30,7 +30,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class CPInstruction extends Instruction { - public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, StringInit, CentralMoment, Covariance, UaggOuterChain }; + public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, StringInit, CentralMoment, Covariance, UaggOuterChain }; protected CPINSTRUCTION_TYPE _cptype; protected Operator _optr; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnParameterizedBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnParameterizedBuiltinCPInstruction.java new file mode 100644 index 0000000..aff9e74 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnParameterizedBuiltinCPInstruction.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import java.util.ArrayList; + +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.FrameBlock; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.runtime.transform.encode.Encoder; +import org.apache.sysml.runtime.transform.encode.EncoderFactory; + + +public class MultiReturnParameterizedBuiltinCPInstruction extends ComputationCPInstruction +{ + protected ArrayList<CPOperand> _outputs; + + public MultiReturnParameterizedBuiltinCPInstruction(Operator op, CPOperand input1, CPOperand input2, ArrayList<CPOperand> outputs, String opcode, String istr ) { + super(op, input1, input2, outputs.get(0), opcode, istr); + _cptype = CPINSTRUCTION_TYPE.MultiReturnBuiltin; + _outputs = outputs; + } + + public CPOperand getOutput(int i) { + return _outputs.get(i); + } + + /** + * + * @param str + * @return + * @throws DMLRuntimeException + */ + public static MultiReturnParameterizedBuiltinCPInstruction parseInstruction ( String str ) + throws DMLRuntimeException + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + ArrayList<CPOperand> outputs = new ArrayList<CPOperand>(); + String opcode = parts[0]; + + if ( opcode.equalsIgnoreCase("transformencode") ) { + // one input and two outputs + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + outputs.add ( new CPOperand(parts[3], ValueType.DOUBLE, DataType.MATRIX) ); + outputs.add ( new CPOperand(parts[4], ValueType.STRING, DataType.FRAME) ); + return new MultiReturnParameterizedBuiltinCPInstruction(null, in1, in2, outputs, opcode, str); + } + else { + throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode); + } + + } + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException + { + //obtain and pin input frame + FrameBlock fin = ec.getFrameInput(input1.getName()); + String spec = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getStringValue(); + + //execute block transform encode + Encoder encoder = EncoderFactory.createEncoder(spec, fin.getNumColumns()); + MatrixBlock data = encoder.encode(fin, new MatrixBlock(fin.getNumRows(), fin.getNumColumns(), false)); //build and apply + FrameBlock meta = encoder.getMetaData(new FrameBlock(fin.getNumColumns(), ValueType.STRING)); + + //release input and outputs + ec.releaseFrameInput(input1.getName()); + ec.setMatrixOutput(getOutput(0).getName(), data); + ec.setFrameOutput(getOutput(1).getName(), meta); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/test/java/org/apache/sysml/test/integration/functions/jmlc/FrameEncodeTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/FrameEncodeTest.java b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/FrameEncodeTest.java new file mode 100644 index 0000000..aed90f3 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/FrameEncodeTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.jmlc; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.api.DMLException; +import org.apache.sysml.api.jmlc.Connection; +import org.apache.sysml.api.jmlc.PreparedScript; +import org.apache.sysml.api.jmlc.ResultVariables; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +/** + * + * + */ +public class FrameEncodeTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "transform7"; + private final static String TEST_DIR = "functions/jmlc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + FrameEncodeTest.class.getSimpleName() + "/"; + + private final static int rows = 700; + private final static int cols = 3; + + private final static int nRuns = 2; + + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.1; + + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "F2" }) ); + } + + @Test + public void testJMLCTransformDense() throws IOException { + runJMLCReuseTest(TEST_NAME1, false, false); + } + + @Test + public void testJMLCTransformSparse() throws IOException { + runJMLCReuseTest(TEST_NAME1, true, false); + } + + @Test + public void testJMLCTransformDenseReuse() throws IOException { + runJMLCReuseTest(TEST_NAME1, false, true); + } + + @Test + public void testJMLCTransformSparseReuse() throws IOException { + runJMLCReuseTest(TEST_NAME1, true, true); + } + + /** + * + * @param sparseM1 + * @param sparseM2 + * @param instType + * @throws IOException + */ + private void runJMLCReuseTest( String testname, boolean sparse, boolean modelReuse ) + throws IOException + { + String TEST_NAME = testname; + + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + //generate inputs + double[][] Fd = TestUtils.round(getRandomMatrix(rows, cols, 0.51, 7.49, sparse?sparsity2:sparsity1, 1234)); + String[][] F1s = FrameTransformTest.createFrameData(Fd); + + //run DML via JMLC + ArrayList<String[][]> F2set = execDMLScriptviaJMLC( TEST_NAME, F1s, modelReuse ); + + //check correct result + for( String[][] data : F2set ) + for( int i=0; i<F1s.length; i++ ) + for( int j=0; j<F1s[i].length; j++ ) + Assert.assertEquals("Wrong result: "+data[i][j]+".", data[i][j], F1s[i][j]); + } + + /** + * + * @param X + * @return + * @throws DMLException + * @throws IOException + */ + private ArrayList<String[][]> execDMLScriptviaJMLC( String testname, String[][] F1, boolean modelReuse) + throws IOException + { + Timing time = new Timing(true); + + ArrayList<String[][]> ret = new ArrayList<String[][]>(); + + //establish connection to SystemML + Connection conn = new Connection(); + + try + { + //prepare input arguments + HashMap<String,String> args = new HashMap<String,String>(); + args.put("$TRANSFORM_SPEC", "{ \"ids\": true ,\"recode\": [ 1, 2, 3] }"); + + //read and precompile script + String script = conn.readScript(SCRIPT_DIR + TEST_DIR + testname + ".dml"); + PreparedScript pstmt = conn.prepareScript(script, args, new String[]{"F1","M"}, new String[]{"F2"}, false); + + if( modelReuse ) + pstmt.setFrame("F1", F1, true); + + //execute script multiple times + for( int i=0; i<nRuns; i++ ) + { + //bind input parameters + if( !modelReuse ) + pstmt.setFrame("F1", F1); + + //execute script + ResultVariables rs = pstmt.executeScript(); + + //get output parameter + String[][] Y = rs.getFrame("F2"); + ret.add(Y); //keep result for comparison + } + } + catch(Exception ex) + { + ex.printStackTrace(); + throw new IOException(ex); + } + finally + { + if( conn != null ) + conn.close(); + } + + System.out.println("JMLC scoring w/ "+nRuns+" runs in "+time.stop()+"ms."); + + return ret; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b1451c0c/src/test/scripts/functions/jmlc/transform7.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/jmlc/transform7.dml b/src/test/scripts/functions/jmlc/transform7.dml new file mode 100644 index 0000000..a3dfca7 --- /dev/null +++ b/src/test/scripts/functions/jmlc/transform7.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($F1, data_type="frame", format="csv"); +specJson = $TRANSFORM_SPEC + +[X, M] = transformencode(target=F1, spec=specJson); + +X = X * (X!=77.7); + +F2 = transformdecode(target=X, meta=M, spec=specJson); +write(F2, $F2); +
