This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 69d3358 [SYSTEMDS-3143] Frame rm empty instruction
69d3358 is described below
commit 69d33589de1258ba68b3652dfb9a5884adea213e
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Wed Sep 22 00:17:50 2021 +0200
[SYSTEMDS-3143] Frame rm empty instruction
This commit adds the remove empty instruction to frame, this instruction
was previously only supported on matrices.
Closes #1397
---
.../ParameterizedBuiltinFunctionExpression.java | 16 +-
.../cp/ParameterizedBuiltinCPInstruction.java | 34 ++-
.../fed/ParameterizedBuiltinFEDInstruction.java | 251 ++++++++++++++++++++-
.../sysds/runtime/matrix/data/FrameBlock.java | 110 +++++++++
.../apache/sysds/runtime/util/UtilFunctions.java | 7 +
.../test/component/frame/FrameRemoveEmptyTest.java | 195 ++++++++++++++++
src/test/scripts/functions/frame/removeEmpty1.dml | 30 +++
7 files changed, 623 insertions(+), 20 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 444ab54..442d1e6 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -584,7 +584,8 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
+ Arrays.toString(invalid.toArray(new
String[0])), false);
//check existence and correctness of arguments
- checkTargetParam(getVarParam("target"), conditional);
+ Expression target = getVarParam("target");
+ checkEmptyTargetParam(target, conditional);
Expression margin = getVarParam("margin");
if( margin==null ){
@@ -608,8 +609,11 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
_varParams.put("empty.return", new
BooleanIdentifier(true));
// Output is a matrix with unknown dims
- output.setDataType(DataType.MATRIX);
- output.setValueType(ValueType.FP64);
+ output.setDataType(target.getOutput().getDataType());
+ if(target.getOutput().getDataType() == DataType.FRAME)
+ output.setValueType(ValueType.STRING);
+ else
+ output.setValueType(ValueType.FP64);
output.setDimensions(-1, -1);
}
@@ -726,6 +730,12 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
raiseValidateError("Input matrix 'target' is of type
'"+target.getOutput().getDataType()
+"'. Please specify the input matrix.",
conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
+
+ private void checkEmptyTargetParam(Expression target, boolean
conditional) {
+ if( target==null )
+ raiseValidateError("Named parameter 'target' missing.
Please specify the input matrix.",
+ conditional,
LanguageErrorCodes.INVALID_PARAMETERS);
+ }
private void checkOptionalBooleanParam(Expression param, String name,
boolean conditional) {
if( param!=null && (!param.getOutput().getDataType().isScalar()
|| param.getOutput().getValueType() != ValueType.BOOLEAN) ){
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index ccced11..233154a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -208,21 +208,31 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
String margin = params.get("margin");
if(!(margin.equals("rows") || margin.equals("cols")))
throw new DMLRuntimeException("Unspupported
margin identifier '" + margin + "'.");
+ if(ec.isFrameObject(params.get("target"))) {
+ FrameBlock target =
ec.getFrameInput(params.get("target"));
+ MatrixBlock select =
params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null;
- // acquire locks
- MatrixBlock target =
ec.getMatrixInput(params.get("target"));
- MatrixBlock select = params.containsKey("select") ?
ec.getMatrixInput(params.get("select")) : null;
+ boolean emptyReturn =
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
+ FrameBlock soresBlock =
target.removeEmptyOperations(margin.equals("rows"), emptyReturn, select);
+ ec.setFrameOutput(output.getName(), soresBlock);
+ ec.releaseFrameInput(params.get("target"));
+ if(params.containsKey("select"))
+
ec.releaseMatrixInput(params.get("select"));
+ } else {
+ // acquire locks
+ MatrixBlock target =
ec.getMatrixInput(params.get("target"));
+ MatrixBlock select =
params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null;
- // compute the result
- boolean emptyReturn =
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
- MatrixBlock soresBlock = target
- .removeEmptyOperations(new MatrixBlock(),
margin.equals("rows"), emptyReturn, select);
+ // compute the result
+ boolean emptyReturn =
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
+ MatrixBlock soresBlock =
target.removeEmptyOperations(new MatrixBlock(), margin.equals("rows"),
emptyReturn, select);
- // release locks
- ec.setMatrixOutput(output.getName(), soresBlock);
- ec.releaseMatrixInput(params.get("target"));
- if(params.containsKey("select"))
- ec.releaseMatrixInput(params.get("select"));
+ // release locks
+ ec.setMatrixOutput(output.getName(),
soresBlock);
+ ec.releaseMatrixInput(params.get("target"));
+ if(params.containsKey("select"))
+
ec.releaseMatrixInput(params.get("select"));
+ }
}
else if(opcode.equalsIgnoreCase("replace")) {
if(ec.isFrameObject(params.get("target"))){
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 02d34a1..a6c5ef1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -28,10 +28,12 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
+import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.zip.Adler32;
import java.util.zip.Checksum;
+import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
@@ -73,6 +75,7 @@ import
org.apache.sysds.runtime.transform.decode.DecoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.runtime.util.UtilFunctions;
public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstruction {
protected final LinkedHashMap<String, String> params;
@@ -151,7 +154,10 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
}
else if(opcode.equals("rmempty"))
- rmempty(ec);
+ if (getTarget(ec) instanceof FrameObject)
+ rmemptyFrame(ec);
+ else
+ rmemptyMatrix(ec);
else if(opcode.equals("lowertri") || opcode.equals("uppertri"))
triangle(ec, opcode);
else if(opcode.equalsIgnoreCase("transformdecode"))
@@ -329,7 +335,170 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
}
}
- private void rmempty(ExecutionContext ec) {
+ private void rmemptyFrame(ExecutionContext ec) {
+ String margin = params.get("margin");
+ if(!(margin.equals("rows") || margin.equals("cols")))
+ throw new DMLRuntimeException("Unsupported margin
identifier '" + margin + "'.");
+
+ FrameObject mo = (FrameObject) getTarget(ec);
+ MatrixObject select = params.containsKey("select") ?
ec.getMatrixObject(params.get("select")) : null;
+ FrameObject out = ec.getFrameObject(output);
+
+ boolean marginRow = params.get("margin").equals("rows");
+ boolean isNotAligned = ((marginRow &&
mo.getFedMapping().getType().isColPartitioned()) ||
+ (!marginRow &&
mo.getFedMapping().getType().isRowPartitioned()));
+
+ MatrixBlock s = new MatrixBlock();
+ if(select == null && isNotAligned) {
+ List<MatrixBlock> colSums = new ArrayList<>();
+ mo.getFedMapping().forEachParallel((range, data) -> {
+ try {
+ FederatedResponse response = data
+ .executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
GetFrameVector(data.getVarID(), margin.equals("rows"))))
+ .get();
+
+ if(!response.isSuccessful())
+
response.throwExceptionFromResponse();
+ MatrixBlock vector = (MatrixBlock)
response.getData()[0];
+ synchronized(colSums) {
+ colSums.add(vector);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+ // find empty in matrix
+ BinaryOperator plus =
InstructionUtils.parseBinaryOperator("+");
+ BinaryOperator greater =
InstructionUtils.parseBinaryOperator(">");
+ s = colSums.get(0);
+ for(int i = 1; i < colSums.size(); i++)
+ s = s.binaryOperationsInPlace(plus,
colSums.get(i));
+ s = s.binaryOperationsInPlace(greater, new
MatrixBlock(s.getNumRows(), s.getNumColumns(), 0.0));
+ select = ExecutionContext.createMatrixObject(s);
+
+ long varID = FederationUtils.getNextFedDataID();
+ ec.setVariable(String.valueOf(varID), select);
+ params.put("select", String.valueOf(varID));
+ // construct new string
+ String[] oldString =
InstructionUtils.getInstructionParts(instString);
+ String[] newString = new String[oldString.length + 1];
+ newString[2] = "select=" + varID;
+ System.arraycopy(oldString, 0, newString, 0, 2);
+ System.arraycopy(oldString, 2, newString, 3,
newString.length - 3);
+ instString =
instString.replace(InstructionUtils.concatOperands(oldString),
+ InstructionUtils.concatOperands(newString));
+ }
+
+ if(select == null) {
+ FederatedRequest fr1 =
FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {getTargetOperand()},
+ new long[] {mo.getFedMapping().getID()});
+ mo.getFedMapping().execute(getTID(), true, fr1);
+
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+ }
+ else if(!isNotAligned) {
+ // construct commands: broadcast , fed rmempty, clean
broadcast
+ FederatedRequest[] fr1 =
mo.getFedMapping().broadcastSliced(select, !marginRow);
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {getTargetOperand(),
+ new CPOperand(params.get("select"),
ValueType.FP64, DataType.MATRIX)},
+ new long[] {mo.getFedMapping().getID(),
fr1[0].getID()});
+
+ // execute federated operations and set output
+ mo.getFedMapping().execute(getTID(), true, fr1, fr2);
+
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+ }
+ else {
+ // construct commands: broadcast , fed rmempty, clean
broadcast
+ FederatedRequest fr1 =
mo.getFedMapping().broadcast(select);
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {getTargetOperand(),
+ new CPOperand(params.get("select"),
ValueType.FP64, DataType.MATRIX)},
+ new long[] {mo.getFedMapping().getID(),
fr1.getID()});
+
+ // execute federated operations and set output
+ mo.getFedMapping().execute(getTID(), true, fr1, fr2);
+
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+ }
+
+ // new ranges
+ Map<FederatedRange, int[]> dcs = new HashMap<>();
+ Map<FederatedRange, int[]> finalDcs1 = dcs;
+ Map<FederatedRange, ValueType[]> finalSchema = new HashMap<>();
+ out.getFedMapping().forEachParallel((range, data) -> {
+ try {
+ FederatedResponse response = data
+ .executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
GetFrameCharacteristics(data.getVarID())))
+ .get();
+
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ Object[] ret = response.getData();
+ int[] subRangeCharacteristics = new int[]{(int)
ret[0], (int) ret[1]};
+ ValueType[] schema = (ValueType[]) ret[2];
+ synchronized(finalDcs1) {
+ finalDcs1.put(range,
subRangeCharacteristics);
+ }
+ synchronized(finalSchema) {
+ finalSchema.put(range, schema);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ dcs = finalDcs1;
+ out.getDataCharacteristics().set(mo.getDataCharacteristics());
+ int len = marginRow ? mo.getSchema().length : (int)
(mo.isFederated(FederationMap.FType.ROW) ? s
+ .getNonZeros() :
finalSchema.values().stream().mapToInt(e -> e.length).sum());
+ ValueType[] schema = new ValueType[len];
+ int pos = 0;
+ for(int i = 0; i <
mo.getFedMapping().getFederatedRanges().length; i++) {
+ FederatedRange federatedRange = new
FederatedRange(out.getFedMapping().getFederatedRanges()[i]);
+
+ if(marginRow) {
+ schema = mo.getSchema();
+ } else if(mo.isFederated(FederationMap.FType.ROW)) {
+ schema = finalSchema.get(federatedRange);
+ } else {
+ ValueType[] tmp =
finalSchema.get(federatedRange);
+ System.arraycopy(tmp, 0, schema, pos,
tmp.length);
+ pos += tmp.length;
+ }
+
+ int[] newRange = dcs.get(federatedRange);
+
out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
+
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+ i == 0) ? 0 :
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
+
+ out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
+
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+
+
out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
+
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+ i == 0) ? 0 :
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
+
+ out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
+
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+ }
+
+ out.setSchema(schema);
+
out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
+ out.getFedMapping().getMaxIndexInRange(1),
+ (int) mo.getBlocksize());
+ }
+
+
+ private void rmemptyMatrix(ExecutionContext ec) {
String margin = params.get("margin");
if(!(margin.equals("rows") || margin.equals("cols")))
throw new DMLRuntimeException("Unsupported margin
identifier '" + margin + "'.");
@@ -428,7 +597,7 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
try {
FederatedResponse response = data
.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new
GetDataCharacteristics(data.getVarID())))
+ new
GetMatrixCharacteristics(data.getVarID())))
.get();
if(!response.isSuccessful())
@@ -724,11 +893,11 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
}
}
- private static class GetDataCharacteristics extends FederatedUDF {
+ private static class GetMatrixCharacteristics extends FederatedUDF {
private static final long serialVersionUID =
578461386177730925L;
- public GetDataCharacteristics(long varID) {
+ public GetMatrixCharacteristics(long varID) {
super(new long[] {varID});
}
@@ -746,6 +915,28 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
}
}
+ private static class GetFrameCharacteristics extends FederatedUDF {
+
+ private static final long serialVersionUID =
578461386177730925L;
+
+ public GetFrameCharacteristics(long varID) {
+ super(new long[] {varID});
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ FrameBlock fb = ((FrameObject)
data[0]).acquireReadAndRelease();
+ int r = fb.getNumRows() != 0 || fb.getNumRows() != -1 ?
fb.getNumRows() : 0;
+ int c = fb.getNumColumns() != 0 || fb.getNumColumns()
!= -1 ? fb.getNumColumns() : 0;
+ return new FederatedResponse(ResponseType.SUCCESS, new
Object[] {r, c, fb.getSchema()});
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
private static class GetVector extends FederatedUDF {
private static final long serialVersionUID =
-1003061862215703768L;
@@ -779,4 +970,54 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
return null;
}
}
+
+ private static class GetFrameVector extends FederatedUDF {
+
+ private static final long serialVersionUID =
-1003061862215703768L;
+ private final boolean _marginRow;
+
+ public GetFrameVector(long varID, boolean marginRow) {
+ super(new long[] {varID});
+ _marginRow = marginRow;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ FrameBlock fb = ((FrameObject)
data[0]).acquireReadAndRelease();
+
+ MatrixBlock ret = _marginRow ? new
MatrixBlock(fb.getNumRows(), 1, 0.0) : new MatrixBlock(1,fb.getNumColumns(),
0.0);
+
+ if(_marginRow) {
+ for(int i = 0; i < fb.getNumRows(); i++) {
+ boolean isEmpty = true;
+
+ for(int j = 0; j < fb.getNumColumns();
j++) {
+ ValueType type =
fb.getSchema()[j];
+ isEmpty = isEmpty &&
(ArrayUtils.contains(new double[]{0.0, Double.NaN},
UtilFunctions.objectToDoubleSafe(type, fb.get(i, j))));
+
+ }
+
+ if(!isEmpty)
+ ret.setValue(i, 0, 1.0);
+ }
+ } else {
+ for(int i = 0; i < fb.getNumColumns(); i++) {
+ int finalI = i;
+ ValueType type = fb.getSchema()[i];
+ boolean isEmpty = IntStream.range(0,
fb.getNumRows()).mapToObj(j -> fb.get(j, finalI))
+ .allMatch(e ->
ArrayUtils.contains(new double[]{0.0, Double.NaN},
UtilFunctions.objectToDoubleSafe(type, e)));
+
+ if(!isEmpty)
+ ret.setValue(0, i,1.0);
+ }
+ }
+
+ return new FederatedResponse(ResponseType.SUCCESS, ret);
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 86bbdab..64f6e80 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -37,6 +37,8 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
+import java.util.function.IntFunction;
+import java.util.stream.IntStream;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.NotImplementedException;
@@ -46,6 +48,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Writable;
import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
@@ -598,6 +601,31 @@ public class FrameBlock implements CacheBlock,
Externalizable {
_msize = -1;
}
+ public void appendColumn(ValueType vt, Array col) {
+ switch (vt) {
+ case STRING:
+ appendColumn(((StringArray) col).get());
+ break;
+ case BOOLEAN:
+ appendColumn(((BooleanArray) col).get());
+ break;
+ case INT32:
+ appendColumn(((IntegerArray) col).get());
+ break;
+ case INT64:
+ appendColumn(((LongArray) col).get());
+ break;
+ case FP32:
+ appendColumn(((FloatArray) col).get());
+ break;
+ case FP64:
+ appendColumn(((DoubleArray) col).get());
+ break;
+ default:
+ throw new RuntimeException("Unsupported value
type: " + vt);
+ }
+ }
+
public Object getColumnData(int c) {
switch(_schema[c]) {
case STRING: return ((StringArray)_coldata[c])._data;
@@ -1640,10 +1668,13 @@ public class FrameBlock implements CacheBlock,
Externalizable {
_data = data;
_size = _data.length;
}
+ public String[] get() { return _data; }
+
@Override
public String get(int index) {
return _data[index];
}
+
@Override
public void set(int index, String value) {
_data[index] = value;
@@ -1705,10 +1736,13 @@ public class FrameBlock implements CacheBlock,
Externalizable {
_data = data;
_size = _data.length;
}
+ public boolean[] get() { return _data; }
+
@Override
public Boolean get(int index) {
return _data[index];
}
+
@Override
public void set(int index, Boolean value) {
_data[index] = (value!=null) ? value : false;
@@ -1772,6 +1806,7 @@ public class FrameBlock implements CacheBlock,
Externalizable {
_data = data;
_size = _data.length;
}
+ public long[] get() { return _data; }
@Override
public Long get(int index) {
return _data[index];
@@ -1839,6 +1874,7 @@ public class FrameBlock implements CacheBlock,
Externalizable {
_data = data;
_size = _data.length;
}
+ public int[] get() { return _data; }
@Override
public Integer get(int index) {
@@ -1906,6 +1942,8 @@ public class FrameBlock implements CacheBlock,
Externalizable {
_data = data;
_size = _data.length;
}
+ public float[] get() { return _data; }
+
@Override
public Float get(int index) {
return _data[index];
@@ -1972,6 +2010,7 @@ public class FrameBlock implements CacheBlock,
Externalizable {
_data = data;
_size = _data.length;
}
+ public double[] get() { return _data; }
@Override
public Double get(int index) {
return _data[index];
@@ -2473,6 +2512,77 @@ public class FrameBlock implements CacheBlock,
Externalizable {
return ret;
}
+ public FrameBlock removeEmptyOperations(boolean rows, boolean
emptyReturn, MatrixBlock select) {
+ if( rows )
+ return removeEmptyRows(select, emptyReturn);
+ else //cols
+ return removeEmptyColumns(select, emptyReturn);
+ }
+
+ private FrameBlock removeEmptyRows(MatrixBlock select, boolean
emptyReturn) {
+ FrameBlock ret = new FrameBlock(_schema, _colnames);
+
+ for(int i = 0; i < _numRows; i++) {
+ boolean isEmpty = true;
+ Object[] row = new Object[getNumColumns()];
+
+ for(int j = 0; j < getNumColumns(); j++) {
+ Array colData = _coldata[j].clone();
+ row[j] = colData.get(i);
+ ValueType type = _schema[j];
+ isEmpty = isEmpty && (ArrayUtils.contains(new
double[]{0.0, Double.NaN}, UtilFunctions.objectToDoubleSafe(type,
colData.get(i))));
+ }
+
+ if((!isEmpty && select == null) || (select != null &&
select.getValue(i, 0) == 1)) {
+ ret.appendRow(row);
+ }
+ }
+
+ if(ret.getNumRows() == 0 && emptyReturn) {
+ String[][] arr = new String[1][getNumColumns()];
+ Arrays.fill(arr, new String[]{null});
+ ValueType[] schema = new ValueType[getNumColumns()];
+ Arrays.fill(schema, ValueType.STRING);
+ return new FrameBlock(schema, arr);
+ }
+
+ return ret;
+ }
+
+ private FrameBlock removeEmptyColumns(MatrixBlock select, boolean
emptyReturn) {
+ FrameBlock ret = new FrameBlock();
+ List<ColumnMetadata> columnMetadata = new ArrayList<>();
+
+ for(int i = 0; i < getNumColumns(); i++) {
+ Array colData = _coldata[i];
+
+ boolean isEmpty = false;
+ if(select == null) {
+ ValueType type = _schema[i];
+ isEmpty = IntStream.range(0,
colData._size).mapToObj((IntFunction<Object>) colData::get)
+ .allMatch(e -> ArrayUtils.contains(new
double[]{0.0, Double.NaN}, UtilFunctions.objectToDoubleSafe(type, e)));
+ }
+
+ if((select != null && select.getValue(0, i) == 1) ||
(!isEmpty && select == null)) {
+ Types.ValueType vt = _schema[i];
+ ret.appendColumn(vt, _coldata[i].clone());
+ columnMetadata.add(new
ColumnMetadata(_colmeta[i]));
+ }
+ }
+
+ if(ret.getNumColumns() == 0 && emptyReturn) {
+ String[][] arr = new String[_numRows][];
+ Arrays.fill(arr, new String[]{null});
+ return new FrameBlock(new
ValueType[]{ValueType.STRING}, arr);
+ }
+
+ ret._colmeta = new ColumnMetadata[columnMetadata.size()];
+ columnMetadata.toArray(ret._colmeta);
+ ret.setColumnMetadata(ret._colmeta);
+
+ return ret;
+ }
+
@Override
public String toString(){
StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index ee6d913..ee64bc8 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.util;
import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.math.NumberUtils;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -487,6 +488,12 @@ public class UtilFunctions {
}
}
+ public static double objectToDoubleSafe(ValueType vt, Object in) {
+ if(vt == ValueType.STRING && !NumberUtils.isCreatable((String)
in)) {
+ return 1.0;
+ } else return objectToDouble(vt, in);
+ }
+
public static double objectToDouble(ValueType vt, Object in) {
if( in == null ) return Double.NaN;
switch( vt ) {
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/FrameRemoveEmptyTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameRemoveEmptyTest.java
new file mode 100644
index 0000000..d3bbdc4
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/frame/FrameRemoveEmptyTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.frame;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.unary.matrix.RemoveEmptyTest;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class FrameRemoveEmptyTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "removeEmpty1";
+ private final static String TEST_DIR = "functions/frame/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RemoveEmptyTest.class.getSimpleName() + "/";
+
+ private final static int _rows = 10;
+ private final static int _cols = 6;
+
+ private final static double _sparsityDense = 0.7;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"V"}));
+ }
+
+ @Test
+ public void testRemoveEmptyRowsDenseCP() {
+ runTestRemoveEmpty(TEST_NAME1, "rows", Types.ExecType.CP,
false);
+ }
+
+ @Test
+ public void testRemoveEmptyRowsSparseCP() {
+ runTestRemoveEmpty(TEST_NAME1, "cols", Types.ExecType.CP, true);
+ }
+
+ @Test
+ @Ignore
+ public void testRemoveEmptyRowsDenseSP() {
+ runTestRemoveEmpty(TEST_NAME1, "rows", Types.ExecType.SPARK,
false);
+ }
+
+ @Test
+ @Ignore
+ public void testRemoveEmptyRowsSparseSP() {
+ runTestRemoveEmpty(TEST_NAME1, "rows", Types.ExecType.SPARK,
true);
+ }
+
+ private void runTestRemoveEmpty(String testname, String margin,
Types.ExecType et, boolean bSelectIndex) {
+ // rtplatform for MR
+ Types.ExecMode platformOld = rtplatform;
+ switch(et) {
+ case SPARK:
+ rtplatform = Types.ExecMode.SPARK;
+ break;
+ default:
+ rtplatform = Types.ExecMode.HYBRID;
+ break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if(rtplatform == Types.ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ try {
+ // register test configuration
+ TestConfiguration config =
getTestConfiguration(testname);
+ config.addVariable("rows", _rows);
+ config.addVariable("cols", _cols);
+ loadTestConfiguration(config);
+
+ /* This is for running the junit test the new way,
i.e., construct the arguments directly */
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] {"-explain", "-args",
input("V"), margin, output("V")};
+
+ MatrixBlock in = createInputMatrix(margin, _rows,
_cols, _sparsityDense, bSelectIndex);
+
+ runTest(true, false, null, -1);
+ double[][] outArray =
TestUtils.convertHashMapToDoubleArray(readDMLMatrixFromOutputDir("V"));
+ MatrixBlock out = new MatrixBlock(outArray.length,
outArray[0].length, false);
+ out.init(outArray, outArray.length, outArray[0].length);
+
+ MatrixBlock in2 = new MatrixBlock(_rows, _cols + 2,
0.0);
+ in2.copy(0, _rows - 1, 0, _cols - 1, in, true);
+ in2.copy(0, (_rows / 2) - 1, _cols, _cols + 1, new
MatrixBlock(_rows / 2, 2, 1.0), true);
+ MatrixBlock expected = in2.removeEmptyOperations(new
MatrixBlock(), margin.equals("rows"), false, null);
+ expected = expected.slice(0, expected.getNumRows() - 1,
0, expected.getNumColumns() - 3);
+
+ TestUtils.compareMatrices(expected, out, 0);
+ }
+ finally {
+ // reset platform for additional tests
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+ private MatrixBlock createInputMatrix(String margin, int rows, int
cols, double sparsity, boolean bSelectIndex) {
+ int rowsp = -1, colsp = -1;
+ if(margin.equals("rows")) {
+ rowsp = rows / 2;
+ colsp = cols;
+ }
+ else {
+ rowsp = rows;
+ colsp = cols / 2;
+ }
+
+ // long seed = System.nanoTime();
+ double[][] V = getRandomMatrix(rows, cols, 0, 1, sparsity, 7);
+ double[][] Vp = new double[rowsp][colsp];
+ double[][] Ix = null;
+ int innz = 0, vnnz = 0;
+
+ // clear out every other row/column
+ if(margin.equals("rows")) {
+ Ix = new double[rows][1];
+ for(int i = 0; i < rows; i++) {
+ boolean clear = i % 2 != 0;
+ if(clear) {
+ for(int j = 0; j < cols; j++)
+ V[i][j] = 0;
+ Ix[i][0] = 0;
+ }
+ else {
+ boolean bNonEmpty = false;
+ for(int j = 0; j < cols; j++) {
+ Vp[i / 2][j] = V[i][j];
+ bNonEmpty |= (V[i][j] != 0.0) ?
true : false;
+ vnnz += (V[i][j] == 0.0) ? 0 :
1;
+ }
+ Ix[i][0] = (bNonEmpty) ? 1 : 0;
+ innz += Ix[i][0];
+ }
+ }
+ }
+ else {
+ Ix = new double[1][cols];
+ for(int j = 0; j < cols; j++) {
+ boolean clear = j % 2 != 0;
+ if(clear) {
+ for(int i = 0; i < rows; i++)
+ V[i][j] = 0;
+ Ix[0][j] = 0;
+ }
+ else {
+ boolean bNonEmpty = false;
+ for(int i = 0; i < rows; i++) {
+ Vp[i][j / 2] = V[i][j];
+ bNonEmpty |= (V[i][j] != 0.0) ?
true : false;
+ vnnz += (V[i][j] == 0.0) ? 0 :
1;
+ }
+ Ix[0][j] = (bNonEmpty) ? 1 : 0;
+ innz += Ix[0][j];
+ }
+ }
+ }
+
+ MatrixCharacteristics imc = new
MatrixCharacteristics(margin.equals("rows") ? rows : 1,
+ margin.equals("rows") ? 1 : cols, 1000, innz);
+ MatrixCharacteristics vmc = new MatrixCharacteristics(rows,
cols, 1000, vnnz);
+
+ MatrixBlock in = new MatrixBlock(rows, cols, false);
+ in.init(V, _rows, _cols);
+
+ writeInputMatrixWithMTD("V", V, false, vmc); // always text
+ writeExpectedMatrix("V", Vp);
+ if(bSelectIndex)
+ writeInputMatrixWithMTD("I", Ix, false, imc);
+
+ return in;
+ }
+}
diff --git a/src/test/scripts/functions/frame/removeEmpty1.dml
b/src/test/scripts/functions/frame/removeEmpty1.dml
new file mode 100644
index 0000000..696880e
--- /dev/null
+++ b/src/test/scripts/functions/frame/removeEmpty1.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1, naStrings= ["NA", "null"," ","NaN", "nan", "", "?", "99999"])
+B = frame(data=["TRUE", "abc"], rows=nrow(A) / 2, cols=2, schema=["BOOLEAN",
"STRING"])
+C = frame(data=["FALSE", "0.0"], rows=nrow(A) / 2, cols=2, schema=["BOOLEAN",
"STRING"])
+D = rbind(B, C)
+V = cbind(as.frame(A), D)
+Vp = removeEmpty(target=V, margin=$2)
+X = as.matrix(Vp[, 1:(ncol(Vp)-2)])
+write(X, $3);