[SYSTEMML-2214] New builtin functions lower.tri and upper.tri (CP/SP) This patch introduces the new builtin functions lower.tri and upper.tri for extracting triangular matrices for both CP and SP backends as well as their compiler/runtime integration.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/4eb1b935 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/4eb1b935 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/4eb1b935 Branch: refs/heads/master Commit: 4eb1b935b66dd787a9d417cbc128b5113d49c461 Parents: 17ccc09 Author: Matthias Boehm <[email protected]> Authored: Tue Mar 27 20:01:52 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Mar 27 20:01:52 2018 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 5 +- .../sysml/hops/ParameterizedBuiltinOp.java | 9 + .../apache/sysml/lops/ParameterizedBuiltin.java | 16 +- .../org/apache/sysml/parser/DMLTranslator.java | 12 +- .../org/apache/sysml/parser/Expression.java | 2 +- .../ParameterizedBuiltinFunctionExpression.java | 83 ++++-- .../java/org/apache/sysml/parser/dml/Dml.g4 | 1 + .../java/org/apache/sysml/parser/pydml/Pydml.g4 | 1 + .../functionobjects/ParameterizedBuiltin.java | 10 +- .../instructions/CPInstructionParser.java | 28 +- .../instructions/SPInstructionParser.java | 14 +- .../cp/ParameterizedBuiltinCPInstruction.java | 18 +- .../ParameterizedBuiltinSPInstruction.java | 104 ++++++-- .../sysml/runtime/matrix/data/MatrixBlock.java | 63 ++++- .../unary/matrix/ExtractTriangularTest.java | 267 +++++++++++++++++++ .../functions/unary/matrix/ExtractLowerTri.R | 34 +++ .../functions/unary/matrix/ExtractLowerTri.dml | 24 ++ .../functions/unary/matrix/ExtractUpperTri.R | 34 +++ .../functions/unary/matrix/ExtractUpperTri.dml | 24 ++ .../functions/unary/matrix/ZPackageSuite.java | 1 + 20 files changed, 669 insertions(+), 81 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index dd168d7..1a93a27 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1114,7 +1114,8 @@ public abstract class Hop implements ParseInfo } public enum ParamBuiltinOp { - INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND, + INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND, + LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, TOSTRING } @@ -1418,6 +1419,8 @@ public abstract class Hop implements ParseInfo HopsParameterizedBuiltinLops.put(ParamBuiltinOp.RMEMPTY, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.RMEMPTY); HopsParameterizedBuiltinLops.put(ParamBuiltinOp.REPLACE, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.REPLACE); HopsParameterizedBuiltinLops.put(ParamBuiltinOp.REXPAND, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.REXPAND); + HopsParameterizedBuiltinLops.put(ParamBuiltinOp.LOWER_TRI, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.LOWER_TRI); + HopsParameterizedBuiltinLops.put(ParamBuiltinOp.UPPER_TRI, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.UPPER_TRI); HopsParameterizedBuiltinLops.put(ParamBuiltinOp.TRANSFORMAPPLY, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.TRANSFORMAPPLY); HopsParameterizedBuiltinLops.put(ParamBuiltinOp.TRANSFORMDECODE, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.TRANSFORMDECODE); HopsParameterizedBuiltinLops.put(ParamBuiltinOp.TRANSFORMCOLMAP, org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes.TRANSFORMCOLMAP); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java index 1010db1..b94ff5c 100644 --- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java @@ -185,6 +185,8 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop case CDF: case INVCDF: case REPLACE: + case LOWER_TRI: + case UPPER_TRI: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: @@ -1117,6 +1119,13 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop setNnz( target.getNnz() ); break; } + case LOWER_TRI: + case UPPER_TRI: { + Hop target = getTargetHop(); + setDim1(target.getDim1()); + setDim2(target.getDim2()); + break; + } case REPLACE: { //dimensions are exactly known from input, sparsity might increase/decrease if pattern/replacement 0 Hop target = getTargetHop(); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java index 1670412..f011ba4 100644 --- a/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java @@ -36,7 +36,7 @@ import org.apache.sysml.parser.Expression.ValueType; public class ParameterizedBuiltin extends Lop { public enum OperationTypes { - CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, + CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, TOSTRING } @@ -176,6 +176,20 @@ public class ParameterizedBuiltin extends Lop break; } + case LOWER_TRI: { + sb.append( "lowertri" ); + sb.append( OPERAND_DELIMITOR ); + sb.append(compileGenericParamMap(_inputParams)); + break; + } + + case UPPER_TRI: { + sb.append( "uppertri" ); + sb.append( OPERAND_DELIMITOR ); + sb.append(compileGenericParamMap(_inputParams)); + break; + } + case REXPAND: sb.append("rexpand"); sb.append(OPERAND_DELIMITOR); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index e0510b5..3250883 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2095,15 +2095,23 @@ public class DMLTranslator target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.REPLACE, paramHops); break; + case LOWER_TRI: + currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), + target.getValueType(), ParamBuiltinOp.LOWER_TRI, paramHops); + break; + + case UPPER_TRI: + currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), + target.getValueType(), ParamBuiltinOp.UPPER_TRI, paramHops); + break; + case ORDER: ArrayList<Hop> inputs = new ArrayList<>(); inputs.add(paramHops.get("target")); inputs.add(paramHops.get("by")); inputs.add(paramHops.get("decreasing")); inputs.add(paramHops.get("index.return")); - currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.SORT, inputs); - break; case TRANSFORMAPPLY: http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index dbbe8b8..e20a908 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -152,7 +152,7 @@ public abstract class Expression implements ParseInfo * Parameterized built-in function operators. */ public enum ParameterizedBuiltinFunctionOp { - GROUPEDAGG, RMEMPTY, REPLACE, ORDER, + GROUPEDAGG, RMEMPTY, REPLACE, ORDER, LOWER_TRI, UPPER_TRI, // Distribution Functions CDF, INVCDF, PNORM, QNORM, PT, QT, PF, QF, PCHISQ, QCHISQ, PEXP, QEXP, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMENCODE, TRANSFORMCOLMAP, TRANSFORMMETA, http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java index e80b46f..6f9d6f7 100644 --- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java @@ -47,11 +47,13 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier private static HashMap<String, Expression.ParameterizedBuiltinFunctionOp> opcodeMap; static { opcodeMap = new HashMap<>(); - opcodeMap.put("aggregate", Expression.ParameterizedBuiltinFunctionOp.GROUPEDAGG); + opcodeMap.put("aggregate", Expression.ParameterizedBuiltinFunctionOp.GROUPEDAGG); opcodeMap.put("groupedAggregate", Expression.ParameterizedBuiltinFunctionOp.GROUPEDAGG); - opcodeMap.put("removeEmpty",Expression.ParameterizedBuiltinFunctionOp.RMEMPTY); - opcodeMap.put("replace", Expression.ParameterizedBuiltinFunctionOp.REPLACE); - opcodeMap.put("order", Expression.ParameterizedBuiltinFunctionOp.ORDER); + opcodeMap.put("removeEmpty", Expression.ParameterizedBuiltinFunctionOp.RMEMPTY); + opcodeMap.put("replace", Expression.ParameterizedBuiltinFunctionOp.REPLACE); + opcodeMap.put("order", Expression.ParameterizedBuiltinFunctionOp.ORDER); + opcodeMap.put("lower.tri", Expression.ParameterizedBuiltinFunctionOp.LOWER_TRI); + opcodeMap.put("upper.tri", Expression.ParameterizedBuiltinFunctionOp.UPPER_TRI); // Distribution Functions opcodeMap.put("cdf", Expression.ParameterizedBuiltinFunctionOp.CDF); @@ -86,6 +88,8 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier pbHopMap.put(Expression.ParameterizedBuiltinFunctionOp.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG); pbHopMap.put(Expression.ParameterizedBuiltinFunctionOp.RMEMPTY, ParamBuiltinOp.RMEMPTY); pbHopMap.put(Expression.ParameterizedBuiltinFunctionOp.REPLACE, ParamBuiltinOp.REPLACE); + pbHopMap.put(Expression.ParameterizedBuiltinFunctionOp.LOWER_TRI, ParamBuiltinOp.LOWER_TRI); + pbHopMap.put(Expression.ParameterizedBuiltinFunctionOp.UPPER_TRI, ParamBuiltinOp.UPPER_TRI); // For order, a ReorgOp is constructed with ReorgOp.SORT type pbHopMap.put(Expression.ParameterizedBuiltinFunctionOp.ORDER, ParamBuiltinOp.INVALID); @@ -246,6 +250,11 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier validateTransformMeta(output, conditional); break; + case LOWER_TRI: + case UPPER_TRI: + validateExtractTriangular(output, getOpCode(), conditional); + break; + case TOSTRING: validateCastAsString(output, conditional); break; @@ -388,15 +397,36 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier } } + private void validateExtractTriangular(DataIdentifier output, ParameterizedBuiltinFunctionOp op, boolean conditional) { + + //check for invalid parameters + Set<String> valid = UtilFunctions.asSet("target", "diag", "values"); + Set<String> invalid = _varParams.keySet().stream() + .filter(k -> !valid.contains(k)).collect(Collectors.toSet()); + if( !invalid.isEmpty() ) + raiseValidateError("Invalid parameters for " + op.name() + ": " + + Arrays.toString(invalid.toArray(new String[0])), false); + + //check existence and correctness of arguments + checkTargetParam(getVarParam("target"), conditional); + checkOptionalBooleanParam(getVarParam("diag"), "diag", conditional); + checkOptionalBooleanParam(getVarParam("values"), "values", conditional); + if( getVarParam("diag") == null ) //default handling + _varParams.put("diag", new BooleanIdentifier(false)); + if( getVarParam("values") == null ) //default handling + _varParams.put("values", new BooleanIdentifier(false)); + + // Output is a matrix with unknown dims + Identifier in = getVarParam("target").getOutput(); + output.setDataType(DataType.MATRIX); + output.setValueType(ValueType.DOUBLE); + output.setDimensions(in.getDim1(), in.getDim2()); + } + private void validateReplace(DataIdentifier output, boolean conditional) { //check existence and correctness of arguments Expression target = getVarParam("target"); - if( target==null ) { - raiseValidateError("Named parameter 'target' missing. Please specify the input matrix.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } - else if( target.getOutput().getDataType() != DataType.MATRIX ){ - raiseValidateError("Input matrix 'target' is of type '"+target.getOutput().getDataType()+"'. Please specify the input matrix.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } + checkTargetParam(target, conditional); Expression pattern = getVarParam("pattern"); if( pattern==null ) { @@ -422,13 +452,8 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier private void validateOrder(DataIdentifier output, boolean conditional) { //check existence and correctness of arguments - Expression target = getVarParam("target"); //[MANDATORY] TARGET - if( target==null ) { - raiseValidateError("Named parameter 'target' missing. Please specify the input matrix.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } - else if( target.getOutput().getDataType() != DataType.MATRIX ){ - raiseValidateError("Input matrix 'target' is of type '"+target.getOutput().getDataType()+"'. Please specify the input matrix.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } + Expression target = getVarParam("target"); + checkTargetParam(target, conditional); //check for unsupported parameters for(String param : getVarParams().keySet()) @@ -482,13 +507,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier + Arrays.toString(invalid.toArray(new String[0])), false); //check existence and correctness of arguments - Expression target = getVarParam("target"); - if( target==null ) { - raiseValidateError("Named parameter 'target' missing. Please specify the input matrix.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } - else if( target.getOutput().getDataType() != DataType.MATRIX ){ - raiseValidateError("Input matrix 'target' is of type '"+target.getOutput().getDataType()+"'. Please specify the input matrix.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } + checkTargetParam(getVarParam("target"), conditional); Expression margin = getVarParam("margin"); if( margin==null ){ @@ -622,6 +641,22 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier output.setDimensions(outputDim1, outputDim2); } + private void checkTargetParam(Expression target, boolean conditional) { + if( target==null ) + raiseValidateError("Named parameter 'target' missing. Please specify the input matrix.", + conditional, LanguageErrorCodes.INVALID_PARAMETERS); + else if( target.getOutput().getDataType() != DataType.MATRIX ) + raiseValidateError("Input matrix 'target' is of type '"+target.getOutput().getDataType() + +"'. Please specify the input matrix.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); + } + + private void checkOptionalBooleanParam(Expression param, String name, boolean conditional) { + if( param!=null && (!param.getOutput().getDataType().isScalar() || param.getOutput().getValueType() != ValueType.BOOLEAN) ){ + raiseValidateError("Boolean parameter '"+name+"' is of type "+param.getOutput().getDataType() + +"["+param.getOutput().getValueType()+"].", conditional, LanguageErrorCodes.INVALID_PARAMETERS); + } + } + private void validateDistributionFunctions(DataIdentifier output, boolean conditional) { // CDF and INVCDF expects one unnamed parameter, it must be renamed as "quantile" // (i.e., we must compute P(X <= x) where x is called as "quantile" ) http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/parser/dml/Dml.g4 ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 index 9491855..8723b3f 100644 --- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 +++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 @@ -183,6 +183,7 @@ ID : (ALPHABET (ALPHABET|DIGIT|'_')* '::')? ALPHABET (ALPHABET|DIGIT|'_')* // Special ID cases: // | 'matrix' // --> This is a special case which causes lot of headache | 'as.scalar' | 'as.matrix' | 'as.frame' | 'as.double' | 'as.integer' | 'as.logical' | 'index.return' | 'empty.return' | 'lower.tail' + | 'lower.tri' | 'upper.tri' ; // Unfortunately, we have datatype name clashing with builtin function name: matrix :( // Therefore, ugly work around for checking datatype http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/parser/pydml/Pydml.g4 ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/pydml/Pydml.g4 b/src/main/java/org/apache/sysml/parser/pydml/Pydml.g4 index 34d0c34..d0211c6 100644 --- a/src/main/java/org/apache/sysml/parser/pydml/Pydml.g4 +++ b/src/main/java/org/apache/sysml/parser/pydml/Pydml.g4 @@ -303,6 +303,7 @@ ID : (ALPHABET (ALPHABET|DIGIT|'_')* '.')? ALPHABET (ALPHABET|DIGIT|'_')* // | 'matrix' // --> This is a special case which causes lot of headache // | 'scalar' | 'float' | 'int' | 'bool' // corresponds to as.scalar, as.double, as.integer and as.logical | 'index.return' | 'empty.return' + | 'lower.tri' | 'upper.tri' ; // Unfortunately, we have datatype name clashing with builtin function name: matrix :( // Therefore, ugly work around for checking datatype http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java b/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java index 909fdc8..302f680 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java @@ -44,7 +44,7 @@ public class ParameterizedBuiltin extends ValueFunction private static final long serialVersionUID = -5966242955816522697L; public enum ParameterizedBuiltinCode { - CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, + CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE } public enum ProbabilityDistributionCode { INVALID, NORMAL, EXP, CHISQ, F, T } @@ -59,6 +59,8 @@ public class ParameterizedBuiltin extends ValueFunction String2ParameterizedBuiltinCode.put( "invcdf", ParameterizedBuiltinCode.INVCDF); String2ParameterizedBuiltinCode.put( "rmempty", ParameterizedBuiltinCode.RMEMPTY); String2ParameterizedBuiltinCode.put( "replace", ParameterizedBuiltinCode.REPLACE); + String2ParameterizedBuiltinCode.put( "lowertri", ParameterizedBuiltinCode.LOWER_TRI); + String2ParameterizedBuiltinCode.put( "uppertri", ParameterizedBuiltinCode.UPPER_TRI); String2ParameterizedBuiltinCode.put( "rexpand", ParameterizedBuiltinCode.REXPAND); String2ParameterizedBuiltinCode.put( "transformapply", ParameterizedBuiltinCode.TRANSFORMAPPLY); String2ParameterizedBuiltinCode.put( "transformdecode", ParameterizedBuiltinCode.TRANSFORMDECODE); @@ -162,6 +164,12 @@ public class ParameterizedBuiltin extends ValueFunction case REPLACE: return new ParameterizedBuiltin(ParameterizedBuiltinCode.REPLACE); + case LOWER_TRI: + return new ParameterizedBuiltin(ParameterizedBuiltinCode.LOWER_TRI); + + case UPPER_TRI: + return new ParameterizedBuiltin(ParameterizedBuiltinCode.UPPER_TRI); + case REXPAND: return new ParameterizedBuiltin(ParameterizedBuiltinCode.REXPAND); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 6b875d2..00ad286 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -67,10 +67,7 @@ import org.apache.sysml.runtime.instructions.cpfile.MatrixIndexingCPFileInstruct public class CPInstructionParser extends InstructionParser { - public static final HashMap<String, CPType> String2CPInstructionType; - public static final HashMap<String, CPType> String2CPFileInstructionType; - static { String2CPInstructionType = new HashMap<>(); String2CPInstructionType.put( "ba+*" , CPType.AggregateBinary); @@ -185,17 +182,19 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "eval" , CPType.BuiltinNary); // Parameterized Builtin Functions - String2CPInstructionType.put( "cdf" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "invcdf" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "groupedagg" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "rmempty" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "replace" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "rexpand" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "toString" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "transformapply",CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "cdf", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "invcdf", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "groupedagg", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "rmempty" , CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "replace", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "lowertri", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "uppertri", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "rexpand", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "toString", CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "transformapply", CPType.ParameterizedBuiltin); String2CPInstructionType.put( "transformdecode",CPType.ParameterizedBuiltin); String2CPInstructionType.put( "transformcolmap",CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "transformmeta",CPType.ParameterizedBuiltin); + String2CPInstructionType.put( "transformmeta", CPType.ParameterizedBuiltin); String2CPInstructionType.put( "transformencode",CPType.MultiReturnParameterizedBuiltin); // Ternary Instruction Opcodes @@ -286,11 +285,6 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "partition", CPType.Partition); String2CPInstructionType.put( "compress", CPType.Compression); String2CPInstructionType.put( "spoof", CPType.SpoofFused); - - - //CP FILE instruction - String2CPFileInstructionType = new HashMap<>(); - String2CPFileInstructionType.put( "rmempty" , CPType.ParameterizedBuiltin); } public static CPInstruction parseSingleInstruction (String str ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index 8d0fb54..799a77a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -245,12 +245,14 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "sigmoid", SPType.Unary); // Parameterized Builtin Functions - String2SPInstructionType.put( "groupedagg" , SPType.ParameterizedBuiltin); - String2SPInstructionType.put( "mapgroupedagg", SPType.ParameterizedBuiltin); - String2SPInstructionType.put( "rmempty" , SPType.ParameterizedBuiltin); - String2SPInstructionType.put( "replace" , SPType.ParameterizedBuiltin); - String2SPInstructionType.put( "rexpand" , SPType.ParameterizedBuiltin); - String2SPInstructionType.put( "transformapply",SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "groupedagg", SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "mapgroupedagg", SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "rmempty", SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "replace", SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "rexpand", SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "lowertri", SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "uppertri", SPType.ParameterizedBuiltin); + String2SPInstructionType.put( "transformapply", SPType.ParameterizedBuiltin); String2SPInstructionType.put( "transformdecode",SPType.ParameterizedBuiltin); String2SPInstructionType.put( "transformencode",SPType.MultiReturnBuiltin); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index c5909ac..fe18dbc 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -126,7 +126,9 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction } else if( opcode.equalsIgnoreCase("rmempty") || opcode.equalsIgnoreCase("replace") - || opcode.equalsIgnoreCase("rexpand") ) + || opcode.equalsIgnoreCase("rexpand") + || opcode.equalsIgnoreCase("lowertri") + || opcode.equalsIgnoreCase("uppertri")) { func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); return new ParameterizedBuiltinCPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str); @@ -211,15 +213,19 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction ec.releaseMatrixInput(params.get("select"), getExtendedOpcode()); } else if ( opcode.equalsIgnoreCase("replace") ) { - // acquire locks MatrixBlock target = ec.getMatrixInput(params.get("target"), getExtendedOpcode()); - - // compute the result double pattern = Double.parseDouble( params.get("pattern") ); double replacement = Double.parseDouble( params.get("replacement") ); MatrixBlock ret = (MatrixBlock) target.replaceOperations(new MatrixBlock(), pattern, replacement); - - //release locks + ec.setMatrixOutput(output.getName(), ret, getExtendedOpcode()); + ec.releaseMatrixInput(params.get("target"), getExtendedOpcode()); + } + else if ( opcode.equals("lowertri") || opcode.equals("uppertri")) { + MatrixBlock target = ec.getMatrixInput(params.get("target"), getExtendedOpcode()); + boolean lower = opcode.equals("lowertri"); + boolean diag = Boolean.parseBoolean(params.get("diag")); + boolean values = Boolean.parseBoolean(params.get("values")); + MatrixBlock ret = (MatrixBlock) target.extractTriangular(new MatrixBlock(), lower, diag, values); ec.setMatrixOutput(output.getName(), ret, getExtendedOpcode()); ec.releaseMatrixInput(params.get("target"), getExtendedOpcode()); } http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java index 0b37bd8..425dbd3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java @@ -48,6 +48,7 @@ import org.apache.sysml.runtime.functionobjects.ValueFunction; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction; +import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupBroadcast; import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupJoin; @@ -166,6 +167,8 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction } else if( opcode.equalsIgnoreCase("rexpand") || opcode.equalsIgnoreCase("replace") + || opcode.equalsIgnoreCase("lowertri") + || opcode.equalsIgnoreCase("uppertri") || opcode.equalsIgnoreCase("transformapply") || opcode.equalsIgnoreCase("transformdecode")) { @@ -368,25 +371,42 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction } } else if ( opcode.equalsIgnoreCase("replace") ) - { - //get input rdd handle - String rddVar = params.get("target"); - JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar ); - MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar); + { + JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(params.get("target")); + MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(params.get("target")); //execute replace operation double pattern = Double.parseDouble( params.get("pattern") ); double replacement = Double.parseDouble( params.get("replacement") ); JavaPairRDD<MatrixIndexes,MatrixBlock> out = - in1.mapValues(new RDDReplaceFunction(pattern, replacement)); + in1.mapValues(new RDDReplaceFunction(pattern, replacement)); //store output rdd handle sec.setRDDHandleForVariable(output.getName(), out); - sec.addLineageRDD(output.getName(), rddVar); + sec.addLineageRDD(output.getName(), params.get("target")); //update output statistics (required for correctness) MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); - mcOut.set(mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock(), (pattern!=0 && replacement!=0)?mcIn.getNonZeros():-1); + mcOut.set(mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), + mcIn.getColsPerBlock(), (pattern!=0 && replacement!=0)?mcIn.getNonZeros():-1); + } + else if ( opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri") ) + { + JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(params.get("target")); + MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(params.get("target")); + boolean lower = opcode.equalsIgnoreCase("lowertri"); + boolean diag = Boolean.parseBoolean(params.get("diag")); + boolean values = Boolean.parseBoolean(params.get("values")); + + JavaPairRDD<MatrixIndexes,MatrixBlock> out = in1.mapPartitionsToPair( + new RDDExtractTriangularFunction(lower, diag, values), true); + + //store output rdd handle + sec.setRDDHandleForVariable(output.getName(), out); + sec.addLineageRDD(output.getName(), params.get("target")); + + //update output statistics (required for correctness) + sec.getMatrixCharacteristics(output.getName()).setDimension(mcIn.getRows(), mcIn.getCols()); } else if ( opcode.equalsIgnoreCase("rexpand") ) { @@ -414,7 +434,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction //execute rexpand rows/cols operation (no shuffle required because outputs are //block-aligned with the input, i.e., one input block generates n output blocks) JavaPairRDD<MatrixIndexes,MatrixBlock> out = in - .flatMapToPair(new RDDRExpandFunction(maxVal, dirRows, cast, ignore, brlen, bclen)); + .flatMapToPair(new RDDRExpandFunction(maxVal, dirRows, cast, ignore, brlen, bclen)); //store output rdd handle sec.setRDDHandleForVariable(output.getName(), out); @@ -430,7 +450,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction FrameObject fo = sec.getFrameObject(params.get("target")); JavaPairRDD<Long,FrameBlock> in = (JavaPairRDD<Long,FrameBlock>) sec.getRDDHandleForFrameObject(fo, InputInfo.BinaryBlockInputInfo); - FrameBlock meta = sec.getFrameInput(params.get("meta")); + FrameBlock meta = sec.getFrameInput(params.get("meta")); MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(params.get("target")); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); String[] colnames = !TfMetaUtils.isIDSpec(params.get("spec")) ? @@ -466,7 +486,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction //get input RDD and meta data JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(params.get("target")); MatrixCharacteristics mc = sec.getMatrixCharacteristics(params.get("target")); - FrameBlock meta = sec.getFrameInput(params.get("meta")); + FrameBlock meta = sec.getFrameInput(params.get("meta")); String[] colnames = meta.getColumnNames(); //reblock if necessary (clen > bclen) @@ -494,25 +514,69 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction } } - public static class RDDReplaceFunction implements Function<MatrixBlock, MatrixBlock> - { + public static class RDDReplaceFunction implements Function<MatrixBlock, MatrixBlock> { private static final long serialVersionUID = 6576713401901671659L; - private double _pattern; private double _replacement; - public RDDReplaceFunction(double pattern, double replacement) - { + public RDDReplaceFunction(double pattern, double replacement) { _pattern = pattern; _replacement = replacement; } @Override - public MatrixBlock call(MatrixBlock arg0) - throws Exception - { + public MatrixBlock call(MatrixBlock arg0) { return (MatrixBlock) arg0.replaceOperations(new MatrixBlock(), _pattern, _replacement); - } + } + } + + private static class RDDExtractTriangularFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> + { + private static final long serialVersionUID = 2754868819184155702L; + private final boolean _lower, _diag, _values; + + public RDDExtractTriangularFunction(boolean lower, boolean diag, boolean values) { + _lower = lower; + _diag = diag; + _values = values; + } + + @Override + public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) { + return new ExtractTriangularIterator(arg0); + } + + private class ExtractTriangularIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> + { + public ExtractTriangularIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) { + super(in); + } + + @Override + protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) { + MatrixIndexes ix = arg._1(); + MatrixBlock mb = arg._2(); + + //handle cases of pass-through and reset block + if( (_lower && ix.getRowIndex() > ix.getColumnIndex()) + || (!_lower && ix.getRowIndex() < ix.getColumnIndex()) ) { + return _values ? arg : new Tuple2<MatrixIndexes,MatrixBlock>( + ix, new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1d)); + } + + //handle cases of empty blocks + if( (_lower && ix.getRowIndex() < ix.getColumnIndex()) + || (!_lower && ix.getRowIndex() > ix.getColumnIndex()) ) { + return new Tuple2<MatrixIndexes,MatrixBlock>(ix, + new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), true)); + } + + //extract triangular blocks for blocks on diagonal + assert(ix.getRowIndex() == ix.getColumnIndex()); + return new Tuple2<MatrixIndexes,MatrixBlock>(ix, + mb.extractTriangular(new MatrixBlock(), _lower, _diag, _values)); + } + } } public static class RDDRemoveEmptyFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>,MatrixIndexes,MatrixBlock> http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index fc4a82f..9807340 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -164,7 +164,10 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab public MatrixBlock(double val) { reset(1, 1, false, 1, val); - nonZeros = (val != 0) ? 1 : 0; + } + + public MatrixBlock(int rl, int cl, double val) { + reset(rl, cl, false, (long)rl*cl, val); } /** @@ -245,7 +248,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab rlen = rl; clen = cl; sparse = (val == 0) ? sp : false; - nonZeros = (val == 0) ? 0 : rl*cl; + nonZeros = (val == 0) ? 0 : (long)rl*cl; estimatedNNzsPerRow = (estnnz < 0 || !sparse) ? -1 : (int)Math.ceil((double)estnnz/(double)rlen); @@ -4853,6 +4856,62 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab return ret; } + public MatrixBlock extractTriangular(MatrixBlock ret, boolean lower, boolean diag, boolean values) { + ret.reset(rlen, clen, sparse); + if( isEmptyBlock(false) ) + return ret; //sparse-safe + ret.allocateBlock(); + + long nnz = 0; + if( sparse ) { //SPARSE + SparseBlock a = sparseBlock; + SparseBlock c = ret.sparseBlock; + for( int i=0; i<rlen; i++ ) { + if( a.isEmpty(i) ) continue; + int jbeg = Math.min(lower ? 0 : (diag ? i : i+1), clen); + int jend = Math.min(lower ? (diag ? i+1 : i) : clen, clen); + if( values ) { + int k1 = a.posFIndexGTE(i, jbeg); + int k2 = a.posFIndexGTE(i, jend); + k1 = (k1 >= 0) ? k1 : a.size(i); + k2 = (k2 >= 0) ? k2 : a.size(i); + int apos = a.pos(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + c.allocate(i, k2-k1); + for( int k=apos+k1; k<apos+k2; k++ ) + ret.appendValue(i, aix[k], avals[k]); + } + else { + c.allocate(i, jend-jbeg); + for( int j=jbeg; j<jend; j++ ) + ret.appendValue(i, j, 1); + } + } + //nnz maintained internally + } + else { //DENSE <- DENSE + DenseBlock a = denseBlock; + DenseBlock c = ret.getDenseBlock(); + for(int i = 0; i < rlen; i++) { + int jbeg = Math.min(lower ? 0 : (diag ? i : i+1), clen); + int jend = Math.min(lower ? (diag ? i+1 : i) : clen, clen); + double[] avals = a.values(i), cvals = c.values(i); + int aix = a.pos(i,jbeg), cix = c.pos(i,jbeg); + if( values ) { + System.arraycopy(avals, aix, cvals, cix, jend-jbeg); + nnz += UtilFunctions.countNonZeros(avals, aix, jend-jbeg); + } + else { //R semantics full reset, not just nnz + Arrays.fill(cvals, cix, cix+(jend-jbeg), 1); + nnz += (jend-jbeg); + } + } + } + ret.setNonZeros(nnz); + ret.examSparsity(); + return ret; + } /** * D = ctable(A,v2,W) http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/ExtractTriangularTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/ExtractTriangularTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/ExtractTriangularTest.java new file mode 100644 index 0000000..9fd346a --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/ExtractTriangularTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.unary.matrix; + +import org.junit.Test; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class ExtractTriangularTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "extractLowerTri"; + private final static String TEST_NAME2 = "extractUpperTri"; + private final static String TEST_DIR = "functions/unary/matrix/"; + private static final String TEST_CLASS_DIR = TEST_DIR + ExtractTriangularTest.class.getSimpleName() + "/"; + + private final static int _rows = 1321; + private final static int _cols = 1123; + private final static double _sparsityDense = 0.5; + private final static double _sparsitySparse = 0.05; + private final static double eps = 1e-8; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testExtractLowerTriDenseBoolCP() { + runExtractTriangular(TEST_NAME1, false, false, false, ExecType.CP); + } + + @Test + public void testExtractLowerTriDenseValuesCP() { + runExtractTriangular(TEST_NAME1, false, false, true, ExecType.CP); + } + + @Test + public void testExtractLowerTriDenseDiagBoolCP() { + runExtractTriangular(TEST_NAME1, false, true, false, ExecType.CP); + } + + @Test + public void testExtractLowerTriDenseDiagValuesCP() { + runExtractTriangular(TEST_NAME1, false, true, true, ExecType.CP); + } + + @Test + public void testExtractLowerTriSparseBoolCP() { + runExtractTriangular(TEST_NAME1, true, false, false, ExecType.CP); + } + + @Test + public void testExtractLowerTriSparseValuesCP() { + runExtractTriangular(TEST_NAME1, true, false, true, ExecType.CP); + } + + @Test + public void testExtractLowerTriSparseDiagBoolCP() { + runExtractTriangular(TEST_NAME1, true, true, false, ExecType.CP); + } + + @Test + public void testExtractLowerTriSparseDiagValuesCP() { + runExtractTriangular(TEST_NAME1, true, true, true, ExecType.CP); + } + + @Test + public void testExtractUpperTriDenseBoolCP() { + runExtractTriangular(TEST_NAME2, false, false, false, ExecType.CP); + } + + @Test + public void testExtractUpperTriDenseValuesCP() { + runExtractTriangular(TEST_NAME2, false, false, true, ExecType.CP); + } + + @Test + public void testExtractUpperTriDenseDiagBoolCP() { + runExtractTriangular(TEST_NAME2, false, true, false, ExecType.CP); + } + + @Test + public void testExtractUpperTriDenseDiagValuesCP() { + runExtractTriangular(TEST_NAME2, false, true, true, ExecType.CP); + } + + @Test + public void testExtractUpperTriSparseBoolCP() { + runExtractTriangular(TEST_NAME2, true, false, false, ExecType.CP); + } + + @Test + public void testExtractUpperTriSparseValuesCP() { + runExtractTriangular(TEST_NAME2, true, false, true, ExecType.CP); + } + + @Test + public void testExtractUpperTriSparseDiagBoolCP() { + runExtractTriangular(TEST_NAME2, true, true, false, ExecType.CP); + } + + @Test + public void testExtractUpperTriSparseDiagValuesCP() { + runExtractTriangular(TEST_NAME2, true, true, true, ExecType.CP); + } + + @Test + public void testExtractLowerTriDenseBoolSP() { + runExtractTriangular(TEST_NAME1, false, false, false, ExecType.SPARK); + } + + @Test + public void testExtractLowerTriDenseValuesSP() { + runExtractTriangular(TEST_NAME1, false, false, true, ExecType.SPARK); + } + + @Test + public void testExtractLowerTriDenseDiagBoolSP() { + runExtractTriangular(TEST_NAME1, false, true, false, ExecType.SPARK); + } + + @Test + public void testExtractLowerTriDenseDiagValuesSP() { + runExtractTriangular(TEST_NAME1, false, true, true, ExecType.SPARK); + } + + @Test + public void testExtractLowerTriSparseBoolSP() { + runExtractTriangular(TEST_NAME1, true, false, false, ExecType.SPARK); + } + + @Test + public void testExtractLowerTriSparseValuesSP() { + runExtractTriangular(TEST_NAME1, true, false, true, ExecType.SPARK); + } + + @Test + public void testExtractLowerTriSparseDiagBoolSP() { + runExtractTriangular(TEST_NAME1, true, true, false, ExecType.SPARK); + } + + @Test + public void testExtractLowerTriSparseDiagValuesSP() { + runExtractTriangular(TEST_NAME1, true, true, true, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriDenseBoolSP() { + runExtractTriangular(TEST_NAME2, false, false, false, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriDenseValuesSP() { + runExtractTriangular(TEST_NAME2, false, false, true, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriDenseDiagBoolSP() { + runExtractTriangular(TEST_NAME2, false, true, false, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriDenseDiagValuesSP() { + runExtractTriangular(TEST_NAME2, false, true, true, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriSparseBoolSP() { + runExtractTriangular(TEST_NAME2, true, false, false, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriSparseValuesSP() { + runExtractTriangular(TEST_NAME2, true, false, true, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriSparseDiagBoolSP() { + runExtractTriangular(TEST_NAME2, true, true, false, ExecType.SPARK); + } + + @Test + public void testExtractUpperTriSparseDiagValuesSP() { + runExtractTriangular(TEST_NAME2, true, true, true, ExecType.SPARK); + } + + private void runExtractTriangular( String testname, boolean sparse, boolean diag, boolean values, ExecType et) + { + //rtplatform for MR + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + try + { + //setup dims and sparsity + double sparsity = sparse ? _sparsitySparse : _sparsityDense; + + //register test configuration + TestConfiguration config = getTestConfiguration(testname); + config.addVariable("rows", _rows); + config.addVariable("cols", _cols); + loadTestConfiguration(config); + + String sdiag = String.valueOf(diag).toUpperCase(); + String svalues = String.valueOf(values).toUpperCase(); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-explain", "-args", + input("X"), sdiag, svalues, output("R") }; + fullRScriptName = HOME + testname + ".R"; + rCmd = "Rscript "+fullRScriptName+" " + +inputDir()+" "+sdiag+" "+svalues+" "+expectedDir(); + + //generate actual dataset + double[][] X = getRandomMatrix(_rows, _cols, -0.05, 1, sparsity, 7); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + } + finally { + //reset platform for additional tests + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/test/scripts/functions/unary/matrix/ExtractLowerTri.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/unary/matrix/ExtractLowerTri.R b/src/test/scripts/functions/unary/matrix/ExtractLowerTri.R new file mode 100644 index 0000000..6cb6397 --- /dev/null +++ b/src/test/scripts/functions/unary/matrix/ExtractLowerTri.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +R = lower.tri(X, diag=as.logical(args[2])); +if( as.logical(args[3]) ) { + R = R * X; +} + +writeMM(as(R, "CsparseMatrix"), paste(args[4], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/test/scripts/functions/unary/matrix/ExtractLowerTri.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/unary/matrix/ExtractLowerTri.dml b/src/test/scripts/functions/unary/matrix/ExtractLowerTri.dml new file mode 100644 index 0000000..e1a4dd0 --- /dev/null +++ b/src/test/scripts/functions/unary/matrix/ExtractLowerTri.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +R = lower.tri(target=X, diag=$2, values=$3) +write(R, $4); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/test/scripts/functions/unary/matrix/ExtractUpperTri.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/unary/matrix/ExtractUpperTri.R b/src/test/scripts/functions/unary/matrix/ExtractUpperTri.R new file mode 100644 index 0000000..08ca120 --- /dev/null +++ b/src/test/scripts/functions/unary/matrix/ExtractUpperTri.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +R = upper.tri(X, diag=as.logical(args[2])); +if( as.logical(args[3]) ) { + R = R * X; +} + +writeMM(as(R, "CsparseMatrix"), paste(args[4], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/test/scripts/functions/unary/matrix/ExtractUpperTri.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/unary/matrix/ExtractUpperTri.dml b/src/test/scripts/functions/unary/matrix/ExtractUpperTri.dml new file mode 100644 index 0000000..0392213 --- /dev/null +++ b/src/test/scripts/functions/unary/matrix/ExtractUpperTri.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +R = upper.tri(target=X, diag=$2, values=$3) +write(R, $4); http://git-wip-us.apache.org/repos/asf/systemml/blob/4eb1b935/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java index 09f4113..27def61 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/unary/matrix/ZPackageSuite.java @@ -35,6 +35,7 @@ import org.junit.runners.Suite; CholeskyTest.class, DiagTest.class, EigenFactorizeTest.class, + ExtractTriangularTest.class, FullCummaxTest.class, FullCumminTest.class, FullCumprodTest.class,
