Repository: incubator-systemml Updated Branches: refs/heads/master bcd5d96fc -> 98f509944
[SYSTEMML-535] Add optional printf formatting to print function Overload print function to allow Java-based printf formatting. This requires allowing a variable number of operands. Create MultipleOp Hop, MultipleCP Lop, BuiltinMultipleCPInstruction, and ScalarBuiltinMultipleCPInstruction classes and add related code. Closes #317. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/98f50994 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/98f50994 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/98f50994 Branch: refs/heads/master Commit: 98f509944c93ef5418af8b6dd25d8f59d65506fa Parents: bcd5d96 Author: Deron Eriksson <[email protected]> Authored: Fri Jan 6 13:51:47 2017 -0800 Committer: Deron Eriksson <[email protected]> Committed: Fri Jan 6 13:51:47 2017 -0800 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 19 +- .../java/org/apache/sysml/hops/MultipleOp.java | 193 +++++++++++++++++++ src/main/java/org/apache/sysml/lops/Lop.java | 4 +- .../org/apache/sysml/lops/LopProperties.java | 14 +- .../java/org/apache/sysml/lops/MultipleCP.java | 128 ++++++++++++ .../java/org/apache/sysml/lops/compile/Dag.java | 15 +- .../org/apache/sysml/parser/DMLTranslator.java | 68 +++++-- .../sysml/parser/ParForStatementBlock.java | 6 +- .../org/apache/sysml/parser/PrintStatement.java | 104 ++++++---- .../org/apache/sysml/parser/StatementBlock.java | 20 +- .../parser/common/CommonSyntacticValidator.java | 67 +++++-- .../sysml/runtime/functionobjects/Builtin.java | 16 +- .../instructions/CPInstructionParser.java | 9 +- .../cp/BuiltinMultipleCPInstruction.java | 79 ++++++++ .../runtime/instructions/cp/CPInstruction.java | 2 +- .../runtime/instructions/cp/CPOperand.java | 13 ++ .../cp/ScalarBuiltinMultipleCPInstruction.java | 101 ++++++++++ .../integration/mlcontext/MLContextTest.java | 171 ++++++++++++++++ 18 files changed, 934 insertions(+), 95 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 4202fe3..3aa3dab 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -34,6 +34,7 @@ import org.apache.sysml.lops.Data; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.lops.LopsException; +import org.apache.sysml.lops.MultipleCP; import org.apache.sysml.lops.ReBlock; import org.apache.sysml.lops.UnaryCP; import org.apache.sysml.parser.Expression.DataType; @@ -1024,6 +1025,10 @@ public abstract class Hop WUMM //weighted unary mm }; + // Operations that require a variable number of operands + public enum MultipleOperandOperation { + PRINTF + } public enum AggOp { SUM, SUM_SQ, MIN, MAX, TRACE, PROD, MEAN, VAR, MAXINDEX, MININDEX @@ -1259,6 +1264,18 @@ public abstract class Hop HopsOpOp1LopsUS.put(OpOp1.STOP, org.apache.sysml.lops.UnaryCP.OperationTypes.STOP); } + /** + * Maps from a multiple (variable number of operands) Hop operation type to + * the corresponding Lop operation type. This is called in the MultipleOp + * constructLops() method that is used to construct the Lops that correspond + * to a Hop. + */ + protected static final HashMap<MultipleOperandOperation, MultipleCP.OperationType> MultipleOperandOperationHopTypeToLopType; + static { + MultipleOperandOperationHopTypeToLopType = new HashMap<MultipleOperandOperation, MultipleCP.OperationType>(); + MultipleOperandOperationHopTypeToLopType.put(MultipleOperandOperation.PRINTF, MultipleCP.OperationType.PRINTF); + } + protected static final HashMap<Hop.OpOp1, String> HopsOpOp12String; static { HopsOpOp12String = new HashMap<OpOp1, String>(); @@ -1287,7 +1304,7 @@ public abstract class Hop HopsOpOp12String.put(OpOp1.SPROP, "sprop"); HopsOpOp12String.put(OpOp1.SIGMOID, "sigmoid"); } - + protected static final HashMap<Hop.ParamBuiltinOp, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes> HopsParameterizedBuiltinLops; static { HopsParameterizedBuiltinLops = new HashMap<Hop.ParamBuiltinOp, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes>(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/hops/MultipleOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/MultipleOp.java b/src/main/java/org/apache/sysml/hops/MultipleOp.java new file mode 100644 index 0000000..d850b8a --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/MultipleOp.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops; + +import java.util.ArrayList; + +import org.apache.sysml.lops.Lop; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.LopsException; +import org.apache.sysml.lops.MultipleCP; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; + +/** + * The MultipleOp Hop allows for a variable number of operands. Functionality + * such as 'printf' (overloaded into the existing print function) is an example + * of an operation that potentially takes a variable number of operands. + * + */ +public class MultipleOp extends Hop { + protected MultipleOperandOperation multipleOperandOperation = null; + + protected MultipleOp() { + } + + /** + * MultipleOp constructor. + * + * @param name + * the target name, typically set by the DMLTranslator when + * constructing Hops. (For example, 'parsertemp1'.) + * @param dataType + * the target data type (SCALAR for printf) + * @param valueType + * the target value type (STRING for printf) + * @param multipleOperandOperation + * the operation type (such as PRINTF) + * @param inputs + * a variable number of input Hops + * @throws HopsException + * thrown if a HopsException occurs + */ + public MultipleOp(String name, DataType dataType, ValueType valueType, + MultipleOperandOperation multipleOperandOperation, Hop... inputs) throws HopsException { + super(name, dataType, valueType); + this.multipleOperandOperation = multipleOperandOperation; + + for (int i = 0; i < inputs.length; i++) { + getInput().add(i, inputs[i]); + inputs[i].getParent().add(this); + } + + // compute unknown dims and nnz + refreshSizeInformation(); + } + + public MultipleOperandOperation getOp() { + return multipleOperandOperation; + } + + public void printMe() throws HopsException { + if (LOG.isDebugEnabled()) { + if (getVisited() != VisitStatus.DONE) { + super.printMe(); + LOG.debug(" Operation: " + multipleOperandOperation); + for (Hop h : getInput()) { + h.printMe(); + } + } + setVisited(VisitStatus.DONE); + } + } + + @Override + public String getOpString() { + return "m(" + multipleOperandOperation.toString().toLowerCase() + ")"; + } + + /** + * Construct the corresponding Lops for this Hop + */ + @Override + public Lop constructLops() throws HopsException, LopsException { + // reuse existing lop + if (getLops() != null) + return getLops(); + + try { + ArrayList<Hop> inHops = getInput(); + Lop[] inLops = new Lop[inHops.size()]; + for (int i = 0; i < inHops.size(); i++) { + Hop inHop = inHops.get(i); + Lop inLop = inHop.constructLops(); + inLops[i] = inLop; + } + + MultipleCP.OperationType opType = MultipleOperandOperationHopTypeToLopType.get(multipleOperandOperation); + if (opType == null) { + throw new HopsException("Unknown MultipleCP Lop operation type for MultipleOperandOperation Hop type '" + + multipleOperandOperation + "'"); + } + + MultipleCP multipleCPLop = new MultipleCP(opType, getDataType(), getValueType(), inLops); + setOutputDimensions(multipleCPLop); + setLineNumbers(multipleCPLop); + setLops(multipleCPLop); + } catch (Exception e) { + throw new HopsException(this.printErrorLocation() + "error constructing Lops for MultipleOp Hop -- \n ", e); + } + + // add reblock/checkpoint lops if necessary + constructAndSetLopsDataFlowProperties(); + + return getLops(); + } + + @Override + protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { + double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + } + + @Override + public boolean allowsAllExecTypes() { + return false; // true? + } + + @Override + protected ExecType optFindExecType() throws HopsException { + checkAndSetForcedPlatform(); // ? + return ExecType.CP; + } + + @Override + public void refreshSizeInformation() { + // do nothing + } + + @Override + public Object clone() throws CloneNotSupportedException { + MultipleOp multipleOp = new MultipleOp(); + + // copy generic attributes + multipleOp.clone(this, false); + + // copy specific attributes + multipleOp.multipleOperandOperation = multipleOperandOperation; + + return multipleOp; + } + + @Override + public boolean compare(Hop that) { + if (!(that instanceof MultipleOp)) + return false; + + if (multipleOperandOperation == MultipleOperandOperation.PRINTF) { + return false; + } + + // if add new multiple operand types in addition to PRINTF, + // probably need to modify this. + MultipleOp mo = (MultipleOp) that; + return (multipleOperandOperation == mo.multipleOperandOperation); + } + + @Override + protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + @Override + protected long[] inferOutputCharacteristics(MemoTable memo) { + return null; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/lops/Lop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java index a5ed9e9..ad25970 100644 --- a/src/main/java/org/apache/sysml/lops/Lop.java +++ b/src/main/java/org/apache/sysml/lops/Lop.java @@ -59,6 +59,8 @@ public abstract class Lop SortKeys, PickValues, Checkpoint, //Spark persist into storage level PlusMult, MinusMult, //CP + /** CP operation on a variable number of operands */ + MULTIPLE_CP }; /** @@ -353,7 +355,7 @@ public abstract class Lop } /** - * Method to get the execution type (CP or MR) of LOP + * Method to get the execution type (CP, CP_FILE, MR, SPARK, GPU, INVALID) of LOP * * @return execution type */ http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/lops/LopProperties.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/LopProperties.java b/src/main/java/org/apache/sysml/lops/LopProperties.java index 48e8f57..5720487 100644 --- a/src/main/java/org/apache/sysml/lops/LopProperties.java +++ b/src/main/java/org/apache/sysml/lops/LopProperties.java @@ -26,7 +26,19 @@ import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; public class LopProperties { - + /** + * Execution types: + * + * <ul> + * <li>CP - Control Program (single JVM)</li> + * <li>CP_FILE - ?</li> + * <li>MR - Apache Hadoop</li> + * <li>SPARK - Apache Spark</li> + * <li>GPU - Execute on a GPU</li> + * <li>INVALID - Invalid execution type</li> + * </ul> + * + */ public enum ExecType { CP, CP_FILE, MR, SPARK, GPU, INVALID }; public enum ExecLocation {INVALID, RecordReader, Map, MapOrReduce, MapAndReduce, Reduce, Data, ControlProgram }; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/lops/MultipleCP.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/MultipleCP.java b/src/main/java/org/apache/sysml/lops/MultipleCP.java new file mode 100644 index 0000000..74a6070 --- /dev/null +++ b/src/main/java/org/apache/sysml/lops/MultipleCP.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.lops; + +import org.apache.sysml.lops.LopProperties.ExecLocation; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.compile.JobType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; + +/** + * Lop to perform an operation on a variable number of operands. + * + */ +public class MultipleCP extends Lop { + + public enum OperationType { + PRINTF + }; + + OperationType operationType; + + public MultipleCP(OperationType operationType, DataType dt, ValueType vt, Lop... inputLops) { + super(Lop.Type.MULTIPLE_CP, dt, vt); + this.operationType = operationType; + for (Lop inputLop : inputLops) { + addInput(inputLop); + inputLop.addOutput(this); + } + + boolean breaksAlignment = false; // ? + boolean aligner = false; // ? + boolean definesMRJob = false; // ? + lps.addCompatibility(JobType.INVALID); // ? + this.lps.setProperties(inputs, ExecType.CP, ExecLocation.ControlProgram, breaksAlignment, aligner, + definesMRJob); // ? + } + + @Override + public String toString() { + return "Operation Type: " + operationType; + } + + public OperationType getOperationType() { + return operationType; + } + + /** + * Generate the complete instruction string for this Lop. This instruction + * string can have a variable number of input operands. It displays the + * following: + * + * <ul> + * <li>Execution type (CP, SPARK, etc.) + * <li>Operand delimiter (°)</li> + * <li>Opcode (printf, etc.)</li> + * <li>Operand delimiter (°)</li> + * <li>Variable number of inputs, each followed by an operand delimiter + * (°)</li> + * <ul> + * <li>Input consists of (label · data type · value type + * · is literal)</li> + * </ul> + * <li>Output consisting of (label · data type · value + * type)</li> + * </ul> + * + * Example: <br> + * The following DML<br> + * <code>print('hello %s', 'world')</code><br> + * generates the instruction string:<br> + * <code>CP°printf°hello %s·SCALAR·STRING·true°world·SCALAR·STRING·true°_Var1·SCALAR·STRING</code><br> + * + * Note: This generated instruction string is parsed in the + * parseInstruction() method of BuiltinMultipleCPInstruction, which parses + * the instruction string to generate an instruction object that is a + * subclass of BuiltinMultipleCPInstruction. + */ + @Override + public String getInstructions(String output) throws LopsException { + String opString = getOpcode(); + + StringBuilder sb = new StringBuilder(); + + sb.append(getExecType()); + sb.append(Lop.OPERAND_DELIMITOR); + + sb.append(opString); + sb.append(OPERAND_DELIMITOR); + + for (Lop input : inputs) { + sb.append(input.prepScalarInputOperand(getExecType())); + sb.append(OPERAND_DELIMITOR); + } + + sb.append(prepOutputOperand(output)); + + return sb.toString(); + } + + private String getOpcode() throws LopsException { + switch (operationType) { + case PRINTF: + return OperationType.PRINTF.toString().toLowerCase(); + default: + throw new UnsupportedOperationException( + "MultipleCP operation type (" + operationType + ") is not defined."); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/lops/compile/Dag.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java index 8a2fa1c..8f17b2c 100644 --- a/src/main/java/org/apache/sysml/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java @@ -43,9 +43,6 @@ import org.apache.sysml.lops.AppendM; import org.apache.sysml.lops.BinaryM; import org.apache.sysml.lops.CombineBinary; import org.apache.sysml.lops.Data; -import org.apache.sysml.lops.PMMJ; -import org.apache.sysml.lops.ParameterizedBuiltin; -import org.apache.sysml.lops.SortKeys; import org.apache.sysml.lops.Data.OperationTypes; import org.apache.sysml.lops.FunctionCallCP; import org.apache.sysml.lops.Lop; @@ -56,12 +53,15 @@ import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.MapMult; import org.apache.sysml.lops.OutputParameters; import org.apache.sysml.lops.OutputParameters.Format; +import org.apache.sysml.lops.PMMJ; +import org.apache.sysml.lops.ParameterizedBuiltin; import org.apache.sysml.lops.PickByCount; +import org.apache.sysml.lops.SortKeys; import org.apache.sysml.lops.Unary; import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Expression; -import org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression; import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; @@ -70,11 +70,11 @@ import org.apache.sysml.runtime.instructions.CPInstructionParser; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.Instruction.INSTRUCTION_TYPE; import org.apache.sysml.runtime.instructions.InstructionParser; +import org.apache.sysml.runtime.instructions.MRJobInstruction; import org.apache.sysml.runtime.instructions.SPInstructionParser; import org.apache.sysml.runtime.instructions.cp.CPInstruction; -import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE; -import org.apache.sysml.runtime.instructions.MRJobInstruction; +import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.OutputInfo; @@ -1427,6 +1427,9 @@ public class Dag<N extends Lop> inst_string = node.getInstructions(inputs, outputs); } + else if (node.getType() == Lop.Type.MULTIPLE_CP) { // ie, MultipleCP class + inst_string = node.getInstructions(node.getOutputParameters().getLabel()); + } else { if ( node.getInputs().isEmpty() ) { // currently, such a case exists only for Rand lop http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 545fe12..1d92e25 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; +import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +31,7 @@ import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.ConvolutionOp; import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.FunctionOp; @@ -48,25 +50,25 @@ import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.MemoTable; +import org.apache.sysml.hops.MultipleOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.ipa.InterProceduralAnalysis; -import org.apache.sysml.hops.rewrite.ProgramRewriter; import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; +import org.apache.sysml.hops.rewrite.ProgramRewriter; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopsException; +import org.apache.sysml.parser.Expression.BuiltinFunctionOp; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.FormatType; import org.apache.sysml.parser.Expression.ParameterizedBuiltinFunctionOp; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.parser.PrintStatement.PRINTTYPE; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.hops.ConvolutionOp; -import org.apache.sysml.hops.rewrite.HopRewriteUtils; -import org.apache.sysml.parser.Expression.BuiltinFunctionOp; public class DMLTranslator @@ -1014,27 +1016,57 @@ public class DMLTranslator } if (current instanceof PrintStatement) { - PrintStatement ps = (PrintStatement) current; - Expression source = ps.getExpression(); - PRINTTYPE ptype = ps.getType(); - DataIdentifier target = createTarget(); target.setDataType(DataType.SCALAR); target.setValueType(ValueType.STRING); - target.setAllPositions(current.getFilename(), current.getBeginLine(), target.getBeginColumn(), current.getEndLine(), current.getEndColumn()); - - Hop ae = processExpression(source, target, ids); - + target.setAllPositions(current.getFilename(), current.getBeginLine(), target.getBeginColumn(), + current.getEndLine(), current.getEndColumn()); + + PrintStatement ps = (PrintStatement) current; + PRINTTYPE ptype = ps.getType(); + try { - Hop.OpOp1 op = (ptype == PRINTTYPE.PRINT ? Hop.OpOp1.PRINT : Hop.OpOp1.STOP); - Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae); - printHop.setAllPositions(current.getBeginLine(), current.getBeginColumn(), current.getEndLine(), current.getEndColumn()); - output.add(printHop); - } catch ( HopsException e ) { + if (ptype == PRINTTYPE.PRINT) { + Hop.OpOp1 op = Hop.OpOp1.PRINT; + Expression source = ps.getExpressions().get(0); + Hop ae = processExpression(source, target, ids); + Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, + ae); + printHop.setAllPositions(current.getBeginLine(), current.getBeginColumn(), current.getEndLine(), + current.getEndColumn()); + output.add(printHop); + } else if (ptype == PRINTTYPE.STOP) { + Hop.OpOp1 op = Hop.OpOp1.STOP; + Expression source = ps.getExpressions().get(0); + Hop ae = processExpression(source, target, ids); + Hop stopHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, + ae); + stopHop.setAllPositions(current.getBeginLine(), current.getBeginColumn(), current.getEndLine(), + current.getEndColumn()); + output.add(stopHop); + } else if (ptype == PRINTTYPE.PRINTF) { + Hop.MultipleOperandOperation printfOperation = Hop.MultipleOperandOperation.PRINTF; + List<Expression> expressions = ps.getExpressions(); + Hop[] inHops = new Hop[expressions.size()]; + // process the expressions (function parameters) that + // make up the printf-styled print statement + // into Hops so that these can be passed to the printf + // Hop (ie, MultipleOp) as input Hops + for (int j = 0; j < expressions.size(); j++) { + Hop inHop = processExpression(expressions.get(j), target, ids); + inHops[j] = inHop; + } + target.setValueType(ValueType.STRING); + Hop printfHop = new MultipleOp(target.getName(), target.getDataType(), target.getValueType(), + printfOperation, inHops); + output.add(printfHop); + } + + } catch (HopsException e) { throw new LanguageException(e); } } - + if (current instanceof AssignmentStatement) { AssignmentStatement as = (AssignmentStatement) current; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java index 627f1d4..bd92a15 100644 --- a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java @@ -876,7 +876,11 @@ public class ParForStatementBlock extends ForStatementBlock else if (s instanceof PrintStatement) { PrintStatement s2 = (PrintStatement)s; - ret = rGetDataIdentifiers(s2.getExpression()); + ret = new ArrayList<DataIdentifier>(); + for (Expression expression : s2.getExpressions()) { + List<DataIdentifier> dataIdentifiers = rGetDataIdentifiers(expression); + ret.addAll(dataIdentifiers); + } } //potentially extend this list with other Statements if required http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/parser/PrintStatement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/PrintStatement.java b/src/main/java/org/apache/sysml/parser/PrintStatement.java index 019d02a..463b02d 100644 --- a/src/main/java/org/apache/sysml/parser/PrintStatement.java +++ b/src/main/java/org/apache/sysml/parser/PrintStatement.java @@ -19,20 +19,35 @@ package org.apache.sysml.parser; +import java.util.ArrayList; +import java.util.List; + import org.apache.sysml.api.DMLScript; import org.apache.sysml.debug.DMLBreakpointManager; public class PrintStatement extends Statement { - public enum PRINTTYPE {PRINT, STOP}; - - protected PRINTTYPE _type; // print or stop - protected Expression _expr; + /** + * The PRINTTYPE options are: PRINT, PRINTF, and STOP. + * <p> + * Note that PRINTF functionality is overloaded onto the existing 'print' + * built-in function. + */ + public enum PRINTTYPE { + PRINT, PRINTF, STOP + }; + + protected PRINTTYPE _type; // print, printf, or stop + protected List<Expression> expressions; - private static PRINTTYPE getPrintType(String type) throws LanguageException { + private static PRINTTYPE getPrintType(String type, List<Expression> expressions) throws LanguageException { if(type.equalsIgnoreCase("print")) { - return PRINTTYPE.PRINT; + if (expressions.size() == 1) { + return PRINTTYPE.PRINT; + } else { + return PRINTTYPE.PRINTF; + } } else if (type.equalsIgnoreCase("stop")) { return PRINTTYPE.STOP; @@ -40,26 +55,30 @@ public class PrintStatement extends Statement else throw new LanguageException("Unknown statement type: " + type); } - - public PrintStatement(String type, Expression expr, int beginLine, int beginCol, int endLine, int endCol) - throws LanguageException - { - this(getPrintType(type), expr); - + + public PrintStatement(String type, List<Expression> expressions, int beginLine, int beginCol, + int endLine, int endCol) throws LanguageException { + this(getPrintType(type, expressions), expressions); + setBeginLine(beginLine); setBeginColumn(beginCol); setEndLine(endLine); setEndColumn(endCol); } - - public PrintStatement(PRINTTYPE type, Expression expr) throws LanguageException{ + + public PrintStatement(PRINTTYPE type, List<Expression> expressions) + throws LanguageException { _type = type; - _expr = expr; + this.expressions = expressions; } - + public Statement rewriteStatement(String prefix) throws LanguageException{ - Expression newExpr = _expr.rewriteExpression(prefix); - PrintStatement retVal = new PrintStatement(_type, newExpr); + List<Expression> newExpressions = new ArrayList<Expression>(); + for (Expression oldExpression : expressions) { + Expression newExpression = oldExpression.rewriteExpression(prefix); + newExpressions.add(newExpression); + } + PrintStatement retVal = new PrintStatement(_type, newExpressions); retVal.setBeginLine(this.getBeginLine()); retVal.setBeginColumn(this.getBeginColumn()); retVal.setEndLine(this.getEndLine()); @@ -74,22 +93,34 @@ public class PrintStatement extends Statement return lo; } - - public String toString(){ - StringBuilder sb = new StringBuilder(); - sb.append(_type + " (" ); - if (_expr != null){ - sb.append(_expr.toString()); - } - sb.append(");"); - return sb.toString(); + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(_type + " ("); + if ((_type == PRINTTYPE.PRINT) || (_type == PRINTTYPE.STOP)) { + sb.append(expressions.get(0).toString()); + } else if (_type == PRINTTYPE.PRINTF) { + for (int i = 0; i < expressions.size(); i++) { + if (i > 0) { + sb.append(", "); + } + Expression expression = expressions.get(i); + sb.append(expression.toString()); + } + } + + sb.append(");"); + return sb.toString(); } @Override public VariableSet variablesRead() { - VariableSet result = _expr.variablesRead(); - return result; + VariableSet variableSet = new VariableSet(); + for (Expression expression : expressions) { + VariableSet variablesRead = expression.variablesRead(); + variableSet.addVariables(variablesRead); + } + return variableSet; } @Override @@ -101,7 +132,7 @@ public class PrintStatement extends Statement public boolean controlStatement() { // ensure that breakpoints end up in own statement block if (DMLScript.ENABLE_DEBUG_MODE) { - DMLBreakpointManager.insertBreakpoint(_expr.getBeginLine()); + DMLBreakpointManager.insertBreakpoint(expressions.get(0).getBeginLine()); return true; } @@ -112,12 +143,15 @@ public class PrintStatement extends Statement return false; } - public Expression getExpression(){ - return _expr; - } - public PRINTTYPE getType() { return _type; } - + + public List<Expression> getExpressions() { + return expressions; + } + + public void setExpressions(List<Expression> expressions) { + this.expressions = expressions; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index 73fdd94..74e707a 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -22,6 +22,7 @@ package org.apache.sysml.parser; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -727,18 +728,19 @@ public class StatementBlock extends LiveVariableAnalysis } else if(current instanceof ForStatement || current instanceof IfStatement || current instanceof WhileStatement ){ - raiseValidateError("control statement (CVStatement, ELStatement, WhileStatement, IfStatement, ForStatement) should not be in genreric statement block. Likely a parsing error", conditional); + raiseValidateError("control statement (WhileStatement, IfStatement, ForStatement) should not be in generic statement block. Likely a parsing error", conditional); } - - else if (current instanceof PrintStatement){ + + else if (current instanceof PrintStatement) { PrintStatement pstmt = (PrintStatement) current; - Expression expr = pstmt.getExpression(); - expr.validateExpression(ids.getVariables(), currConstVars, conditional); - - // check that variables referenced in print statement expression are scalars - if (expr.getOutput().getDataType() != Expression.DataType.SCALAR){ - raiseValidateError("print statement can only print scalars", conditional); + List<Expression> expressions = pstmt.getExpressions(); + for (Expression expression : expressions) { + expression.validateExpression(ids.getVariables(), currConstVars, conditional); + if (expression.getOutput().getDataType() != Expression.DataType.SCALAR) { + raiseValidateError("print statement can only print scalars", conditional); + } } + } // no work to perform for PathStatement or ImportStatement http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java index 43eabcb..e114f5e 100644 --- a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java @@ -22,6 +22,7 @@ package org.apache.sysml.parser.common; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.regex.Pattern; @@ -486,23 +487,57 @@ public abstract class CommonSyntacticValidator { protected void setPrintStatement(ParserRuleContext ctx, String functionName, ArrayList<ParameterExpression> paramExpression, StatementInfo thisinfo) { - if(paramExpression.size() != 1) { - notifyErrorListeners(functionName + "() has only one parameter", ctx.start); - return; - } - Expression expr = paramExpression.get(0).getExpr(); - if(expr == null) { - notifyErrorListeners("cannot process " + functionName + "() function", ctx.start); - return; - } - try { - int line = ctx.start.getLine(); - int col = ctx.start.getCharPositionInLine(); - thisinfo.stmt = new PrintStatement(functionName, expr, line, col, line, col); - setFileLineColumn(thisinfo.stmt, ctx); - } catch (LanguageException e) { - notifyErrorListeners("cannot process " + functionName + "() function", ctx.start); + int numParams = paramExpression.size(); + if (numParams == 0) { + notifyErrorListeners(functionName + "() must have more than 0 parameters", ctx.start); return; + } else if (numParams == 1) { + Expression expr = paramExpression.get(0).getExpr(); + if(expr == null) { + notifyErrorListeners("cannot process " + functionName + "() function", ctx.start); + return; + } + try { + int line = ctx.start.getLine(); + int col = ctx.start.getCharPositionInLine(); + ArrayList<Expression> expList = new ArrayList<Expression>(); + expList.add(expr); + thisinfo.stmt = new PrintStatement(functionName, expList, line, col, line, col); + setFileLineColumn(thisinfo.stmt, ctx); + } catch (LanguageException e) { + notifyErrorListeners("cannot process " + functionName + "() function", ctx.start); + return; + } + } else if (numParams > 1) { + if ("stop".equals(functionName)) { + notifyErrorListeners("stop() function cannot have more than 1 parameter", ctx.start); + return; + } + + Expression firstExp = paramExpression.get(0).getExpr(); + if (firstExp == null) { + notifyErrorListeners("cannot process " + functionName + "() function", ctx.start); + return; + } + if (!(firstExp instanceof StringIdentifier)) { + notifyErrorListeners("printf-style functionality requires first print parameter to be a string", ctx.start); + return; + } + try { + int line = ctx.start.getLine(); + int col = ctx.start.getCharPositionInLine(); + + List<Expression> expressions = new ArrayList<Expression>(); + for (ParameterExpression pe : paramExpression) { + Expression expression = pe.getExpr(); + expressions.add(expression); + } + thisinfo.stmt = new PrintStatement(functionName, expressions, line, col, line, col); + setFileLineColumn(thisinfo.stmt, ctx); + } catch (LanguageException e) { + notifyErrorListeners("cannot process " + functionName + "() function", ctx.start); + return; + } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java index dad8c2d..9673f47 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java @@ -49,7 +49,7 @@ public class Builtin extends ValueFunction private static final long serialVersionUID = 3836744687789840574L; - public enum BuiltinCode { SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP }; + public enum BuiltinCode { SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP }; public BuiltinCode bFunc; private static final boolean FASTMATH = true; @@ -76,6 +76,7 @@ public class Builtin extends ValueFunction String2BuiltinCode.put( "exp" , BuiltinCode.EXP); String2BuiltinCode.put( "plogp" , BuiltinCode.PLOGP); String2BuiltinCode.put( "print" , BuiltinCode.PRINT); + String2BuiltinCode.put( "printf" , BuiltinCode.PRINTF); String2BuiltinCode.put( "nrow" , BuiltinCode.NROW); String2BuiltinCode.put( "ncol" , BuiltinCode.NCOL); String2BuiltinCode.put( "length" , BuiltinCode.LENGTH); @@ -96,7 +97,7 @@ public class Builtin extends ValueFunction // We should create one object for every builtin function that we support private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null, atanObj = null; private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj=null; - private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null; + private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null, printfObj; private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj=null, floorObj=null; private static Builtin inverseObj=null, cumsumObj=null, cumprodObj=null, cumminObj=null, cummaxObj=null; private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null; @@ -195,6 +196,11 @@ public class Builtin extends ValueFunction if ( printObj == null ) printObj = new Builtin(BuiltinCode.PRINT); return printObj; + case PRINTF: + if (printfObj == null) { + printfObj = new Builtin(BuiltinCode.PRINTF); + } + return printfObj; case NROW: if ( nrowObj == null ) nrowObj = new Builtin(BuiltinCode.NROW); @@ -450,7 +456,7 @@ public class Builtin extends ValueFunction } } - // currently, it is used only for PRINT and STOP + // currently, it is used only for PRINT, PRINTF and STOP public String execute (String in1) throws DMLRuntimeException { @@ -459,6 +465,10 @@ public class Builtin extends ValueFunction if (!DMLScript.suppressPrint2Stdout()) System.out.println(in1); return null; + case PRINTF: + if (!DMLScript.suppressPrint2Stdout()) + System.out.println(in1); + return null; case STOP: throw new DMLScriptException(in1); default: http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 18e3a48..3355a6e 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -23,8 +23,8 @@ package org.apache.sysml.runtime.instructions; import java.util.HashMap; import org.apache.sysml.lops.DataGen; -import org.apache.sysml.lops.UnaryCP; import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.UnaryCP; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.cp.AggregateBinaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.AggregateTernaryCPInstruction; @@ -34,8 +34,10 @@ import org.apache.sysml.runtime.instructions.cp.ArithmeticBinaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.BooleanBinaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.BooleanUnaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.BuiltinBinaryCPInstruction; +import org.apache.sysml.runtime.instructions.cp.BuiltinMultipleCPInstruction; import org.apache.sysml.runtime.instructions.cp.BuiltinUnaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.CPInstruction; +import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE; import org.apache.sysml.runtime.instructions.cp.CentralMomentCPInstruction; import org.apache.sysml.runtime.instructions.cp.CompressionCPInstruction; import org.apache.sysml.runtime.instructions.cp.ConvolutionCPInstruction; @@ -62,7 +64,6 @@ import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction; import org.apache.sysml.runtime.instructions.cp.TernaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.UaggOuterChainCPInstruction; import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; -import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE; import org.apache.sysml.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction; import org.apache.sysml.runtime.instructions.cpfile.ParameterizedBuiltinCPFileInstruction; @@ -178,6 +179,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "sigmoid", CPINSTRUCTION_TYPE.BuiltinUnary); String2CPInstructionType.put( "sel+", CPINSTRUCTION_TYPE.BuiltinUnary); + String2CPInstructionType.put( "printf" , CPINSTRUCTION_TYPE.BuiltinMultiple); // Parameterized Builtin Functions String2CPInstructionType.put( "cdf" , CPINSTRUCTION_TYPE.ParameterizedBuiltin); @@ -332,7 +334,8 @@ public class CPInstructionParser extends InstructionParser case BuiltinUnary: return BuiltinUnaryCPInstruction.parseInstruction(str); - + case BuiltinMultiple: + return BuiltinMultipleCPInstruction.parseInstruction(str); case Reorg: return ReorgCPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/runtime/instructions/cp/BuiltinMultipleCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/BuiltinMultipleCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/BuiltinMultipleCPInstruction.java new file mode 100644 index 0000000..bab3148 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/BuiltinMultipleCPInstruction.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import java.util.Arrays; + +import org.apache.sysml.lops.MultipleCP; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.functionobjects.Builtin; +import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.runtime.matrix.operators.SimpleOperator; + +/** + * Instruction to handle a variable number of input operands. It parses an + * instruction string to generate an object that is a subclass of + * BuiltinMultipleCPInstruction. Currently the only subclass of + * BuiltinMultipleCPInstruction is ScalarBuiltinMultipleCPInstruction. The + * ScalarBuiltinMultipleCPInstruction class is responsible for printf-style + * Java-based string formatting. + * + */ +public abstract class BuiltinMultipleCPInstruction extends CPInstruction { + + public CPOperand output; + public CPOperand[] inputs; + + public BuiltinMultipleCPInstruction(Operator op, String opcode, String istr, CPOperand output, + CPOperand... inputs) { + super(op, opcode, istr); + _cptype = CPINSTRUCTION_TYPE.BuiltinMultiple; + this.output = output; + this.inputs = inputs; + } + + public static BuiltinMultipleCPInstruction parseInstruction(String str) throws DMLRuntimeException { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + + String opcode = parts[0]; + + String outputString = parts[parts.length - 1]; + CPOperand outputOperand = new CPOperand(outputString); + + String[] inputStrings = null; + CPOperand[] inputOperands = null; + if (parts.length > 2) { + inputStrings = Arrays.copyOfRange(parts, 1, parts.length - 1); + inputOperands = new CPOperand[parts.length - 2]; + for (int i = 0; i < inputStrings.length; i++) { + inputOperands[i] = new CPOperand(inputStrings[i]); + } + } + + if (MultipleCP.OperationType.PRINTF.toString().equalsIgnoreCase(opcode)) { + ValueFunction func = Builtin.getBuiltinFnObject(opcode); + return new ScalarBuiltinMultipleCPInstruction(new SimpleOperator(func), opcode, str, outputOperand, + inputOperands); + } + throw new DMLRuntimeException("Opcode (" + opcode + ") not recognized in BuiltinMultipleCPInstruction"); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java index 158e26f..a63202a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java @@ -30,7 +30,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class CPInstruction extends Instruction { - public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution }; + public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, BuiltinMultiple, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution }; protected CPINSTRUCTION_TYPE _cptype; protected Operator _optr; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java index 345aa7b..5fee5ed 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java @@ -19,6 +19,8 @@ package org.apache.sysml.runtime.instructions.cp; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.commons.lang.builder.ToStringStyle; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.instructions.Instruction; @@ -104,4 +106,15 @@ public class CPOperand } } + public void copy(CPOperand o){ + _name = o.getName(); + _valueType = o.getValueType(); + _dataType = o.getDataType(); + } + + @Override + public String toString() { + return ToStringBuilder.reflectionToString(this, ToStringStyle.SHORT_PREFIX_STYLE); + } + } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarBuiltinMultipleCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarBuiltinMultipleCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarBuiltinMultipleCPInstruction.java new file mode 100644 index 0000000..9033228 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarBuiltinMultipleCPInstruction.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.lops.MultipleCP; +import org.apache.sysml.parser.Expression; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.matrix.operators.Operator; + +/** + * The ScalarBuiltinMultipleCPInstruction class is responsible for printf-style + * Java-based string formatting. The first input is the format string. The + * inputs after the first input are the arguments to be formatted in the format + * string. + * + */ +public class ScalarBuiltinMultipleCPInstruction extends BuiltinMultipleCPInstruction { + + public ScalarBuiltinMultipleCPInstruction(Operator op, String opcode, String istr, CPOperand output, + CPOperand... inputs) { + super(op, opcode, istr, output, inputs); + } + + @Override + public void processInstruction(ExecutionContext ec) throws DMLRuntimeException { + if (MultipleCP.OperationType.PRINTF.toString().equalsIgnoreCase(getOpcode())) { + List<ScalarObject> scalarObjects = new ArrayList<ScalarObject>(); + for (CPOperand input : inputs) { + ScalarObject so = ec.getScalarInput(input.getName(), input.getValueType(), input.isLiteral()); + scalarObjects.add(so); + } + + // determine the format string (first argument) to pass to String.format + ScalarObject formatStringObject = scalarObjects.get(0); + if (formatStringObject.getValueType() != Expression.ValueType.STRING) { + throw new DMLRuntimeException("First parameter needs to be a string"); + } + String formatString = formatStringObject.getStringValue(); + + // determine the arguments after the format string to pass to String.format + Object[] objects = null; + if (scalarObjects.size() > 1) { + objects = new Object[scalarObjects.size() - 1]; + for (int i = 1; i < scalarObjects.size(); i++) { + ScalarObject scalarObject = scalarObjects.get(i); + switch (scalarObject.getValueType()) { + case INT: + objects[i - 1] = scalarObject.getLongValue(); + break; + case DOUBLE: + objects[i - 1] = scalarObject.getDoubleValue(); + break; + case BOOLEAN: + objects[i - 1] = scalarObject.getBooleanValue(); + break; + case STRING: + objects[i - 1] = scalarObject.getStringValue(); + break; + default: + } + } + } + + String result = String.format(formatString, objects); + if (!DMLScript.suppressPrint2Stdout()) { + System.out.println(result); + } + + // this is necessary so that the remove variable operation can be + // performed + ec.setScalarOutput(output.getName(), new StringObject(result)); + } else { + throw new DMLRuntimeException( + "Opcode (" + getOpcode() + ") not recognized in ScalarBuiltinMultipleCPInstruction"); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/98f50994/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index 9c1647c..c74b187 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -2336,6 +2336,177 @@ public class MLContextTest extends AutomatedTestBase { ml.execute(script); } + @Test + public void testPrintFormattingStringSubstitution() { + System.out.println("MLContextTest - print formatting string substitution"); + Script script = dml("print('hello %s', 'world');"); + setExpectedStdOut("hello world"); + ml.execute(script); + } + + @Test + public void testPrintFormattingStringSubstitutions() { + System.out.println("MLContextTest - print formatting string substitutions"); + Script script = dml("print('%s %s', 'hello', 'world');"); + setExpectedStdOut("hello world"); + ml.execute(script); + } + + @Test + public void testPrintFormattingStringSubstitutionAlignment() { + System.out.println("MLContextTest - print formatting string substitution alignment"); + Script script = dml("print(\"'%10s' '%-10s'\", \"hello\", \"world\");"); + setExpectedStdOut("' hello' 'world '"); + ml.execute(script); + } + + @Test + public void testPrintFormattingStringSubstitutionVariables() { + System.out.println("MLContextTest - print formatting string substitution variables"); + Script script = dml("a='hello'; b='world'; print('%s %s', a, b);"); + setExpectedStdOut("hello world"); + ml.execute(script); + } + + @Test + public void testPrintFormattingIntegerSubstitution() { + System.out.println("MLContextTest - print formatting integer substitution"); + Script script = dml("print('int %d', 42);"); + setExpectedStdOut("int 42"); + ml.execute(script); + } + + @Test + public void testPrintFormattingIntegerSubstitutions() { + System.out.println("MLContextTest - print formatting integer substitutions"); + Script script = dml("print('%d %d', 42, 43);"); + setExpectedStdOut("42 43"); + ml.execute(script); + } + + @Test + public void testPrintFormattingIntegerSubstitutionAlignment() { + System.out.println("MLContextTest - print formatting integer substitution alignment"); + Script script = dml("print(\"'%10d' '%-10d'\", 42, 43);"); + setExpectedStdOut("' 42' '43 '"); + ml.execute(script); + } + + @Test + public void testPrintFormattingIntegerSubstitutionVariables() { + System.out.println("MLContextTest - print formatting integer substitution variables"); + Script script = dml("a=42; b=43; print('%d %d', a, b);"); + setExpectedStdOut("42 43"); + ml.execute(script); + } + + @Test + public void testPrintFormattingDoubleSubstitution() { + System.out.println("MLContextTest - print formatting double substitution"); + Script script = dml("print('double %f', 42.0);"); + setExpectedStdOut("double 42.000000"); + ml.execute(script); + } + + @Test + public void testPrintFormattingDoubleSubstitutions() { + System.out.println("MLContextTest - print formatting double substitutions"); + Script script = dml("print('%f %f', 42.42, 43.43);"); + setExpectedStdOut("42.420000 43.430000"); + ml.execute(script); + } + + @Test + public void testPrintFormattingDoubleSubstitutionAlignment() { + System.out.println("MLContextTest - print formatting double substitution alignment"); + Script script = dml("print(\"'%10.2f' '%-10.2f'\", 42.53, 43.54);"); + setExpectedStdOut("' 42.53' '43.54 '"); + ml.execute(script); + } + + @Test + public void testPrintFormattingDoubleSubstitutionVariables() { + System.out.println("MLContextTest - print formatting double substitution variables"); + Script script = dml("a=12.34; b=56.78; print('%f %f', a, b);"); + setExpectedStdOut("12.340000 56.780000"); + ml.execute(script); + } + + @Test + public void testPrintFormattingBooleanSubstitution() { + System.out.println("MLContextTest - print formatting boolean substitution"); + Script script = dml("print('boolean %b', TRUE);"); + setExpectedStdOut("boolean true"); + ml.execute(script); + } + + @Test + public void testPrintFormattingBooleanSubstitutions() { + System.out.println("MLContextTest - print formatting boolean substitutions"); + Script script = dml("print('%b %b', TRUE, FALSE);"); + setExpectedStdOut("true false"); + ml.execute(script); + } + + @Test + public void testPrintFormattingBooleanSubstitutionAlignment() { + System.out.println("MLContextTest - print formatting boolean substitution alignment"); + Script script = dml("print(\"'%10b' '%-10b'\", TRUE, FALSE);"); + setExpectedStdOut("' true' 'false '"); + ml.execute(script); + } + + @Test + public void testPrintFormattingBooleanSubstitutionVariables() { + System.out.println("MLContextTest - print formatting boolean substitution variables"); + Script script = dml("a=TRUE; b=FALSE; print('%b %b', a, b);"); + setExpectedStdOut("true false"); + ml.execute(script); + } + + @Test + public void testPrintFormattingMultipleTypes() { + System.out.println("MLContextTest - print formatting multiple types"); + Script script = dml("a='hello'; b=3; c=4.5; d=TRUE; print('%s %d %f %b', a, b, c, d);"); + setExpectedStdOut("hello 3 4.500000 true"); + ml.execute(script); + } + + @Test + public void testPrintFormattingMultipleExpressions() { + System.out.println("MLContextTest - print formatting multiple expressions"); + Script script = dml("a='hello'; b='goodbye'; c=4; d=3; e=3.0; f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);"); + setExpectedStdOut("hellogoodbye 1 15.000000 true"); + ml.execute(script); + } + + @Test + public void testPrintFormattingForLoop() { + System.out.println("MLContextTest - print formatting for loop"); + Script script = dml("for (i in 1:3) { print('int value %d', i); }"); + // check that one of the lines is returned + setExpectedStdOut("int value 3"); + ml.execute(script); + } + + @Test + public void testPrintFormattingParforLoop() { + System.out.println("MLContextTest - print formatting parfor loop"); + Script script = dml("parfor (i in 1:3) { print('int value %d', i); }"); + // check that one of the lines is returned + setExpectedStdOut("int value 3"); + ml.execute(script); + } + + @Test + public void testPrintFormattingForLoopMultiply() { + System.out.println("MLContextTest - print formatting for loop multiply"); + Script script = dml("a = 5.0; for (i in 1:3) { print('%d %f', i, a * i); }"); + // check that one of the lines is returned + setExpectedStdOut("3 15.000000"); + ml.execute(script); + } + // NOTE: Uncomment these tests once they work // @SuppressWarnings({ "rawtypes", "unchecked" })
