This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new 83400bb [SYSTEMDS-2764] Frame constructor and data-gen operations 83400bb is described below commit 83400bb0f239a60e9c44444fe5977ee304f611f0 Author: Shafaq Siddiqi <shafaq.sidd...@tugraz.at> AuthorDate: Sat Dec 19 22:32:23 2020 +0100 [SYSTEMDS-2764] Frame constructor and data-gen operations DIA project WS2020/21. Closes #1132. --- src/main/java/org/apache/sysds/common/Types.java | 2 +- .../apache/sysds/hops/recompile/Recompiler.java | 2 +- src/main/java/org/apache/sysds/lops/DataGen.java | 75 +- .../org/apache/sysds/parser/DMLTranslator.java | 6 + .../org/apache/sysds/parser/DataExpression.java | 781 ++++++++++++++------- .../java/org/apache/sysds/parser/Expression.java | 2 +- .../runtime/instructions/CPInstructionParser.java | 1 + .../runtime/instructions/SPInstructionParser.java | 3 +- .../instructions/cp/DataGenCPInstruction.java | 90 ++- .../instructions/spark/RandSPInstruction.java | 205 +++++- .../apache/sysds/runtime/util/UtilFunctions.java | 71 +- .../test/functions/builtin/BuiltinDBSCANTest.java | 13 +- .../test/functions/frame/FrameConstructorTest.java | 156 ++++ .../functions/frame/FrameConstructorTest.dml | 33 + 14 files changed, 1136 insertions(+), 304 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index c76cd1c..f03feea 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -448,7 +448,7 @@ public class Types } public enum OpOpDG { - RAND, SEQ, SINIT, SAMPLE, TIME + RAND, SEQ, FRAMEINIT, SINIT, SAMPLE, TIME } public enum OpOpData { diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java index 6e960e7..c57fbc8 100644 --- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java @@ -1357,7 +1357,7 @@ public class Recompiler DataGenOp d = (DataGenOp) hop; HashMap<String,Integer> params = d.getParamIndexMap(); if ( d.getOp() == OpOpDG.RAND || d.getOp()==OpOpDG.SINIT - || d.getOp() == OpOpDG.SAMPLE ) + || d.getOp() == OpOpDG.SAMPLE || d.getOp() == OpOpDG.FRAMEINIT ) { boolean initUnknown = !d.dimsKnown(); // TODO refresh tensor size information diff --git a/src/main/java/org/apache/sysds/lops/DataGen.java b/src/main/java/org/apache/sysds/lops/DataGen.java index ddc1a8a..7487a59 100644 --- a/src/main/java/org/apache/sysds/lops/DataGen.java +++ b/src/main/java/org/apache/sysds/lops/DataGen.java @@ -42,7 +42,8 @@ public class DataGen extends Lop public static final String SINIT_OPCODE = "sinit"; //string initialize public static final String SAMPLE_OPCODE = "sample"; //sample.int public static final String TIME_OPCODE = "time"; //time - + public static final String FRAME_OPCODE = "frame"; //time + private int _numThreads = 1; /** base dir for rand input */ @@ -111,6 +112,8 @@ public class DataGen extends Lop return getSampleInstructionCPSpark(output); case TIME: return getTimeInstructionCP(output); + case FRAMEINIT: + return getFrameInstructionCPSpark(output); default: throw new LopsException("Unknown data generation method: " + _op); } @@ -206,6 +209,76 @@ public class DataGen extends Lop return sb.toString(); } + private String getFrameInstructionCPSpark(String output) + { + //sanity checks + if ( _op != OpOpDG.FRAMEINIT ) + throw new LopsException("Invalid instruction generation for data generation method " + _op); + if( getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.size() - 5 ) { // frame + throw new LopsException(printErrorLocation() + "Invalid number of operands (" + + getInputs().size() + ") for a frame operation"); + } + + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + sb.append( Lop.OPERAND_DELIMITOR ); + + sb.append(FRAME_OPCODE); + sb.append(OPERAND_DELIMITOR); + + Lop iLop = _inputParams.get(DataExpression.RAND_DATA); + if ( iLop != null ) { + if(iLop instanceof Nary) { + for(Lop lop : iLop.getInputs()) { + sb.append(((Data)lop).getStringValue()); + sb.append(DataExpression.DELIM_NA_STRING_SEP); + } + } + else if(iLop instanceof Data) { + sb.append(((Data)iLop).getStringValue()); + } + } + + sb.append(OPERAND_DELIMITOR); + + iLop = _inputParams.get(DataExpression.RAND_DIMS); + if (iLop != null) { + sb.append(iLop.prepScalarInputOperand(getExecType())); + sb.append(OPERAND_DELIMITOR); + } + else { + iLop = _inputParams.get(DataExpression.RAND_ROWS); + sb.append(iLop.prepScalarInputOperand(getExecType())); + sb.append(OPERAND_DELIMITOR); + + iLop = _inputParams.get(DataExpression.RAND_COLS); + sb.append(iLop.prepScalarInputOperand(getExecType())); + sb.append(OPERAND_DELIMITOR); + } + iLop = _inputParams.get(DataExpression.SCHEMAPARAM); + if ( iLop != null ) { + if(iLop instanceof Nary) { + for(Lop lop : iLop.getInputs()) { + sb.append(((Data)lop).getStringValue()); + sb.append(DataExpression.DELIM_NA_STRING_SEP); + } + } + else if(iLop instanceof Data) { + sb.append(((Data)iLop).getStringValue()); + } + } + + sb.append(OPERAND_DELIMITOR); + + if( getExecType() == ExecType.SPARK ) { + sb.append(baseDir); + sb.append(OPERAND_DELIMITOR); + } + + sb.append( prepOutputOperand(output)); + return sb.toString(); + } + private String getSInitInstructionCPSpark(String output) { if ( _op != OpOpDG.SINIT ) diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index aab0d22..8e33a15 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2093,6 +2093,12 @@ public class DMLTranslator currBuiltinOp = new DataGenOp(method, target, paramHops); break; + case FRAME: + // We limit RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, and RAND_PDF to be constants + method = OpOpDG.FRAMEINIT; + currBuiltinOp = new DataGenOp(method, target, paramHops); + break; + case TENSOR: case MATRIX: ArrayList<Hop> tmpMatrix = new ArrayList<>(); diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java index f9fe5a4..dcea873 100644 --- a/src/main/java/org/apache/sysds/parser/DataExpression.java +++ b/src/main/java/org/apache/sysds/parser/DataExpression.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map.Entry; import java.util.Set; @@ -125,7 +126,10 @@ public class DataExpression extends DataIdentifier public static final Set<String> RESHAPE_VALID_PARAM_NAMES = new HashSet<>( Arrays.asList(RAND_BY_ROW, RAND_DIMNAMES, RAND_DATA, RAND_ROWS, RAND_COLS, RAND_DIMS)); - + + public static final Set<String> FRAME_VALID_PARAM_NAMES = new HashSet<>( + Arrays.asList(SCHEMAPARAM, RAND_DATA, RAND_ROWS, RAND_COLS)); + public static final Set<String> SQL_VALID_PARAM_NAMES = new HashSet<>( Arrays.asList(SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY)); @@ -156,7 +160,8 @@ public class DataExpression extends DataIdentifier public static final double DEFAULT_DELIM_FILL_VALUE = 0.0; public static final boolean DEFAULT_DELIM_SPARSE = false; public static final String DEFAULT_NA_STRINGS = ""; - + public static final String DEFAULT_SCHEMAPARAM = "NULL"; + private DataOp _opcode; private HashMap<String, Expression> _varParams; private boolean _strInit = false; //string initialize @@ -186,309 +191,361 @@ public class DataExpression extends DataIdentifier + passedParamExprs + " " + parseInfo + " " + errorListener); } // check if the function name is built-in function - // (assign built-in function op if function is built-in) - Expression.DataOp dop; + // (assign built-in function op if function is built-in) DataExpression dataExpr = null; - if (functionName.equals("read") || functionName.equals("readMM") || functionName.equals("read.csv")) { - dop = Expression.DataOp.READ; - dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo); - - if (functionName.equals("readMM")) - dataExpr.addVarParam(DataExpression.FORMAT_TYPE, - new StringIdentifier(FileFormat.MM.toString(), parseInfo)); - - if (functionName.equals("read.csv")) - dataExpr.addVarParam(DataExpression.FORMAT_TYPE, - new StringIdentifier(FileFormat.CSV.toString(), parseInfo)); - - if (functionName.equals("read.libsvm")) - dataExpr.addVarParam(DataExpression.FORMAT_TYPE, - new StringIdentifier(FileFormat.LIBSVM.toString(), parseInfo)); - - // validate the filename is the first parameter - if (passedParamExprs.size() < 1){ - errorListener.validationError(parseInfo, "read method must have at least filename parameter"); - return null; - } - - ParameterExpression pexpr = (passedParamExprs.size() == 0) ? null : passedParamExprs.get(0); + if (functionName.equals("read") || functionName.equals("readMM") || functionName.equals("read.csv")) + dataExpr = processReadDataExpression(functionName, passedParamExprs, errorListener, parseInfo); + else if (functionName.equalsIgnoreCase("rand")) + dataExpr = processRandDataExpression(functionName, passedParamExprs, errorListener, parseInfo); + else if (functionName.equals("matrix")) + dataExpr = processMatrixExpression(functionName, passedParamExprs, errorListener, parseInfo); + else if (functionName.equals("frame")) + dataExpr = processFrameExpression(functionName, passedParamExprs, errorListener, parseInfo); + else if (functionName.equals("tensor")) + dataExpr = processTensorExpression(functionName, passedParamExprs, errorListener, parseInfo); + else if (functionName.equals("sql")) + dataExpr = processSQLExpression(functionName, passedParamExprs, errorListener, parseInfo); + else if (functionName.equals("federated")) + dataExpr = processFederatedExpression(functionName, passedParamExprs, errorListener, parseInfo); + + if (dataExpr != null) + dataExpr.setParseInfo(parseInfo); + return dataExpr; + } + + private static DataExpression processReadDataExpression(String functionName, + List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo) + { + DataExpression dataExpr = new DataExpression(DataOp.READ, new HashMap<>(), parseInfo); + if (functionName.equals("readMM")) + dataExpr.addVarParam(DataExpression.FORMAT_TYPE, + new StringIdentifier(FileFormat.MM.toString(), parseInfo)); + + if (functionName.equals("read.csv")) + dataExpr.addVarParam(DataExpression.FORMAT_TYPE, + new StringIdentifier(FileFormat.CSV.toString(), parseInfo)); + + if (functionName.equals("read.libsvm")) + dataExpr.addVarParam(DataExpression.FORMAT_TYPE, + new StringIdentifier(FileFormat.LIBSVM.toString(), parseInfo)); + + // validate the filename is the first parameter + if (passedParamExprs.size() < 1){ + errorListener.validationError(parseInfo, "read method must have at least filename parameter"); + return null; + } + + ParameterExpression pexpr = (passedParamExprs.size() == 0) ? null : passedParamExprs.get(0); + + if ( (pexpr != null) && (!(pexpr.getName() == null) || (pexpr.getName() != null && pexpr.getName().equalsIgnoreCase(DataExpression.IO_FILENAME)))){ + errorListener.validationError(parseInfo, "first parameter to read statement must be filename"); + return null; + } else if( pexpr != null ){ + dataExpr.addVarParam(DataExpression.IO_FILENAME, pexpr.getExpr()); + } + + // validate all parameters are added only once and valid name + for (int i = 1; i < passedParamExprs.size(); i++){ + String currName = passedParamExprs.get(i).getName(); + Expression currExpr = passedParamExprs.get(i).getExpr(); - if ( (pexpr != null) && (!(pexpr.getName() == null) || (pexpr.getName() != null && pexpr.getName().equalsIgnoreCase(DataExpression.IO_FILENAME)))){ - errorListener.validationError(parseInfo, "first parameter to read statement must be filename"); + if (dataExpr.getVarParam(currName) != null){ + errorListener.validationError(parseInfo, "attempted to add IOStatement parameter " + currName + " more than once"); return null; - } else if( pexpr != null ){ - dataExpr.addVarParam(DataExpression.IO_FILENAME, pexpr.getExpr()); } - - // validate all parameters are added only once and valid name - for (int i = 1; i < passedParamExprs.size(); i++){ - String currName = passedParamExprs.get(i).getName(); - Expression currExpr = passedParamExprs.get(i).getExpr(); - - if (dataExpr.getVarParam(currName) != null){ - errorListener.validationError(parseInfo, "attempted to add IOStatement parameter " + currName + " more than once"); - return null; - } - // verify parameter names for read function - boolean isValidName = READ_VALID_PARAM_NAMES.contains(currName); + // verify parameter names for read function + boolean isValidName = READ_VALID_PARAM_NAMES.contains(currName); - if (!isValidName){ - errorListener.validationError(parseInfo, "attempted to add invalid read statement parameter " + currName); - return null; - } - dataExpr.addVarParam(currName, currExpr); + if (!isValidName){ + errorListener.validationError(parseInfo, "attempted to add invalid read statement parameter " + currName); + return null; } + dataExpr.addVarParam(currName, currExpr); } - else if (functionName.equalsIgnoreCase("rand")){ - - dop = Expression.DataOp.RAND; - dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo); - - for (ParameterExpression currExpr : passedParamExprs){ - String pname = currExpr.getName(); - Expression pexpr = currExpr.getExpr(); - if (pname == null){ - errorListener.validationError(parseInfo, "for rand statement, all arguments must be named parameters"); - return null; - } - dataExpr.addRandExprParam(pname, pexpr); + + return dataExpr; + } + + private static DataExpression processRandDataExpression(String functionName, + List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo) + { + DataExpression dataExpr = new DataExpression(DataOp.RAND, new HashMap<>(), parseInfo); + + for (ParameterExpression currExpr : passedParamExprs){ + String pname = currExpr.getName(); + Expression pexpr = currExpr.getExpr(); + if (pname == null){ + errorListener.validationError(parseInfo, "for rand statement, all arguments must be named parameters"); + return null; } - dataExpr.setRandDefault(); + dataExpr.addRandExprParam(pname, pexpr); } + dataExpr.setRandDefault(); + return dataExpr; + } + + private static DataExpression processMatrixExpression(String functionName, + List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo) + { + DataExpression dataExpr = new DataExpression(DataOp.MATRIX, new HashMap<>(), parseInfo); + int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count(); + int unnamedParamCount = passedParamExprs.size() - namedParamCount; - else if (functionName.equals("matrix")){ - dop = Expression.DataOp.MATRIX; - dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo); + // check whether named or unnamed parameters are used + if (passedParamExprs.size() < 3){ + errorListener.validationError(parseInfo, "for matrix statement, must specify at least 3 arguments: data, rows, cols"); + return null; + } - int namedParamCount = 0, unnamedParamCount = 0; - for (ParameterExpression currExpr : passedParamExprs) { - if (currExpr.getName() == null) - unnamedParamCount++; - else - namedParamCount++; + if (unnamedParamCount > 1){ + if (namedParamCount > 0) { + errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters"); + return null; } - - // check whether named or unnamed parameters are used - if (passedParamExprs.size() < 3){ + if (unnamedParamCount < 3) { errorListener.validationError(parseInfo, "for matrix statement, must specify at least 3 arguments: data, rows, cols"); return null; } - - if (unnamedParamCount > 1){ - - if (namedParamCount > 0) { - errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters"); - return null; - } - - if (unnamedParamCount < 3) { - errorListener.validationError(parseInfo, "for matrix statement, must specify at least 3 arguments: data, rows, cols"); - return null; - } - - // assume: data, rows, cols, [byRow], [dimNames] - dataExpr.addMatrixExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr()); - dataExpr.addMatrixExprParam(DataExpression.RAND_ROWS,passedParamExprs.get(1).getExpr()); - dataExpr.addMatrixExprParam(DataExpression.RAND_COLS,passedParamExprs.get(2).getExpr()); - - if (unnamedParamCount >= 4) - dataExpr.addMatrixExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(3).getExpr()); - - if (unnamedParamCount == 5) - dataExpr.addMatrixExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(4).getExpr()); - - if (unnamedParamCount > 5) { - errorListener.validationError(parseInfo, "for matrix statement, at most 5 arguments supported: data, rows, cols, byrow, dimname"); - return null; - } - + // assume: data, rows, cols, [byRow], [dimNames] + dataExpr.addMatrixExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr()); + dataExpr.addMatrixExprParam(DataExpression.RAND_ROWS,passedParamExprs.get(1).getExpr()); + dataExpr.addMatrixExprParam(DataExpression.RAND_COLS,passedParamExprs.get(2).getExpr()); + + if (unnamedParamCount >= 4) + dataExpr.addMatrixExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(3).getExpr()); + if (unnamedParamCount == 5) + dataExpr.addMatrixExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(4).getExpr()); + if (unnamedParamCount > 5) { + errorListener.validationError(parseInfo, "for matrix statement, at most 5 arguments supported: data, rows, cols, byrow, dimname"); + return null; + } + } + else { + // handle first parameter, which is data and may be unnamed + ParameterExpression firstParam = passedParamExprs.get(0); + if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){ + errorListener.validationError(parseInfo, "matrix method must have data parameter as first parameter or unnamed parameter"); + return null; } else { - // handle first parameter, which is data and may be unnamed - ParameterExpression firstParam = passedParamExprs.get(0); - if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){ - errorListener.validationError(parseInfo, "matrix method must have data parameter as first parameter or unnamed parameter"); + dataExpr.addMatrixExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr()); + } + + for (int i=1; i<passedParamExprs.size(); i++){ + if (passedParamExprs.get(i).getName() == null){ + errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters, only data parameter can be unnammed"); return null; } else { - dataExpr.addMatrixExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr()); - } - - for (int i=1; i<passedParamExprs.size(); i++){ - if (passedParamExprs.get(i).getName() == null){ - errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters, only data parameter can be unnammed"); - return null; - } else { - dataExpr.addMatrixExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr()); - } + dataExpr.addMatrixExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr()); } } - dataExpr.setMatrixDefault(); } - else if (functionName.equals("tensor")){ - dop = Expression.DataOp.TENSOR; - dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo); + dataExpr.setMatrixDefault(); + return dataExpr; + } + + private static DataExpression processFrameExpression(String functionName, + List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo) + { + DataExpression dataExpr = new DataExpression(DataOp.FRAME, new HashMap<>(), parseInfo); + int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count(); + int unnamedParamCount = passedParamExprs.size() - namedParamCount; + + // check whether named or unnamed parameters are used + if (passedParamExprs.size() < 3) { // it will generate a frame with string schema + errorListener.validationError(parseInfo, "for frame statement, must specify at least 3 arguments: data, rows and cols"); + return null; + } - int namedParamCount = 0, unnamedParamCount = 0; - for (ParameterExpression currExpr : passedParamExprs) { - if (currExpr.getName() == null) - unnamedParamCount++; - else - namedParamCount++; + if (unnamedParamCount > 1) { + if (namedParamCount > 0) { + errorListener.validationError(parseInfo, "for frame statement, cannot mix named and unnamed parameters"); + return null; + } + if (unnamedParamCount < 3) { + errorListener.validationError(parseInfo, "for frame statement, must specify at least 3 arguments: rows, cols"); + return null; } + // assume: data, rows, cols, [Schema] + dataExpr.addFrameExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr()); + dataExpr.addFrameExprParam(DataExpression.RAND_ROWS, passedParamExprs.get(1).getExpr()); + dataExpr.addFrameExprParam(DataExpression.RAND_COLS, passedParamExprs.get(2).getExpr()); - // check whether named or unnamed parameters are used - if (passedParamExprs.size() < 2){ - errorListener.validationError(parseInfo, "for tensor statement, must specify at least 2 arguments: data, dims[]"); + if (unnamedParamCount == 3) + dataExpr.addFrameExprParam(DataExpression.SCHEMAPARAM, passedParamExprs.get(3).getExpr()); + if (unnamedParamCount > 3) { + errorListener.validationError(parseInfo, "for frame statement, at most 4 arguments supported: data, rows, cols, schema"); return null; } + } + else { + // handle first parameter, which is data and may be unnamed + ParameterExpression firstParam = passedParamExprs.get(0); + if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){ + errorListener.validationError(parseInfo, "frame method must have data parameter as first parameter or unnamed parameter"); + return null; + } + else { + dataExpr.addFrameExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr()); + } - if (unnamedParamCount > 1){ - if (namedParamCount > 0) { - errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters"); + for (int i=1; i<passedParamExprs.size(); i++){ + if (passedParamExprs.get(i).getName() == null){ + errorListener.validationError(parseInfo, "for frame statement, cannot mix named and unnamed parameters, only data parameter can be unnammed"); return null; + } else { + dataExpr.addFrameExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr()); } + } + } + dataExpr.setFrameDefault(); + return dataExpr; + } + + private static DataExpression processTensorExpression(String functionName, + List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo) + { + DataExpression dataExpr = new DataExpression(DataOp.TENSOR, new HashMap<>(), parseInfo); + int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count(); + int unnamedParamCount = passedParamExprs.size() - namedParamCount; - // assume: data, dims[], [byRow], [dimNames] - dataExpr.addTensorExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr()); - dataExpr.addTensorExprParam(DataExpression.RAND_DIMS,passedParamExprs.get(1).getExpr()); - - if (unnamedParamCount >= 3) - // TODO use byRow parameter - dataExpr.addTensorExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(2).getExpr()); - - if (unnamedParamCount == 4) - dataExpr.addTensorExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(3).getExpr()); + // check whether named or unnamed parameters are used + if (passedParamExprs.size() < 2){ + errorListener.validationError(parseInfo, "for tensor statement, must specify at least 2 arguments: data, dims[]"); + return null; + } + if (unnamedParamCount > 1){ + if (namedParamCount > 0) { + errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters"); + return null; + } - if (unnamedParamCount > 4) { - errorListener.validationError(parseInfo, "for tensor statement, at most 4 arguments supported: data, dims, byrow, dimname"); - return null; - } + // assume: data, dims[], [byRow], [dimNames] + dataExpr.addTensorExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr()); + dataExpr.addTensorExprParam(DataExpression.RAND_DIMS,passedParamExprs.get(1).getExpr()); + if (unnamedParamCount >= 3) + // TODO use byRow parameter + dataExpr.addTensorExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(2).getExpr()); + if (unnamedParamCount == 4) + dataExpr.addTensorExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(3).getExpr()); + if (unnamedParamCount > 4) { + errorListener.validationError(parseInfo, "for tensor statement, at most 4 arguments supported: data, dims, byrow, dimname"); + return null; + } + } + else { + // handle first parameter, which is data and may be unnamed + ParameterExpression firstParam = passedParamExprs.get(0); + if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){ + errorListener.validationError(parseInfo, "tensor method must have data parameter as first parameter or unnamed parameter"); + return null; } else { - // handle first parameter, which is data and may be unnamed - ParameterExpression firstParam = passedParamExprs.get(0); - if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){ - errorListener.validationError(parseInfo, "tensor method must have data parameter as first parameter or unnamed parameter"); + dataExpr.addTensorExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr()); + } + + for (int i=1; i<passedParamExprs.size(); i++){ + if (passedParamExprs.get(i).getName() == null){ + errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters, only data parameter can be unnammed"); return null; } else { - dataExpr.addTensorExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr()); - } - - for (int i=1; i<passedParamExprs.size(); i++){ - if (passedParamExprs.get(i).getName() == null){ - errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters, only data parameter can be unnammed"); - return null; - } - else { - dataExpr.addTensorExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr()); - } + dataExpr.addTensorExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr()); } } - dataExpr.setTensorDefault(); } - else if (functionName.equals("sql")) { - dop = DataOp.SQL; - dataExpr = new DataExpression(dop, new HashMap<>(), parseInfo); - - int namedParamCount = 0, unnamedParamCount = 0; - for (ParameterExpression currExpr : passedParamExprs) { - if (currExpr.getName() == null) - unnamedParamCount++; - else - namedParamCount++; - } - - // check whether named or unnamed parameters are used - if (passedParamExprs.size() < 2){ - errorListener.validationError(parseInfo, "for sql statement, must specify at least 2 arguments: conn, query"); + dataExpr.setTensorDefault(); + return dataExpr; + } + + private static DataExpression processSQLExpression(String functionName, + List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo) + { + DataExpression dataExpr = new DataExpression(DataOp.SQL, new HashMap<>(), parseInfo); + int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count(); + int unnamedParamCount = passedParamExprs.size() - namedParamCount; + + // check whether named or unnamed parameters are used + if (passedParamExprs.size() < 2){ + errorListener.validationError(parseInfo, "for sql statement, must specify at least 2 arguments: conn, query"); + return null; + } + if (unnamedParamCount > 0){ + if (namedParamCount > 0) { + errorListener.validationError(parseInfo, "for sql statement, cannot mix named and unnamed parameters"); return null; } - - if (unnamedParamCount > 0){ - if (namedParamCount > 0) { - errorListener.validationError(parseInfo, "for sql statement, cannot mix named and unnamed parameters"); - return null; - } - - if (unnamedParamCount == 2 || unnamedParamCount == 4 ) { - // assume: conn, query, [password, query] - dataExpr.addSqlExprParam(DataExpression.SQL_CONN, passedParamExprs.get(0).getExpr()); - dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(1).getExpr()); - if (unnamedParamCount == 4) { - dataExpr.addSqlExprParam(DataExpression.SQL_PASS, passedParamExprs.get(2).getExpr()); - dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(3).getExpr()); - } + if (unnamedParamCount == 2 || unnamedParamCount == 4 ) { + // assume: conn, query, [password, query] + dataExpr.addSqlExprParam(DataExpression.SQL_CONN, passedParamExprs.get(0).getExpr()); + dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(1).getExpr()); + if (unnamedParamCount == 4) { + dataExpr.addSqlExprParam(DataExpression.SQL_PASS, passedParamExprs.get(2).getExpr()); + dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(3).getExpr()); } - else { - errorListener.validationError(parseInfo, "for sql statement, " - + "at most 4 arguments supported: conn, user, password, query"); - return null; - } - } else { - for (ParameterExpression passedParamExpr : passedParamExprs) { - dataExpr.addSqlExprParam(passedParamExpr.getName(), passedParamExpr.getExpr()); - } + errorListener.validationError(parseInfo, "for sql statement, " + + "at most 4 arguments supported: conn, user, password, query"); + return null; } - dataExpr.setSqlDefault(); } - else if (functionName.equals("federated")) { - dop = DataOp.FEDERATED; - dataExpr = new DataExpression(dop, new HashMap<>(), parseInfo); - int namedParamCount = 0, unnamedParamCount = 0; - for (ParameterExpression currExpr : passedParamExprs) { - if (currExpr.getName() == null) - unnamedParamCount++; - else - namedParamCount++; + else { + for (ParameterExpression passedParamExpr : passedParamExprs) { + dataExpr.addSqlExprParam(passedParamExpr.getName(), passedParamExpr.getExpr()); } - if(passedParamExprs.size() < 2) { + } + dataExpr.setSqlDefault(); + return dataExpr; + } + + private static DataExpression processFederatedExpression(String functionName, + List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo) + { + DataExpression dataExpr = new DataExpression(DataOp.FEDERATED, new HashMap<>(), parseInfo); + int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count(); + int unnamedParamCount = passedParamExprs.size() - namedParamCount; + + if(passedParamExprs.size() < 2) { + errorListener.validationError(parseInfo, + "for federated statement, must specify at least 2 arguments: addresses, ranges"); + return null; + } + if(unnamedParamCount > 0) { + if(namedParamCount > 0) { errorListener.validationError(parseInfo, - "for federated statement, must specify at least 2 arguments: addresses, ranges"); + "for federated statement, cannot mix named and unnamed parameters"); return null; } - if(unnamedParamCount > 0) { - if(namedParamCount > 0) { - errorListener.validationError(parseInfo, - "for federated statement, cannot mix named and unnamed parameters"); - return null; - } - if(unnamedParamCount == 2) { - // first parameter addresses second are the ranges (type defaults to Matrix) - ParameterExpression param = passedParamExprs.get(0); - dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr()); - param = passedParamExprs.get(1); - dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr()); - } - else if(unnamedParamCount == 3) { - ParameterExpression param = passedParamExprs.get(0); - dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr()); - param = passedParamExprs.get(1); - dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr()); - param = passedParamExprs.get(2); - dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr()); - } - else { - errorListener.validationError(parseInfo, - "for federated statement, at most 3 arguments are supported: addresses, ranges, type"); - } + if(unnamedParamCount == 2) { + // first parameter addresses second are the ranges (type defaults to Matrix) + ParameterExpression param = passedParamExprs.get(0); + dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr()); + param = passedParamExprs.get(1); + dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr()); + } + else if(unnamedParamCount == 3) { + ParameterExpression param = passedParamExprs.get(0); + dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr()); + param = passedParamExprs.get(1); + dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr()); + param = passedParamExprs.get(2); + dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr()); } else { - for (ParameterExpression passedParamExpr : passedParamExprs) { - dataExpr.addFederatedExprParam(passedParamExpr.getName(), passedParamExpr.getExpr()); - } + errorListener.validationError(parseInfo, + "for federated statement, at most 3 arguments are supported: addresses, ranges, type"); } - dataExpr.setFederatedDefault(); } - - if (dataExpr != null) { - dataExpr.setParseInfo(parseInfo); + else { + for (ParameterExpression passedParamExpr : passedParamExprs) { + dataExpr.addFederatedExprParam(passedParamExpr.getName(), passedParamExpr.getExpr()); + } } + dataExpr.setFederatedDefault(); return dataExpr; - } // end method getBuiltinFunctionExpression + } public void addRandExprParam(String paramName, Expression paramValue) { @@ -544,6 +601,32 @@ public class DataExpression extends DataIdentifier paramValue.setParseInfo(this); addVarParam(paramName,paramValue); } + public void addFrameExprParam(String paramName, Expression paramValue) + { + // check name is valid + boolean found = FRAME_VALID_PARAM_NAMES.contains(paramName); + + if (!found){ + raiseValidateError("unexpected parameter \"" + paramName + + "\". Legal parameters for frame statement are " + + "(capitalization-sensitive): " + RAND_DATA + ", " + RAND_ROWS + + ", " + RAND_COLS + ", " + SCHEMAPARAM); + } + if (getVarParam(paramName) != null) { + raiseValidateError("attempted to add frame statement parameter " + paramValue + " more than once"); + } +// TODO convert double Matrix to String Frame + // Process the case where user provides double values to rows or cols +// if (paramName.equals(RAND_ROWS) && paramValue instanceof StringIdentifier) { +// paramValue = new IntIdentifier((long) ((DoubleIdentifier) paramValue).getValue(), this); +// } else if (paramName.equals(RAND_COLS) && paramValue instanceof DoubleIdentifier) { +// paramValue = new IntIdentifier((long) ((DoubleIdentifier) paramValue).getValue(), this); +// } + + // add the parameter to expression list + paramValue.setParseInfo(this); + addVarParam(paramName,paramValue); + } public void addTensorExprParam(String paramName, Expression paramValue) { @@ -641,6 +724,13 @@ public class DataExpression extends DataIdentifier addVarParam(RAND_BY_ROW, new BooleanIdentifier(true, this)); } + public void setFrameDefault(){ + if(getVarParam(RAND_DATA) == null) + addVarParam(RAND_DATA, new StringIdentifier(null, this)); + if (getVarParam(SCHEMAPARAM) == null) + addVarParam(SCHEMAPARAM, new StringIdentifier(DEFAULT_SCHEMAPARAM, this)); + } + public void setTensorDefault(){ if (getVarParam(RAND_BY_ROW) == null) addVarParam(RAND_BY_ROW, new BooleanIdentifier(true, this)); @@ -792,7 +882,7 @@ public class DataExpression extends DataIdentifier } inputParamExpr.validateExpression(ids, currConstVars, conditional); if (s != null && !s.equals(RAND_DATA) && !s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES) - && !s.equals(DELIM_NA_STRINGS) && getVarParam(s).getOutput().getDataType() != DataType.SCALAR ) { + && !s.equals(DELIM_NA_STRINGS) && !s.equals(SCHEMAPARAM) && getVarParam(s).getOutput().getDataType() != DataType.SCALAR ) { raiseValidateError("Non-scalar data types are not supported for data expression.", conditional,LanguageErrorCodes.INVALID_PARAMETERS); } } @@ -804,20 +894,19 @@ public class DataExpression extends DataIdentifier // check if data parameter of matrix is scalar or matrix -- if scalar, use Rand instead Expression dataParam1 = getVarParam(RAND_DATA); if (dataParam1 == null && (getOpCode().equals(DataOp.MATRIX) || getOpCode().equals(DataOp.TENSOR))){ - raiseValidateError("for matrix or tensor, must defined data parameter", conditional, LanguageErrorCodes.INVALID_PARAMETERS); + raiseValidateError("for matrix, frame or tensor, must defined data parameter", conditional, LanguageErrorCodes.INVALID_PARAMETERS); } // We need to remember the operation if we replace the OpCode by rand so we have the correct output - if (dataParam1 != null && dataParam1.getOutput().getDataType() == DataType.SCALAR && + if (dataParam1!=null && dataParam1.getOutput()!=null && dataParam1.getOutput().getDataType() == DataType.SCALAR && (_opcode == DataOp.MATRIX || _opcode == DataOp.TENSOR)/*&& dataParam instanceof ConstIdentifier*/ ){ - //MB: note we should not check for const identifiers here, because otherwise all matrix constructors with + //MB: note we must not check for const identifiers here, because otherwise all matrix constructors with //variable input are routed to a reshape operation (but it works only on matrices and hence, crashes) // replace DataOp MATRIX with RAND -- Rand handles matrix generation for Scalar values // replace data parameter with min / max within Rand case below this.setOpCode(DataOp.RAND); } - - + // IMPORTANT: for each operation, one must handle unnamed parameters switch (this.getOpCode()) { @@ -1732,7 +1821,7 @@ public class DataExpression extends DataIdentifier else { raiseValidateError("In matrix statement, can only assign rows a long " + "(integer) value >= 1 -- attempted to assign value: " + colsExpr.toString(), conditional); - } + } } else if (colsExpr instanceof DataIdentifier && !(colsExpr instanceof IndexedIdentifier)) { @@ -1758,7 +1847,6 @@ public class DataExpression extends DataIdentifier } // handle double constant else if (constValue instanceof DoubleIdentifier){ - if (((DoubleIdentifier)constValue).getValue() < 1){ raiseValidateError("In matrix statement, can only assign cols a long " + "(integer) value >= 1 -- attempted to assign value: " @@ -1768,8 +1856,7 @@ public class DataExpression extends DataIdentifier long roundedValue = Double.valueOf(Math.floor(((DoubleIdentifier)constValue).getValue())).longValue(); colsExpr = new IntIdentifier(roundedValue, this); addVarParam(RAND_COLS, colsExpr); - colsLong = roundedValue; - + colsLong = roundedValue; } else { // exception -- rows must be integer or double constant @@ -1781,29 +1868,190 @@ public class DataExpression extends DataIdentifier // handle general expression colsExpr.validateExpression(ids, currConstVars, conditional); } - - } + } else { // handle general expression colsExpr.validateExpression(ids, currConstVars, conditional); } - } + } getOutput().setFileFormat(FileFormat.BINARY); getOutput().setDataType(DataType.MATRIX); getOutput().setValueType(ValueType.FP64); getOutput().setDimensions(rowsLong, colsLong); - + if (getOutput() instanceof IndexedIdentifier){ ((IndexedIdentifier) getOutput()).setOriginalDimensions(getOutput().getDim1(), getOutput().getDim2()); - } - //getOutput().computeDataType(); - - if (getOutput() instanceof IndexedIdentifier){ LOG.warn(this.printWarningLocation() + "Output for matrix Statement may have incorrect size information"); } break; + case FRAME: + //handle default and input arguments + setFrameDefault(); + validateParams(conditional, FRAME_VALID_PARAM_NAMES, + "Legal parameters for frame statement are (case-sensitive): " + + RAND_DATA + ", " + RAND_ROWS + ", " + RAND_COLS + ", " + SCHEMAPARAM); + + //validate correct value types + if (getVarParam(RAND_ROWS) != null && (getVarParam(RAND_ROWS) instanceof StringIdentifier || getVarParam(RAND_ROWS) instanceof BooleanIdentifier)){ + raiseValidateError("for frame statement " + RAND_ROWS + " has incorrect value type", conditional); + } + if (getVarParam(RAND_COLS) != null && (getVarParam(RAND_COLS) instanceof StringIdentifier || getVarParam(RAND_COLS) instanceof BooleanIdentifier)){ + raiseValidateError("for frame statement " + RAND_COLS + " has incorrect value type", conditional); + } + + //validate general data expression + getVarParam(RAND_DATA).validateExpression(ids, currConstVars, conditional); + + rowsLong = -1L; + colsLong = -1L; + + /////////////////////////////////////////////////////////////////// + // HANDLE ROWS + /////////////////////////////////////////////////////////////////// + rowsExpr = getVarParam(RAND_ROWS); + if (rowsExpr != null){ + if (rowsExpr instanceof IntIdentifier) { + if (((IntIdentifier)rowsExpr).getValue() >= 1 ) + rowsLong = ((IntIdentifier)rowsExpr).getValue(); + else + raiseValidateError("In frame statement, can only assign rows a long " + + "(integer) value >= 1 -- attempted to assign value: " + ((IntIdentifier)rowsExpr).getValue(), conditional); + } + else if (rowsExpr instanceof DoubleIdentifier) { + if (((DoubleIdentifier)rowsExpr).getValue() >= 1 ) + rowsLong = Double.valueOf((Math.floor(((DoubleIdentifier)rowsExpr).getValue()))).longValue(); + else + raiseValidateError("In frame statement, can only assign rows a long " + + "(integer) value >= 1 -- attempted to assign value: " + rowsExpr.toString(), conditional); + } + else if (rowsExpr instanceof DataIdentifier && !(rowsExpr instanceof IndexedIdentifier)) { + // check if the DataIdentifier variable is a ConstIdentifier + String identifierName = ((DataIdentifier)rowsExpr).getName(); + if (currConstVars.containsKey(identifierName)){ + // handle int constant + ConstIdentifier constValue = currConstVars.get(identifierName); + if (constValue instanceof IntIdentifier){ + // check rows is >= 1 --- throw exception + if (((IntIdentifier)constValue).getValue() < 1){ + raiseValidateError("In frame statement, can only assign rows a long " + + "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional); + } + // update row expr with new IntIdentifier + long roundedValue = ((IntIdentifier)constValue).getValue(); + rowsExpr = new IntIdentifier(roundedValue, this); + addVarParam(RAND_ROWS, rowsExpr); + rowsLong = roundedValue; + } + // handle double constant + else if (constValue instanceof DoubleIdentifier){ + if (((DoubleIdentifier)constValue).getValue() < 1.0){ + raiseValidateError("In frame statement, can only assign rows a long " + + "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional); + } + // update row expr with new IntIdentifier (rounded down) + long roundedValue = Double.valueOf(Math.floor(((DoubleIdentifier)constValue).getValue())).longValue(); + rowsExpr = new IntIdentifier(roundedValue, this); + addVarParam(RAND_ROWS, rowsExpr); + rowsLong = roundedValue; + } + else { + // exception -- rows must be integer or double constant + raiseValidateError("In frame statement, can only assign rows a long " + + "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional); + } + } + else { + // handle general expression + rowsExpr.validateExpression(ids, currConstVars, conditional); + } + } + else { + // handle general expression + rowsExpr.validateExpression(ids, currConstVars, conditional); + } + } + + /////////////////////////////////////////////////////////////////// + // HANDLE COLUMNS + /////////////////////////////////////////////////////////////////// + + colsExpr = getVarParam(RAND_COLS); + if (colsExpr != null){ + if (colsExpr instanceof IntIdentifier) { + if (((IntIdentifier)colsExpr).getValue() >= 1 ) + colsLong = ((IntIdentifier)colsExpr).getValue(); + else + raiseValidateError("In frame statement, can only assign cols a long " + + "(integer) value >= 1 -- attempted to assign value: " + colsExpr.toString(), conditional); + } + else if (colsExpr instanceof DoubleIdentifier) { + if (((DoubleIdentifier)colsExpr).getValue() >= 1 ) + colsLong = Double.valueOf((Math.floor(((DoubleIdentifier)colsExpr).getValue()))).longValue(); + else + raiseValidateError("In frame statement, can only assign rows a long " + + "(integer) value >= 1 -- attempted to assign value: " + colsExpr.toString(), conditional); + } + else if (colsExpr instanceof DataIdentifier && !(colsExpr instanceof IndexedIdentifier)) { + // check if the DataIdentifier variable is a ConstIdentifier + String identifierName = ((DataIdentifier)colsExpr).getName(); + if (currConstVars.containsKey(identifierName)){ + // handle int constant + ConstIdentifier constValue = currConstVars.get(identifierName); + if (constValue instanceof IntIdentifier){ + // check cols is >= 1 --- throw exception + if (((IntIdentifier)constValue).getValue() < 1){ + raiseValidateError("In frame statement, can only assign cols a long " + + "(integer) value >= 1 -- attempted to assign value: " + + constValue.toString(), conditional); + } + // update col expr with new IntIdentifier + long roundedValue = ((IntIdentifier)constValue).getValue(); + colsExpr = new IntIdentifier(roundedValue, this); + addVarParam(RAND_COLS, colsExpr); + colsLong = roundedValue; + } + // handle double constant + else if (constValue instanceof DoubleIdentifier){ + if (((DoubleIdentifier)constValue).getValue() < 1){ + raiseValidateError("In frame statement, can only assign cols a long " + + "(integer) value >= 1 -- attempted to assign value: " + + constValue.toString(), conditional); + } + // update col expr with new IntIdentifier (rounded down) + long roundedValue = Double.valueOf(Math.floor(((DoubleIdentifier)constValue).getValue())).longValue(); + colsExpr = new IntIdentifier(roundedValue, this); + addVarParam(RAND_COLS, colsExpr); + colsLong = roundedValue; + } + else { + // exception -- rows must be integer or double constant + raiseValidateError("In frame statement, can only assign cols a long " + + "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional); + } + } + else { + // handle general expression + colsExpr.validateExpression(ids, currConstVars, conditional); + } + } + else { + // handle general expression + colsExpr.validateExpression(ids, currConstVars, conditional); + } + } + getOutput().setFileFormat(FileFormat.BINARY); + getOutput().setDataType(DataType.FRAME); + getOutput().setValueType(ValueType.UNKNOWN); + getOutput().setDimensions(rowsLong, colsLong); + + if (getOutput() instanceof IndexedIdentifier){ + ((IndexedIdentifier) getOutput()).setOriginalDimensions(getOutput().getDim1(), getOutput().getDim2()); + LOG.warn(this.printWarningLocation() + "Output for frame Statement may have incorrect size information"); + } + break; + case TENSOR: //handle default and input arguments setTensorDefault(); @@ -1840,9 +2088,8 @@ public class DataExpression extends DataIdentifier if (getOutput() instanceof IndexedIdentifier){ LOG.warn(this.printWarningLocation() + "Output for tensor Statement may have incorrect size information"); } - break; - + case SQL: //handle default and input arguments setSqlDefault(); diff --git a/src/main/java/org/apache/sysds/parser/Expression.java b/src/main/java/org/apache/sysds/parser/Expression.java index 059d093..e7b49ca 100644 --- a/src/main/java/org/apache/sysds/parser/Expression.java +++ b/src/main/java/org/apache/sysds/parser/Expression.java @@ -60,7 +60,7 @@ public abstract class Expression implements ParseInfo * Data operators. */ public enum DataOp { - READ, WRITE, RAND, MATRIX, TENSOR, SQL, FEDERATED + READ, WRITE, RAND, MATRIX, FRAME, TENSOR, SQL, FEDERATED } /** diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index da86189..0a7e28d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -285,6 +285,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( DataGen.SINIT_OPCODE , CPType.StringInit); String2CPInstructionType.put( DataGen.SAMPLE_OPCODE , CPType.Rand); String2CPInstructionType.put( DataGen.TIME_OPCODE , CPType.Rand); + String2CPInstructionType.put( DataGen.FRAME_OPCODE , CPType.Rand); String2CPInstructionType.put( "ctable", CPType.Ctable); String2CPInstructionType.put( "ctableexpand", CPType.Ctable); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index b8fdfe8..4b77bff 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -284,7 +284,8 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( DataGen.RAND_OPCODE , SPType.Rand); String2SPInstructionType.put( DataGen.SEQ_OPCODE , SPType.Rand); String2SPInstructionType.put( DataGen.SAMPLE_OPCODE, SPType.Rand); - + String2SPInstructionType.put( DataGen.FRAME_OPCODE, SPType.Rand); + //ternary instruction opcodes String2SPInstructionType.put( "ctable", SPType.Ctable); String2SPInstructionType.put( "ctableexpand", SPType.Ctable); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java index 86b95a3..553577e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java @@ -29,12 +29,14 @@ import org.apache.sysds.hops.DataGenOp; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.DataGen; import org.apache.sysds.lops.Lop; +import org.apache.sysds.parser.DataExpression; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator; @@ -42,6 +44,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.runtime.util.UtilFunctions; +import java.util.Arrays; +import java.util.Random; public class DataGenCPInstruction extends UnaryCPInstruction { private static final Log LOG = LogFactory.getLog(DataGenCPInstruction.class.getName()); @@ -52,7 +56,7 @@ public class DataGenCPInstruction extends UnaryCPInstruction { private boolean minMaxAreDoubles; private final String minValueStr, maxValueStr; private final double minValue, maxValue, sparsity; - private final String pdf, pdfParams; + private final String pdf, pdfParams, frame_data, schema; private final long seed; private Long runtimeSeed; @@ -67,10 +71,11 @@ public class DataGenCPInstruction extends UnaryCPInstruction { private static final int SEED_POSITION_RAND = 8; private static final int SEED_POSITION_SAMPLE = 4; - private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, - CPOperand rows, CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, double sparsity, long seed, - String probabilityDensityFunction, String pdfParams, int k, - CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, boolean replace, String opcode, String istr) { + private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, + CPOperand rows, CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, double sparsity, + long seed, String probabilityDensityFunction, String pdfParams, int k, CPOperand seqFrom, CPOperand seqTo, + CPOperand seqIncr, boolean replace, String data, String schema, String opcode, String istr) + { super(CPType.Rand, op, in, out, opcode, istr); this.method = mthd; this.rows = rows; @@ -107,29 +112,38 @@ public class DataGenCPInstruction extends UnaryCPInstruction { this.seq_to = seqTo; this.seq_incr = seqIncr; this.replace = replace; + this.frame_data = data; + this.schema = schema; } private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, double sparsity, long seed, - String probabilityDensityFunction, String pdfParams, int k, String opcode, String istr) { + String probabilityDensityFunction, String pdfParams, int k, String opcode, String istr) { this(op, mthd, in, out, rows, cols, dims, blen, minValue, maxValue, sparsity, seed, - probabilityDensityFunction, pdfParams, k, null, null, null, false, opcode, istr); + probabilityDensityFunction, pdfParams, k, null, null, null, + false, null, null, opcode, istr); } private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, CPOperand dims, int blen, String maxValue, boolean replace, long seed, String opcode, String istr) { this(op, mthd, in, out, rows, cols, dims, blen, "0", maxValue, 1.0, seed, - null, null, 1, null, null, null, replace, opcode, istr); + null, null, 1, null, null, null, replace, null, null, opcode, istr); } private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, CPOperand dims, int blen, CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) { this(op, mthd, in, out, rows, cols, dims, blen, "0", "1", 1.0, -1, - null, null, 1, seqFrom, seqTo, seqIncr, false, opcode, istr); + null, null, 1, seqFrom, seqTo, seqIncr, false, null, null, opcode, istr); } private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand out, String opcode, String istr) { this(op, mthd, null, out, null, null, null, 0, "0", "0", 0, 0, - null, null, 1, null, null, null, false, opcode, istr); + null, null, 1, null, null, null, false, null, null, opcode, istr); + } + + public DataGenCPInstruction(Operator op, OpOpDG method, CPOperand out, CPOperand rows, CPOperand cols, String data, + String schema, String opcode, String str) { + this(op, method, null, out, rows, cols, null, 0, "0", "0", 0, 0, + null, null, 1, null, null, null, false, data, schema, opcode, str); } public long getRows() { @@ -217,6 +231,10 @@ public class DataGenCPInstruction extends UnaryCPInstruction { // 1 operand: outvar InstructionUtils.checkNumFields ( s, 1 ); } + else if ( opcode.equalsIgnoreCase(DataGen.FRAME_OPCODE) ) { + method = OpOpDG.FRAMEINIT; + InstructionUtils.checkNumFields ( s, 5 ); + } CPOperand out = new CPOperand(s[s.length-1]); Operator op = null; @@ -247,15 +265,23 @@ public class DataGenCPInstruction extends UnaryCPInstruction { return new DataGenCPInstruction(op, method, null, out, rows, cols, dims, blen, s[5 - missing], s[6 - missing], sparsity, seed, pdf, pdfParams, k, opcode, str); } - else if ( method == OpOpDG.SEQ) + else if ( method == OpOpDG.SEQ) { int blen = Integer.parseInt(s[3]); CPOperand from = new CPOperand(s[4]); CPOperand to = new CPOperand(s[5]); CPOperand incr = new CPOperand(s[6]); - + return new DataGenCPInstruction(op, method, null, out, null, null, null, blen, from, to, incr, opcode, str); } + else if ( method == OpOpDG.FRAMEINIT) + { + String data = s[1]; + CPOperand rows = new CPOperand(s[2]); + CPOperand cols = new CPOperand(s[3]); + String valueType = s[4]; + return new DataGenCPInstruction(op, method, out, rows, cols, data, valueType, opcode, str); + } else if ( method == OpOpDG.SAMPLE) { CPOperand rows = new CPOperand(s[2]); @@ -282,6 +308,7 @@ public class DataGenCPInstruction extends UnaryCPInstruction { MatrixBlock soresBlock = null; TensorBlock tensorBlock = null; ScalarObject soresScalar = null; + FrameBlock soresFrame = null; //process specific datagen operator if ( method == OpOpDG.RAND ) { @@ -369,7 +396,41 @@ public class DataGenCPInstruction extends UnaryCPInstruction { else if ( method == OpOpDG.TIME ) { soresScalar = new IntObject(System.nanoTime()); } - + else if(method == OpOpDG.FRAMEINIT) + { + int lrows = (int) ec.getScalarInput(rows).getLongValue(); + int lcols = (int) ec.getScalarInput(cols).getLongValue(); + String schemaValues[] = schema.split(DataExpression.DELIM_NA_STRING_SEP); + ValueType[] vt = schemaValues[0].equals(DataExpression.DEFAULT_SCHEMAPARAM) ? + UtilFunctions.nCopies(lcols, ValueType.STRING) : + UtilFunctions.stringToValueType(schemaValues); + int schemaLength = vt.length; + if(schemaLength != lcols) + throw new DMLRuntimeException("schema-dimension mismatch"); + + if(frame_data.equals("")) { + //TODO fix hard-coded seed, consistently with sparse frame init + soresFrame = UtilFunctions.generateRandomFrameBlock(lrows, lcols, vt, new Random(10)); + } + else { + String[] data = frame_data.split(DataExpression.DELIM_NA_STRING_SEP); + if(data.length != schemaLength && data.length > 1) + throw new DMLRuntimeException("data values should be equal to number of columns," + + " or a single values for all columns"); + if(data.length > 1) { + soresFrame = new FrameBlock(vt); + for(int i = 0; i < lrows; i++) + soresFrame.appendRow(data); + } + else { + soresFrame = new FrameBlock(vt); + String[] data1 = new String[lcols]; + Arrays.fill(data1, frame_data); + for(int i = 0; i < lrows; i++) + soresFrame.appendRow(data1); + } + } + } if( output.isMatrix() ) { //guarded sparse block representation change if( soresBlock.getInMemorySize() < OptimizerUtils.SAFE_REP_CHANGE_THRES ) @@ -386,6 +447,8 @@ public class DataGenCPInstruction extends UnaryCPInstruction { } else if( output.isScalar() ) ec.setScalarOutput(output.getName(), soresScalar); + else if (output.isFrame()) + ec.setFrameOutput(output.getName(), soresFrame); } private static void checkValidDimensions(long rows, long cols) { @@ -444,4 +507,5 @@ public class DataGenCPInstruction extends UnaryCPInstruction { new CPOperand(ec.getScalarInput(op)).getLineageLiteral()); return inst; } + } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java index dac0aa5..3d0c2e5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java @@ -50,6 +50,7 @@ import org.apache.sysds.hops.DataGenOp; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.DataGen; import org.apache.sysds.lops.Lop; +import org.apache.sysds.parser.DataExpression; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; @@ -62,6 +63,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixCell; @@ -90,7 +92,7 @@ public class RandSPInstruction extends UnarySPInstruction { private final double minValue, maxValue; private final String minValueStr, maxValueStr; private final double sparsity; - private final String pdf, pdfParams; + private final String pdf, pdfParams, frame_data, schema; private long seed = 0; private final String dir; private final CPOperand seq_from, seq_to, seq_incr; @@ -104,9 +106,10 @@ public class RandSPInstruction extends UnarySPInstruction { private static final int SEED_POSITION_SAMPLE = 4; private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, - CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, - double sparsity, long seed, String dir, String probabilityDensityFunction, String pdfParams, - CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, boolean replace, String opcode, String istr) + CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, double sparsity, + long seed, String dir, String probabilityDensityFunction, String pdfParams, CPOperand seqFrom, + CPOperand seqTo, CPOperand seqIncr, boolean replace, String fdata, + String schema, String opcode, String istr) { super(SPType.Rand, op, in, out, opcode, istr); this._method = mthd; @@ -131,7 +134,6 @@ public class RandSPInstruction extends UnarySPInstruction { } minDouble = -1; maxDouble = -1; - //minMaxAreDoubles = false; } this.minValue = minDouble; this.maxValue = maxDouble; @@ -144,6 +146,8 @@ public class RandSPInstruction extends UnarySPInstruction { this.seq_to = seqTo; this.seq_incr = seqIncr; this.replace = replace; + this.frame_data = fdata; + this.schema = schema; } private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, @@ -151,20 +155,30 @@ public class RandSPInstruction extends UnarySPInstruction { String dir, String probabilityDensityFunction, String pdfParams, String opcode, String istr) { this(op, mthd, in, out, rows, cols, dims, blen, minValue, maxValue, sparsity, seed, dir, - probabilityDensityFunction, pdfParams, null, null, null, false, opcode, istr); + probabilityDensityFunction, pdfParams, null, null, + null, false, null, null, opcode, istr); } private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, CPOperand dims, int blen, CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) { this(op, mthd, in, out, rows, cols, dims, blen, "-1", "-1", -1, -1, null, - null, null, seqFrom, seqTo, seqIncr, false, opcode, istr); + null, null, seqFrom, seqTo, seqIncr, false, + null, null, opcode, istr); } private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, CPOperand dims, int blen, String maxValue, boolean replace, long seed, String opcode, String istr) { this(op, mthd, in, out, rows, cols, dims, blen, "-1", maxValue, -1, seed, null, - null, null, null, null, null, replace, opcode, istr); + null, null, null, null, null, replace, + null, null, opcode, istr); + } + + private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand out, CPOperand rows, + CPOperand cols, String fdata, String schema, String opcode, String istr) { + this(op, mthd, null, out, rows, cols, null, 0, "0", "1", 0, + 0, null,null, null, null, null, + null, false,fdata, schema, opcode, istr); } public long getRows() { @@ -224,6 +238,11 @@ public class RandSPInstruction extends UnarySPInstruction { // 7 operands: range, size, replace, seed, blen, outvar InstructionUtils.checkNumFields ( str, 6 ); } + else if ( opcode.equalsIgnoreCase(DataGen.FRAME_OPCODE) ) { + method = OpOpDG.FRAMEINIT; + InstructionUtils.checkNumFields ( str, 6 ); + } + Operator op = null; // output is specified by the last operand @@ -264,8 +283,7 @@ public class RandSPInstruction extends UnarySPInstruction { return new RandSPInstruction(op, method, in, out, null, null, null, blen, from, to, incr, opcode, str); } - else if ( method == OpOpDG.SAMPLE) - { + else if ( method == OpOpDG.SAMPLE) { String max = !s[1].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? s[1] : "0"; CPOperand rows = new CPOperand(s[2]); @@ -279,6 +297,14 @@ public class RandSPInstruction extends UnarySPInstruction { return new RandSPInstruction(op, method, null, out, rows, cols, null, blen, max, replace, seed, opcode, str); } + else if ( method == OpOpDG.FRAMEINIT) { + String data = s[1]; + CPOperand rows = new CPOperand(s[2]); + CPOperand cols = new CPOperand(s[3]); + String valueType = s[4]; + return new RandSPInstruction(op, method, out, rows, cols, data, valueType, opcode, str); + } + else throw new DMLRuntimeException("Unrecognized data generation method: " + method); } @@ -292,17 +318,101 @@ public class RandSPInstruction extends UnarySPInstruction { case RAND: generateRandData(sec); break; case SEQ: generateSequence(sec); break; case SAMPLE: generateSample(sec); break; + case FRAMEINIT: generateFrame(sec); break; default: throw new DMLRuntimeException("Invalid datagen method: "+_method); } } + private void generateFrame(SparkExecutionContext sec) { + long lrows = sec.getScalarInput(rows).getLongValue(); + long lcols = sec.getScalarInput(cols).getLongValue(); + String data = frame_data; + + //step 1: generate pseudo-random seed (because not specified) + long lSeed = generateRandomSeed(); + + if( LOG.isTraceEnabled() ) + LOG.trace("Process RandSPInstruction frame with seed = "+lSeed+"."); + + //step 2: seed generation + JavaPairRDD<Long, Long> seedsRDD = null; + Well1024a bigrand = LibMatrixDatagen.setupSeedsForRand(lSeed); + double totalSize = OptimizerUtils.estimatePartitionedSizeExactSparsity( lrows, lcols, -1, 1); + double hdfsBlkSize = InfrastructureAnalyzer.getHDFSBlockSize(); + int brlen = ConfigurationManager.getBlocksize(); + DataCharacteristics tmp = new MatrixCharacteristics(lrows, lcols, brlen); + + //a) in-memory seed rdd construction + if( tmp.getNumRowBlocks() < INMEMORY_NUMBLOCKS_THRESHOLD ) + { + ArrayList<Tuple2<Long, Long>> seeds = new ArrayList<>(); + for( long i=0; i<tmp.getNumRowBlocks(); i++ ) { + Long seedForBlock = bigrand.nextLong(); + seeds.add(new Tuple2<>(i*brlen+1, seedForBlock)); + } + + //for load balancing: degree of parallelism such that ~128MB per partition + int numPartitions = (int) Math.max(Math.min(totalSize/hdfsBlkSize, tmp.getNumRowBlocks()), 1); + + //create seeds rdd + seedsRDD = sec.getSparkContext().parallelizePairs(seeds, numPartitions); + } + //b) file-based seed rdd construction (for robustness wrt large number of blocks) + else + { + Path path = new Path(LibMatrixDatagen.generateUniqueSeedPath(dir)); + PrintWriter pw = null; + try + { + FileSystem fs = IOUtilFunctions.getFileSystem(path); + pw = new PrintWriter(fs.create(path)); + StringBuilder sb = new StringBuilder(); + for( long i=0; i<tmp.getNumRowBlocks(); i++ ) { + sb.append(i*brlen+1); + sb.append(','); + sb.append(bigrand.nextLong()); + pw.println(sb.toString()); + sb.setLength(0); + } + } + catch( IOException ex ) { + throw new DMLRuntimeException(ex); + } + finally { + IOUtilFunctions.closeSilently(pw); + } + + //for load balancing: degree of parallelism such that ~128MB per partition + int numPartitions = (int) Math.max(Math.min(totalSize/hdfsBlkSize, tmp.getNumRowBlocks()), 1); + + //create seeds rdd + seedsRDD = sec.getSparkContext() + .textFile(path.toString(), numPartitions) + .mapToPair(new ExtractFrameSeedTuple()); + } + + //prepare input arguments + String schemaValues[] = schema.split(DataExpression.DELIM_NA_STRING_SEP); + ValueType[] vt = (schemaValues[0].equals(DataExpression.DEFAULT_SCHEMAPARAM)) ? + UtilFunctions.nCopies((int)lcols, ValueType.STRING) : + UtilFunctions.stringToValueType(schemaValues); + if(vt.length != lcols) + throw new DMLRuntimeException("schema-dimension mismatch: "+vt.length+" vs "+lcols); + + //step 4: execute rand instruction over seed input + JavaPairRDD<Long, FrameBlock> out = seedsRDD + .mapToPair(new GenerateRandomFrameBlock(lrows, lcols, brlen, vt, data)); + + //step 5: output handling + sec.setRDDHandleForVariable(output.getName(), out); + } + private void generateRandData(SparkExecutionContext sec) { - if (output.getDataType() == DataType.MATRIX) { + if (output.getDataType() == DataType.MATRIX) generateRandDataMatrix(sec); - } else { + else generateRandDataTensor(sec); - } //reset runtime seed (e.g., when executed in loop) runtimeSeed = null; } @@ -800,6 +910,18 @@ public class RandSPInstruction extends UnarySPInstruction { } } + private static class ExtractFrameSeedTuple implements PairFunction<String, Long, Long> { + private static final long serialVersionUID = 3973794676854157100L; + + @Override + public Tuple2<Long, Long> call(String arg) + throws Exception + { + String[] parts = IOUtilFunctions.split(arg, ","); + Long ix = Long.parseLong(parts[0]); + return new Tuple2<>(ix,Long.parseLong(parts[1])); + } + } private static class ExtractMatrixSeedTuple implements PairFunction<String, MatrixIndexes, Long> { private static final long serialVersionUID = 3973794676854157101L; @@ -836,7 +958,62 @@ public class RandSPInstruction extends UnarySPInstruction { return Double.parseDouble(arg); } } + private static class GenerateRandomFrameBlock implements PairFunction<Tuple2<Long, Long>, Long, FrameBlock> + { + private static final long serialVersionUID = 1616346120426470173L; + + private final long _rlen; + private final long _clen; + private final int _brlen; + private final ValueType[] _schema; + private final String _data; + public GenerateRandomFrameBlock(long rlen, long clen, int brlen, ValueType[] schema, String fdata) { + _rlen = rlen; + _clen = clen; + _brlen = brlen; + _schema = schema; + _data = fdata; + } + + @Override + public Tuple2<Long, FrameBlock> call(Tuple2<Long, Long> kv) + throws Exception + { + //compute local block size: + Long ix = kv._1(); + long blockix = UtilFunctions.computeBlockIndex(ix, _brlen); + int lrlen = UtilFunctions.computeBlockSize(_rlen, blockix, _brlen); + //long seed = kv._2; + + FrameBlock out = null; + if(_data.equals("")) { + //TODO fix hard-coded seed + out = UtilFunctions.generateRandomFrameBlock((int)_rlen, (int)_clen, _schema, new Random(10)); + } + else { + String[] data = _data.split(DataExpression.DELIM_NA_STRING_SEP); + if(data.length != _schema.length && data.length > 1) + throw new DMLRuntimeException("data values should be equal " + + "to number of columns, or a single values for all columns"); + if(data.length > 1) { + out = new FrameBlock(_schema); + for(int i = 0; i < lrlen; i++) + out.appendRow(data); + } + else { + out = new FrameBlock(_schema); + String[] data1 = new String[(int)_clen]; + Arrays.fill(data1, _data); + for(int i = 0; i < lrlen; i++) + out.appendRow(data1); + } + } + + return new Tuple2<>(kv._1, out); + } + } + private static class GenerateRandomBlock implements PairFunction<Tuple2<MatrixIndexes, Long>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = 1616346120426470173L; @@ -911,7 +1088,7 @@ public class RandSPInstruction extends UnarySPInstruction { @Override public Tuple2<TensorIndexes, TensorBlock> call(Tuple2<TensorIndexes, Long> kv) - throws Exception + throws Exception { //compute local block size: TensorIndexes ix = kv._1(); diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java index f98cceb..a7fdaf4 100644 --- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java @@ -105,7 +105,7 @@ public class UtilFunctions { public static int nextIntPow2( int in ) { int expon = (in==0) ? 0 : 32-Integer.numberOfLeadingZeros(in-1); long pow2 = pow(2, expon); - return (int)((pow2>Integer.MAX_VALUE)?Integer.MAX_VALUE : pow2); + return (int)((pow2>Integer.MAX_VALUE)?Integer.MAX_VALUE : pow2); } public static long pow(int base, int exp) { @@ -835,5 +835,74 @@ public class UtilFunctions { .map(DATE_FORMATS::get).orElseThrow(() -> new NullPointerException("Unknown date format.")); } + /** + * Generates a random FrameBlock with given parameters. + * + * @param rows frame rows + * @param cols frame cols + * @param schema frame schema + * @param random random number generator + * @return FrameBlock + */ + public static FrameBlock generateRandomFrameBlock(int rows, int cols, ValueType[] schema, Random random){ + String[] names = new String[cols]; + for(int i = 0; i < cols; i++) + names[i] = schema[i].toString(); + FrameBlock frameBlock = new FrameBlock(schema, names); + frameBlock.ensureAllocatedColumns(rows); + for(int row = 0; row < rows; row++) + for(int col = 0; col < cols; col++) + frameBlock.set(row, col, generateRandomValueFromValueType(schema[col], random)); + return frameBlock; + } + /** + * Generates a random value for a given Value Type + * + * @param valueType the ValueType of which to generate the value + * @param random random number generator + * @return Object + */ + public static Object generateRandomValueFromValueType(ValueType valueType, Random random){ + switch (valueType){ + case FP32: return random.nextFloat(); + case FP64: return random.nextDouble(); + case INT32: return random.nextInt(); + case INT64: return random.nextLong(); + case BOOLEAN: return random.nextBoolean(); + case STRING: + return random.ints('a', 'z' + 1).limit(10) + .collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append) + .toString(); + default: + return null; + } + } + + /** + * Generates a ValueType array from a String array + * + * @param schemaValues the string schema of which to generate the ValueType + * @return ValueType[] + */ + public static ValueType[] stringToValueType(String[] schemaValues) { + ValueType[] vt = new ValueType[schemaValues.length]; + for(int i=0; i < schemaValues.length; i++) { + if(schemaValues[i].equalsIgnoreCase("STRING")) + vt[i] = ValueType.STRING; + else if (schemaValues[i].equalsIgnoreCase("FP64")) + vt[i] = ValueType.FP64; + else if (schemaValues[i].equalsIgnoreCase("FP32")) + vt[i] = ValueType.FP32; + else if (schemaValues[i].equalsIgnoreCase("INT64")) + vt[i] = ValueType.INT64; + else if (schemaValues[i].equalsIgnoreCase("INT32")) + vt[i] = ValueType.INT32; + else if (schemaValues[i].equalsIgnoreCase("BOOLEAN")) + vt[i] = ValueType.BOOLEAN; + else + throw new DMLRuntimeException("Invalid column schema. Allowed values are STRING, FP64, FP32, INT64, INT32 and Boolean"); + } + return vt; + } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java index 4c4dc20..d4ddfe4 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java @@ -38,19 +38,24 @@ public class BuiltinDBSCANTest extends AutomatedTestBase private final static double eps = 1e-3; private final static int rows = 1700; - //private final static double spDense = 0.99; private final static double epsDBSCAN = 1; private final static int minPts = 5; @Override - public void setUp() { addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"})); } + public void setUp() { + addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"})); + } @Test - public void testDBSCANDefaultCP() { runDBSCAN(true, ExecType.CP); } + public void testDBSCANDefaultCP() { + runDBSCAN(true, ExecType.CP); + } @Test - public void testDBSCANDefaultSP() { runDBSCAN(true, ExecType.SPARK); } + public void testDBSCANDefaultSP() { + runDBSCAN(true, ExecType.SPARK); + } private void runDBSCAN(boolean defaultProb, ExecType instType) { diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java new file mode 100644 index 0000000..ef9e8a6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.frame; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.io.FrameReaderFactory; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.TestConfiguration; +import org.junit.Test; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.matrix.data.FrameBlock; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; + +import java.util.Random; + +public class FrameConstructorTest extends AutomatedTestBase { + private final static String TEST_DIR = "functions/frame/"; + private final static String TEST_NAME = "FrameConstructorTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FrameConstructorTest.class.getSimpleName() + "/"; + + private final static int rows = 40; + private final static int cols = 4; + + private final static ValueType[] schemaStrings1 = new ValueType[]{ + ValueType.INT64, ValueType.STRING, ValueType.FP64, ValueType.BOOLEAN}; + + private final static ValueType[] schemaStrings2 = new ValueType[]{ + ValueType.INT64, ValueType.STRING, ValueType.FP64, ValueType.STRING}; + + private enum TestType { + NAMED, + NO_SCHEMA, + RANDOM_DATA, + SINGLE_DATA + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"})); + if (TEST_CACHE_ENABLED) { + setOutAndExpectedDeletionDisabled(true); + } + } + + @Test + public void testFrameNamedParam() { + FrameBlock exp = createExpectedFrame(schemaStrings1, false); + runFrameTest(TestType.NAMED, exp, Types.ExecMode.SINGLE_NODE); + } + + @Test + public void testFrameNamedParamSP() { + FrameBlock exp = createExpectedFrame(schemaStrings1, false); + runFrameTest(TestType.NAMED, exp, Types.ExecMode.SPARK); + } + + @Test + public void testNoSchema() { + FrameBlock exp = createExpectedFrame(schemaStrings2, false); + runFrameTest(TestType.NO_SCHEMA, exp, Types.ExecMode.SINGLE_NODE); + } + + @Test + public void testNoSchemaSP() { + FrameBlock exp = createExpectedFrame(schemaStrings2, false); + runFrameTest(TestType.NO_SCHEMA, exp, Types.ExecMode.SPARK); + } + + @Test + public void testRandData() { + FrameBlock exp = UtilFunctions.generateRandomFrameBlock(rows, cols, schemaStrings1, new Random(10)); + runFrameTest(TestType.RANDOM_DATA, exp, Types.ExecMode.SINGLE_NODE); + } + + @Test + public void testRandDataSP() { + FrameBlock exp = UtilFunctions.generateRandomFrameBlock(rows, cols, schemaStrings1, new Random(10)); + runFrameTest(TestType.RANDOM_DATA, exp, Types.ExecMode.SPARK); + } + + @Test + public void testSingleData() { + FrameBlock exp = createExpectedFrame(schemaStrings1, true); + runFrameTest(TestType.SINGLE_DATA, exp, Types.ExecMode.SINGLE_NODE); + } + + @Test + public void testSingleDataSP() { + FrameBlock exp = createExpectedFrame(schemaStrings1, true); + runFrameTest(TestType.SINGLE_DATA, exp, Types.ExecMode.SPARK); + } + + private void runFrameTest(TestType type, FrameBlock expectedOutput, Types.ExecMode et) { + Types.ExecMode platformOld = setExecMode(et); + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + + try { + //setup testcase + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-args", String.valueOf(type), output("F2")}; + + runTest(true, false, null, -1); + FrameBlock fB = FrameReaderFactory + .createFrameReader(Types.FileFormat.CSV) + .readFrameFromHDFS(output("F2"), rows, cols); + String[][] R1 = DataConverter.convertToStringFrame(expectedOutput); + String[][] R2 = DataConverter.convertToStringFrame(fB); + TestUtils.compareFrames(R1, R2, R1.length, R1[0].length); + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true; + OptimizerUtils.ALLOW_OPERATOR_FUSION = true; + } + } + + private static FrameBlock createExpectedFrame(ValueType[] schema, boolean constant) { + FrameBlock exp = new FrameBlock(schema); + String[] out = constant ? + new String[]{"1", "1", "1", "1"} : + new String[]{"1", "abc", "2.5", "TRUE"}; + for(int i=0; i<rows; i++) + exp.appendRow(out); + return exp; + } +} diff --git a/src/test/scripts/functions/frame/FrameConstructorTest.dml b/src/test/scripts/functions/frame/FrameConstructorTest.dml new file mode 100644 index 0000000..8762d30 --- /dev/null +++ b/src/test/scripts/functions/frame/FrameConstructorTest.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +print("param 1 "+$1) +if($1 == "NAMED") + f1 = frame(data=["1", "abc", "2.5", "TRUE"], rows=40, cols=4, schema=["INT64", "STRING", "FP64", "BOOLEAN"]) # all named +if($1 == "NO_SCHEMA") + f1 = frame(data=["1", "abc", "2.5", "TRUE"], rows=40, cols=4) # no schema +if($1 == "RANDOM_DATA") + f1 = frame("", rows=40, cols=4, schema=["INT64", "STRING", "FP64", "BOOLEAN"]) # no data +if($1 == "SINGLE_DATA") + f1 = frame(1, rows=40, cols=4, schema=["INT64", "STRING", "FP64", "BOOLEAN"]) # no data + +# f1 = frame(1, 4, 3) # unnamed parameters not working +write(f1, $2, format="csv")