This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9446bf6362 [SYSTEMDS-3500] Fix perftest regression via new
contains-value function
9446bf6362 is described below
commit 9446bf6362cd46cc4049bcab361f7cc6388809b2
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 23 22:37:19 2023 +0100
[SYSTEMDS-3500] Fix perftest regression via new contains-value function
A while ago the MLLogreg script was extended with robustness checks for
NaN inputs. In the perftest MLogReg 1M_1K_dense (8GB), this led to a
performance regression of unnecessary with 20GB driver because
input and output (16GB) exceed the 70% memory budget. Given that
sum(isNaN(X)) is likely false, we now expose an already existing block
operations contains(X, pattern) that has only have the memory reqs.
We added the CP, SPARK, and FED instructions as well as related tests.
---
scripts/builtin/multiLogReg.dml | 6 +-
scripts/builtin/multiLogRegPredict.dml | 6 +-
.../java/org/apache/sysds/common/Builtins.java | 1 +
src/main/java/org/apache/sysds/common/Types.java | 2 +-
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 3 +-
.../apache/sysds/lops/ParameterizedBuiltin.java | 21 +--
.../org/apache/sysds/parser/DMLTranslator.java | 1 +
.../ParameterizedBuiltinFunctionExpression.java | 45 ++++---
.../controlprogram/federated/FederationUtils.java | 14 +-
.../runtime/instructions/CPInstructionParser.java | 1 +
.../runtime/instructions/SPInstructionParser.java | 3 +-
.../cp/ParameterizedBuiltinCPInstruction.java | 18 ++-
.../fed/ParameterizedBuiltinFEDInstruction.java | 35 +++--
.../spark/ParameterizedBuiltinSPInstruction.java | 30 ++++-
.../test/functions/aggregate/ContainsTest.java | 142 +++++++++++++++++++++
.../federated/algorithms/FederatedLogRegTest.java | 2 +-
src/test/scripts/functions/aggregate/Contains.dml | 24 ++++
17 files changed, 294 insertions(+), 60 deletions(-)
diff --git a/scripts/builtin/multiLogReg.dml b/scripts/builtin/multiLogReg.dml
index 9b7d7da79e..528931ad8e 100644
--- a/scripts/builtin/multiLogReg.dml
+++ b/scripts/builtin/multiLogReg.dml
@@ -59,10 +59,10 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double]
Y, Int icpt = 2,
D = ncol (X);
# Robustness for datasets with missing values (causing NaN gradients)
- numNaNs = sum(isNaN(X))
- if( numNaNs > 0 ) {
+ hasNaNs = contains(target=X, pattern=NaN);
+ if( hasNaNs > 0 ) {
if(verbose)
- print("multiLogReg: matrix X contains "+numNaNs+" missing values,
replacing with 0.")
+ print("multiLogReg: matrix X contains "+sum(isNaN(X))+" missing values,
replacing with 0.")
X = replace(target=X, pattern=NaN, replacement=0);
}
diff --git a/scripts/builtin/multiLogRegPredict.dml
b/scripts/builtin/multiLogRegPredict.dml
index dc5c0332ab..16bf08316a 100644
--- a/scripts/builtin/multiLogRegPredict.dml
+++ b/scripts/builtin/multiLogRegPredict.dml
@@ -49,9 +49,9 @@ m_multiLogRegPredict = function(Matrix[Double] X,
Matrix[Double] B, Matrix[Doubl
stop("multiLogRegPredict: mismatching ncol(X) and nrow(B): "+ncol(X)+"
"+nrow(B));
# Robustness for datasets with missing values (causing NaN probabilities)
- numNaNs = sum(isNaN(X))
- if( numNaNs > 0 ) {
- print("multiLogRegPredict: matrix X contains "+numNaNs+" missing values,
replacing with 0.")
+ hasNaNs = contains(target=X, pattern=NaN);
+ if( hasNaNs > 0 ) {
+ print("multiLogRegPredict: matrix X contains "+sum(isNaN(X))+" missing
values, replacing with 0.")
X = replace(target=X, pattern=NaN, replacement=0);
}
accuracy = 0.0 # initialize variable
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index f7cbb972df..e627adb286 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -310,6 +310,7 @@ public enum Builtins {
// Parameterized functions with parameters
AUTODIFF("autoDiff", false, true),
CDF("cdf", false, true),
+ CONTAINS("contains", false, true),
COUNT_DISTINCT("countDistinct",false, true),
COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
COUNT_DISTINCT_APPROX_ROW("rowCountDistinctApprox", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index ab81ff4e31..49221cee89 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -569,7 +569,7 @@ public class Types
}
public enum ParamBuiltinOp {
- AUTODIFF, INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE,
REXPAND,
+ AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY,
REPLACE, REXPAND,
LOWER_TRI, UPPER_TRI,
TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
TOKENIZE, TOSTRING, LIST, PARAMSERV
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 55e6d79c7b..4404579894 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -181,7 +181,8 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
case REXPAND: {
constructLopsRExpand(inputlops, et);
break;
- }
+ }
+ case CONTAINS:
case CDF:
case INVCDF:
case REPLACE:
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index a0f9331adf..eb8174dbca 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -131,13 +131,6 @@ public class ParameterizedBuiltin extends Lop
break;
- case REPLACE: {
- sb.append( "replace" );
- sb.append( OPERAND_DELIMITOR );
- sb.append(compileGenericParamMap(_inputParams));
- break;
- }
-
case LOWER_TRI: {
sb.append( "lowertri" );
sb.append( OPERAND_DELIMITOR );
@@ -174,11 +167,14 @@ public class ParameterizedBuiltin extends Lop
break;
+ case CONTAINS:
+ case REPLACE:
case TOKENIZE:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
- case TRANSFORMMETA:{
+ case TRANSFORMMETA:
+ case PARAMSERV: {
sb.append(_operation.name().toLowerCase());
//opcode
sb.append(OPERAND_DELIMITOR);
sb.append(compileGenericParamMap(_inputParams));
@@ -202,14 +198,7 @@ public class ParameterizedBuiltin extends Lop
sb.append(compileGenericParamMap(_inputParams));
break;
}
-
- case PARAMSERV: {
- sb.append("paramserv");
- sb.append(OPERAND_DELIMITOR);
- sb.append(compileGenericParamMap(_inputParams));
- break;
- }
-
+
default:
throw new
LopsException(this.printErrorLocation() + "In ParameterizedBuiltin Lop, Unknown
operation: " + _operation);
}
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index c6ed2c5b84..98eaf0bbfb 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2007,6 +2007,7 @@ public class DMLTranslator
target.getValueType(),
source.getOpCode(), paramHops);
break;
+ case CONTAINS:
case GROUPEDAGG:
case RMEMPTY:
case REPLACE:
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 293ca7312e..1d30d13fea 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -202,6 +202,10 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
validateReplace(output, conditional);
break;
+ case CONTAINS:
+ validateContains(output, conditional);
+ break;
+
case ORDER:
validateOrder(output, conditional);
break;
@@ -725,28 +729,24 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
output.setDimensions(in.getDim1(), in.getDim2());
}
+ private void validateContains(DataIdentifier output, boolean
conditional) {
+ //check existence and correctness of arguments
+ Expression target = getVarParam("target");
+ checkTargetParam(target, conditional);
+ checkScalarParam("contains", "pattern", conditional);
+
+ //set boolean scalar
+ output.setBooleanProperties();
+ }
+
private void validateReplace(DataIdentifier output, boolean
conditional) {
//check existence and correctness of arguments
Expression target = getVarParam("target");
if( target.getOutput().getDataType() != DataType.FRAME ){
checkTargetParam(target, conditional);
}
-
- Expression pattern = getVarParam("pattern");
- if( pattern==null ) {
- raiseValidateError("Named parameter 'pattern' missing.
Please specify the replacement pattern.", conditional,
LanguageErrorCodes.INVALID_PARAMETERS);
- }
- else if( pattern.getOutput().getDataType() != DataType.SCALAR
){
- raiseValidateError("Replacement pattern 'pattern' is of
type '"+pattern.getOutput().getDataType()+"'. Please, specify a scalar
replacement pattern.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
- }
-
- Expression replacement = getVarParam("replacement");
- if( replacement==null ) {
- raiseValidateError("Named parameter 'replacement'
missing. Please specify the replacement value.", conditional,
LanguageErrorCodes.INVALID_PARAMETERS);
- }
- else if( replacement.getOutput().getDataType() !=
DataType.SCALAR ){
- raiseValidateError("Replacement value 'replacement' is
of type '"+replacement.getOutput().getDataType()+"'. Please, specify a scalar
replacement value.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
- }
+ checkScalarParam("replace", "pattern", conditional);
+ checkScalarParam("replace", "replacement", conditional);
// Output is a matrix with same dims as input
output.setDataType(target.getOutput().getDataType());
@@ -756,6 +756,19 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
output.setValueType(ValueType.FP64);
output.setDimensions(target.getOutput().getDim1(),
target.getOutput().getDim2());
}
+
+ private void checkScalarParam(String group, String param, boolean
conditional) {
+ Expression eparam = getVarParam(param);
+ if( eparam==null ) {
+ raiseValidateError("Named parameter '"+param+"'
missing. Please specify the "+group+" pattern.",
+ conditional,
LanguageErrorCodes.INVALID_PARAMETERS);
+ }
+ else if( eparam.getOutput().getDataType() != DataType.SCALAR ){
+ raiseValidateError(group + " parameter '"+param+"' is
of type '"
+ + eparam.getOutput().getDataType()+"'. Please,
specify a scalar "+param+".",
+ conditional,
LanguageErrorCodes.INVALID_PARAMETERS);
+ }
+ }
private void validateOrder(DataIdentifier output, boolean conditional) {
//check existence and correctness of arguments
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 73939117ce..cabf4887a6 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -487,7 +487,19 @@ public class FederationUtils {
throw new DMLRuntimeException(ex);
}
}
-
+
+ public static boolean aggBooleanScalar(Future<FederatedResponse>[] tmp)
{
+ boolean ret = false;
+ try {
+ for( Future<FederatedResponse> fr : tmp )
+ ret |=
((ScalarObject)fr.get().getData()[0]).getBooleanValue();
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return ret;
+ }
+
public static MatrixBlock aggMatrix(AggregateUnaryOperator aop,
Future<FederatedResponse>[] ffr, FederationMap map) {
if (aop.isRowAggregate() && map.getType() == FType.ROW)
return bind(ffr, false);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 1dc7b068b8..07ce7d620b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -222,6 +222,7 @@ public class CPInstructionParser extends InstructionParser {
// Parameterized Builtin Functions
String2CPInstructionType.put( "autoDiff" ,
CPType.ParameterizedBuiltin);
+ String2CPInstructionType.put( "contains",
CPType.ParameterizedBuiltin);
String2CPInstructionType.put("paramserv",
CPType.ParameterizedBuiltin);
String2CPInstructionType.put( "nvlist",
CPType.ParameterizedBuiltin);
String2CPInstructionType.put( "cdf",
CPType.ParameterizedBuiltin);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 773153d6d4..06e68a63d5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -275,7 +275,8 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "isinf", SPType.Unary);
// Parameterized Builtin Functions
- String2SPInstructionType.put( "autoDiff" ,
SPType.ParameterizedBuiltin);
+ String2SPInstructionType.put( "autoDiff",
SPType.ParameterizedBuiltin);
+ String2SPInstructionType.put( "contains",
SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "groupedagg",
SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "mapgroupedagg",
SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "rmempty",
SPType.ParameterizedBuiltin);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index baf4f25139..d3c88fd5ff 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -138,13 +138,14 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
}
else if(opcode.equalsIgnoreCase("rmempty") ||
opcode.equalsIgnoreCase("replace") ||
opcode.equalsIgnoreCase("rexpand") ||
opcode.equalsIgnoreCase("lowertri") ||
- opcode.equalsIgnoreCase("uppertri")) {
+ opcode.equalsIgnoreCase("uppertri") ) {
func =
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinCPInstruction(new
SimpleOperator(func), paramsMap, out, opcode, str);
}
- else if(opcode.equals("transformapply") ||
opcode.equals("transformdecode") ||
- opcode.equals("transformcolmap") ||
opcode.equals("transformmeta") || opcode.equals("tokenize") ||
- opcode.equals("toString") || opcode.equals("nvlist") ||
opcode.equals("autoDiff")) {
+ else if(opcode.equals("transformapply") ||
opcode.equals("transformdecode")
+ || opcode.equalsIgnoreCase("contains") ||
opcode.equals("transformcolmap")
+ || opcode.equals("transformmeta") ||
opcode.equals("tokenize")
+ || opcode.equals("toString") || opcode.equals("nvlist")
|| opcode.equals("autoDiff")) {
return new ParameterizedBuiltinCPInstruction(null,
paramsMap, out, opcode, str);
}
else if("paramserv".equals(opcode)) {
@@ -235,6 +236,14 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
ec.releaseMatrixInput(params.get("select"));
}
}
+ else if(opcode.equalsIgnoreCase("contains")) {
+ String varName = params.get("target");
+ MatrixBlock target = ec.getMatrixInput(varName);
+ double pattern =
Double.parseDouble(params.get("pattern"));
+ boolean ret = target.containsValue(pattern);
+ ec.releaseMatrixInput(varName);
+ ec.setScalarOutput(output.getName(), new
BooleanObject(ret));
+ }
else if(opcode.equalsIgnoreCase("replace")) {
if(ec.isFrameObject(params.get("target"))){
FrameBlock target =
ec.getFrameInput(params.get("target"));
@@ -255,7 +264,6 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
ec.setMatrixOutput(output.getName(),
ret);
targetObj.release();
}
-
}
else if(opcode.equals("lowertri") || opcode.equals("uppertri"))
{
MatrixBlock target =
ec.getMatrixInput(params.get("target"));
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 12f2e597ef..7654b92ecc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -58,9 +58,11 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
@@ -85,8 +87,8 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
protected final HashMap<String, String> params;
private static final String[] PARAM_BUILTINS = new String[]{
- "replace", "rmempty", "lowertri", "uppertri",
"transformdecode", "transformapply", "tokenize"};
-
+ "contains", "replace", "rmempty", "lowertri", "uppertri",
+ "transformdecode", "transformapply", "tokenize"};
protected ParameterizedBuiltinFEDInstruction(Operator op,
HashMap<String, String> paramsMap, CPOperand out,
String opcode, String istr) {
@@ -110,7 +112,8 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
ValueFunction func =
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinFEDInstruction(new
SimpleOperator(func), paramsMap, out, opcode, str);
}
- else if(opcode.equals("transformapply") ||
opcode.equals("transformdecode") || opcode.equals("tokenize")) {
+ else if(opcode.equals("transformapply") ||
opcode.equals("transformdecode")
+ || opcode.equals("tokenize") ||
opcode.equals("contains") ) {
return new ParameterizedBuiltinFEDInstruction(null,
paramsMap, out, opcode, str);
}
else {
@@ -140,15 +143,17 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
return paramMap;
}
- public static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinCPInstruction inst,
- ExecutionContext ec) {
+ public static ParameterizedBuiltinFEDInstruction parseInstruction(
+ ParameterizedBuiltinCPInstruction inst, ExecutionContext ec)
+ {
if(ArrayUtils.contains(PARAM_BUILTINS, inst.getOpcode()) &&
inst.getTarget(ec).isFederatedExcept(FType.BROADCAST))
return
ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
return null;
}
- public static ParameterizedBuiltinFEDInstruction
parseInstruction(ParameterizedBuiltinSPInstruction inst,
- ExecutionContext ec) {
+ public static ParameterizedBuiltinFEDInstruction parseInstruction(
+ ParameterizedBuiltinSPInstruction inst, ExecutionContext ec)
+ {
if( inst.getOpcode().equalsIgnoreCase("replace") &&
inst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
return
ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
return null;
@@ -167,13 +172,21 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
@Override
public void processInstruction(ExecutionContext ec) {
String opcode = getOpcode();
- if(opcode.equalsIgnoreCase("replace")) {
+ if(opcode.equalsIgnoreCase("contains")) {
+ FederationMap map = getTarget(ec).getFedMapping();
+ FederatedRequest fr1 =
FederationUtils.callInstruction(instString,
+ output, new CPOperand[] {getTargetOperand()},
new long[] {map.getID()});
+ FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ Future<FederatedResponse>[] tmp = map.execute(getTID(),
fr1, fr2);
+ boolean ret = FederationUtils.aggBooleanScalar(tmp);
+ ec.setVariable(output.getName(), new
BooleanObject(ret));
+ }
+ else if(opcode.equalsIgnoreCase("replace")) {
// similar to unary federated instructions, get
federated input
// execute instruction, and derive federated output
matrix
CacheableData<?> mo = getTarget(ec);
- FederatedRequest fr1 =
FederationUtils.callInstruction(instString,
- output,
- new CPOperand[] {getTargetOperand()},
+ FederatedRequest fr1 = FederationUtils.callInstruction(
+ instString, output, new CPOperand[]
{getTargetOperand()},
new long[] {mo.getFedMapping().getID()});
Future<FederatedResponse>[] ret =
mo.getFedMapping().execute(getTID(), true, fr1);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index e5b8fea07a..cc3ce6d03f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -47,6 +47,7 @@ import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
@@ -169,6 +170,9 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
func =
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new
ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out,
opcode, str);
}
+ else if(opcode.equalsIgnoreCase("contains")) {
+ return new
ParameterizedBuiltinSPInstruction(null, paramsMap, out, opcode, str);
+ }
else {
throw new DMLRuntimeException("Unknown opcode
(" + opcode + ") for ParameterizedBuiltin Instruction.");
}
@@ -363,6 +367,17 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
sec.setMatrixOutput(output.getName(), out);
}
}
+ else if(opcode.equalsIgnoreCase("contains")) {
+ JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
+
.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
+
+ // execute contains operation
+ double pattern =
Double.parseDouble(params.get("pattern"));
+ Double ret = in1.values() //num blocks containing
pattern
+ .map(new RDDContainsFunction(pattern))
+ .reduce((a,b) -> a+b);
+ ec.setScalarOutput(output.getName(), new
BooleanObject(ret>0));
+ }
else if(opcode.equalsIgnoreCase("replace")) {
if(sec.isFrameObject(params.get("target"))){
params.get("target");
@@ -395,7 +410,6 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
mcIn.getBlocksize(),
(pattern != 0 && replacement != 0) ?
mcIn.getNonZeros() : -1);
}
-
}
else if(opcode.equalsIgnoreCase("lowertri") ||
opcode.equalsIgnoreCase("uppertri")) {
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
@@ -566,6 +580,20 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
return arg0.replaceOperations(new MatrixBlock(),
_pattern, _replacement);
}
}
+
+ public static class RDDContainsFunction implements
Function<MatrixBlock, Double> {
+ private static final long serialVersionUID =
6576713401901671659L;
+ private final double _pattern;
+
+ public RDDContainsFunction(double pattern) {
+ _pattern = pattern;
+ }
+
+ @Override
+ public Double call(MatrixBlock arg0) {
+ return arg0.containsValue(_pattern) ? 1d : 0d;
+ }
+ }
public static class RDDFrameReplaceFunction implements
Function<FrameBlock, FrameBlock>{
private static final long serialVersionUID =
6576713401901671660L;
diff --git
a/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
b/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
new file mode 100644
index 0000000000..4ea6d917d8
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.aggregate;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+public class ContainsTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "Contains";
+
+ private final static String TEST_DIR = "functions/aggregate/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
AggregateInfTest.class.getSimpleName() + "/";
+
+ private final static int rows = 1205;
+ private final static int cols = 1179;
+ private final static double sparsity1 = 0.1;
+ private final static double sparsity2 = 0.7;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new
String[]{"B"}));
+ }
+
+
+ @Test
+ public void testNaNTrueDenseCP() {
+ runContainsTest(Double.NaN, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testNaNFalseDenseCP() {
+ runContainsTest(Double.NaN, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testNaNTrueSparseCP() {
+ runContainsTest(Double.NaN, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testNaNFalseSpaseCP() {
+ runContainsTest(Double.NaN, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testInfTrueDenseCP() {
+ runContainsTest(Double.POSITIVE_INFINITY, true, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testInfFalseDenseCP() {
+ runContainsTest(Double.POSITIVE_INFINITY, false, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testInfTrueSparseCP() {
+ runContainsTest(Double.POSITIVE_INFINITY, true, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testInfFalseSpaseCP() {
+ runContainsTest(Double.POSITIVE_INFINITY, false, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testNaNTrueDenseSpark() {
+ runContainsTest(Double.NaN, true, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testNaNFalseDenseSpark() {
+ runContainsTest(Double.NaN, false, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testNaNTrueSparseSpark() {
+ runContainsTest(Double.NaN, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testNaNFalseSpaseSpark() {
+ runContainsTest(Double.NaN, false, true, ExecType.SPARK);
+ }
+
+ private void runContainsTest( double check, boolean expected, boolean
sparse, ExecType instType)
+ {
+ ExecMode oldMode = setExecMode(instType);
+
+ try
+ {
+ double sparsity = (sparse) ? sparsity1 : sparsity2;
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-args",
+ input("A"), String.valueOf(check), output("B")
};
+
+ //generate actual dataset
+ double[][] A = getRandomMatrix(rows, cols, -0.05, 1,
sparsity, 7);
+ A[7][7] = expected ? check : 7;
+ writeInputMatrixWithMTD("A", A, false);
+
+ //run test
+ runTest(true, false, null, -1);
+ boolean ret = TestUtils.readDMLBoolean(output("B"));
+ Assert.assertEquals(expected, ret);
+ if( instType == ExecType.CP ) {
+
Assert.assertEquals(Statistics.getNoOfCompiledSPInst(), 1); //reblock
+
Assert.assertEquals(Statistics.getNoOfExecutedSPInst(), 0);
+ }
+ }
+ finally {
+ resetExecMode(oldMode);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index 7abb1a8125..a3e91ef37d 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -123,7 +123,7 @@ public class FederatedLogRegTest extends AutomatedTestBase {
Assert.assertTrue("contains fed_ba+*",
heavyHittersContainsString("fed_ba+*"));
Assert.assertTrue("contains fed_uar",
heavyHittersContainsString("fed_uark+", "fed_uarsqk+"));
Assert.assertTrue("contains fed_mmchain & r'",
heavyHittersContainsString("fed_mmchain", "fed_r'"));
- Assert.assertTrue("contains fed_isnan",
heavyHittersContainsString("fed_isnan"));
+ Assert.assertTrue("contains fed_contains",
heavyHittersContainsString("fed_contains"));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/scripts/functions/aggregate/Contains.dml
b/src/test/scripts/functions/aggregate/Contains.dml
new file mode 100644
index 0000000000..0576b6e1cd
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/Contains.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+ret = contains(target=A, pattern=$2);
+write(ret, $3);
\ No newline at end of file