This is an automated email from the ASF dual-hosted git repository. ssiddiqi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new d8a996a [SYSTEMDS-3194] Builtins for detecting and correcting value swaps 1. fixInvalidLengths() to detect and correct value swaps in columns of same schema using string length information 2. valueSwaps() - to detect and correct value swaps in columns of different schema using detetcSchema() and distance function. d8a996a is described below commit d8a996a482cc8f16c72386f3127f1db1427764f7 Author: Shafaq Siddiqi <shafaq.sidd...@tugraz.at> AuthorDate: Tue Nov 2 12:17:56 2021 +0100 [SYSTEMDS-3194] Builtins for detecting and correcting value swaps 1. fixInvalidLengths() to detect and correct value swaps in columns of same schema using string length information 2. valueSwaps() - to detect and correct value swaps in columns of different schema using detetcSchema() and distance function. --- scripts/builtin/fixInvalidLengths.dml | 66 ++++++++++ scripts/pipelines/scripts/utils.dml | 26 ++-- .../java/org/apache/sysds/common/Builtins.java | 2 + src/main/java/org/apache/sysds/common/Types.java | 6 +- .../sysds/parser/BuiltinFunctionExpression.java | 1 + .../org/apache/sysds/parser/DMLTranslator.java | 1 + .../sysds/runtime/functionobjects/Builtin.java | 3 +- .../runtime/instructions/CPInstructionParser.java | 1 + .../runtime/instructions/InstructionUtils.java | 6 +- .../runtime/instructions/SPInstructionParser.java | 1 + .../cp/BinaryFrameFrameCPInstruction.java | 6 + .../spark/BinaryFrameFrameSPInstruction.java | 26 +++- .../sysds/runtime/matrix/data/FrameBlock.java | 143 ++++++++++++++++----- .../builtin/BuiltinFixInvalidLengths.java | 82 ++++++++++++ .../test/functions/frame/FrameValueSwapTest.java | 88 +++++++++++++ .../functions/frame/fixInvalidLengthstest.dml | 46 +++++++ src/test/scripts/functions/frame/valueSwaps.dml | 41 ++++++ .../intermediates/classification/bestAcc.csv | 6 +- .../pipelines/intermediates/classification/hp.csv | 6 +- .../pipelines/intermediates/classification/pip.csv | 2 +- 20 files changed, 508 insertions(+), 51 deletions(-) diff --git a/scripts/builtin/fixInvalidLengths.dml b/scripts/builtin/fixInvalidLengths.dml new file mode 100644 index 0000000..8b0ec8a --- /dev/null +++ b/scripts/builtin/fixInvalidLengths.dml @@ -0,0 +1,66 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +s_fixInvalidLengths = function(Frame[Unknown] F1, Double ql = 0.05, Double qu = 0.99) +return (Frame[Unknown] out, Matrix[Double] M) +{ + + length = map(F1, "x -> x.length()") + length = as.matrix(length) + M = getInvalidsMask(length, ql, qu) + # # # check if mask vector has 1 in more than one column + # # # this indicates that two values are being swapped and can be fixed + rowCount = rowSums(M) > 1 + if(sum(rowCount) > 0) + { + countTotalSwaps = sum(rowCount) + # # get the row index for swapping + rowIds = rowCount * seq(1, nrow(rowCount)) + rowIds = removeEmpty(target=rowIds, margin="rows") + colIds = M * t(seq(1, ncol(M))) + for(i in 1:countTotalSwaps) + { + rowIdx = as.scalar(rowIds[i, 1]) + colIdx = removeEmpty(target = colIds[rowIdx], margin="cols") + id1 = as.scalar(colIdx[1, 1]) + id2 = as.scalar(colIdx[1, 2]) + tmp = F1[rowIdx, id1] + F1[rowIdx, id1] = F1[rowIdx, id2] + F1[rowIdx, id2] = tmp + # # remove the mask for fixed entries + M[rowIdx, id1] = 0 + M[rowIdx, id2] = 0 + } + } + M = replace(target = M, pattern = 1, replacement = NaN) + out = F1 +} + +getInvalidsMask = function(Matrix[Double] X, Double ql = 0.05, Double qu = 0.99) +return (Matrix[Double] Y) { + + Y = matrix(0, nrow(X), ncol(X)) + parfor(i in 1:ncol(X), check=0) { + q1 = quantile(X[,i], ql) + q2 = quantile(X[,i], qu) + Y[, i] = ( X[, i] < q1 | X[, i] > q2) + } +} diff --git a/scripts/pipelines/scripts/utils.dml b/scripts/pipelines/scripts/utils.dml index 09c681a..40d3d5b 100644 --- a/scripts/pipelines/scripts/utils.dml +++ b/scripts/pipelines/scripts/utils.dml @@ -149,21 +149,31 @@ return(Boolean validForResources) ###################################### stringProcessing = function(Frame[Unknown] data, Matrix[Double] mask, Frame[String] schema, Boolean CorrectTypos, List[Unknown] ctx = list(prefix="--")) -return(Frame[Unknown] processedData) -{ +return(Frame[Unknown] processedData, Matrix[Double] M) +{ + M = mask prefix = as.scalar(ctx["prefix"]); - - # step 1 drop invalid types + # step 1 fix swap values + print(prefix+" value swap fixing"); + data = valueSwap(data, schema) + + # step 2 fix invalid lengths + print(prefix+" fixing invalid lengths between 5th and 95th quantile"); + q0 = 0.05 + q1 = 0.95 + [data, M] = fixInvalidLengths(data, q0, q1) + + # step 3 drop invalid types print(prefix+" drop values with type mismatch"); data = dropInvalidType(data, schema) - print("dropped invalids") - # step 2 do the case transformations + # step 4 do the case transformations print(prefix+" convert strings to lower case"); for(i in 1:ncol(mask)) if(as.scalar(schema[1,i]) == "STRING") data[, i] = map(data[, i], "x -> x.toLowerCase()") - + + # step 5 typo correction if(CorrectTypos) { # recode data to get null mask @@ -183,7 +193,7 @@ return(Frame[Unknown] processedData) if(as.scalar(schema[1,i]) == "STRING") data[, i] = correctTypos(data[, i], nullMask[, i], 0.2, 0.9, FALSE, TRUE, FALSE); } - + # step 6 porter stemming on all features print(prefix+" porter-stemming on all features"); data = map(data, "x -> PorterStemmer.stem(x)") diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index a5ad588..6b94b55 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -122,6 +122,7 @@ public enum Builtins { EXECUTE_PIPELINE("executePipeline", true), EXP("exp", false), EVAL("eval", false), + FIX_INVALID_LENGTHS("fixInvalidLengths", true), FF_TRAIN("ffTrain", true), FF_PREDICT("ffPredict", true), FLOOR("floor", false), @@ -267,6 +268,7 @@ public enum Builtins { TYPEOF("typeof", false), UNIVAR("univar", true), VAR("var", false), + VALUE_SWAP("valueSwap", false), VECTOR_TO_CSV("vectorToCsv", true), WINSORIZE("winsorize", true, false), //TODO parameterize w/ prob, min/max val XGBOOST("xgboost", true), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 4e2cef7..b5d1330 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -308,8 +308,8 @@ public class Types GREATEREQUAL(true), INTDIV(true), INTERQUANTILE(false), IQM(false), LESS(true), LESSEQUAL(true), LOG(true), MAP(false), MAX(true), MEDIAN(false), MIN(true), MINUS(true), MODULUS(true), MOMENT(false), MULT(true), NOTEQUAL(true), OR(true), - PLUS(true), POW(true), PRINT(false), QUANTILE(false), SOLVE(false), RBIND(false), - XOR(true), + PLUS(true), POW(true), PRINT(false), QUANTILE(false), SOLVE(false), + RBIND(false), VALUE_SWAP(false), XOR(true), //fused ML-specific operators for performance MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) @@ -354,6 +354,7 @@ public class Types case BITWSHIFTR: return "bitwShiftR"; case DROP_INVALID_TYPE: return "dropInvalidType"; case DROP_INVALID_LENGTH: return "dropInvalidLength"; + case VALUE_SWAP: return "valueSwap"; case MAP: return "_map"; default: return name().toLowerCase(); } @@ -388,6 +389,7 @@ public class Types case "bitwShiftR": return BITWSHIFTR; case "dropInvalidType": return DROP_INVALID_TYPE; case "dropInvalidLength": return DROP_INVALID_LENGTH; + case "valueSwap": return VALUE_SWAP; case "map": return MAP; default: return valueOf(opcode.toUpperCase()); } diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index ea58c02..87b1a23 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -1544,6 +1544,7 @@ public class BuiltinFunctionExpression extends DataIdentifier break; case DROP_INVALID_TYPE: + case VALUE_SWAP: checkNumParameters(2); checkMatrixFrameParam(getFirstExpr()); checkMatrixFrameParam(getSecondExpr()); diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 430fe7f..50e2b09 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2536,6 +2536,7 @@ public class DMLTranslator break; case DROP_INVALID_TYPE: case DROP_INVALID_LENGTH: + case VALUE_SWAP: case MAP: currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2); diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java index 45b6c33..4f423c2 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java @@ -50,7 +50,7 @@ public class Builtin extends ValueFunction public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, - TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, DROP_INVALID_LENGTH, MAP, + TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, DROP_INVALID_LENGTH, VALUE_SWAP, MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX} @@ -109,6 +109,7 @@ public class Builtin extends ValueFunction String2BuiltinCode.put( "dropInvalidType", BuiltinCode.DROP_INVALID_TYPE); String2BuiltinCode.put( "dropInvalidLength", BuiltinCode.DROP_INVALID_LENGTH); String2BuiltinCode.put( "_map", BuiltinCode.MAP); + String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP); } private Builtin(BuiltinCode bf) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index d207a6c..236f3e0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -157,6 +157,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "min" , CPType.Binary); String2CPInstructionType.put( "dropInvalidType" , CPType.Binary); String2CPInstructionType.put( "dropInvalidLength" , CPType.Binary); + String2CPInstructionType.put( "valueSwap" , CPType.Binary); String2CPInstructionType.put( "_map" , CPType.Binary); // _map represents the operation map String2CPInstructionType.put( "nmax", CPType.BuiltinNary); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index 9991edf..a13768d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -605,6 +605,8 @@ public class InstructionUtils return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidType")); else if( opcode.equalsIgnoreCase("dropInvalidLength")) return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength")); + else if( opcode.equalsIgnoreCase("valueSwap")) + return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap")); throw new RuntimeException("Unknown binary opcode " + opcode); } @@ -840,7 +842,9 @@ public class InstructionUtils return new BinaryOperator(Builtin.getBuiltinFnObject("min")); else if ( opcode.equalsIgnoreCase("dropInvalidLength") || opcode.equalsIgnoreCase("mapdropInvalidLength") ) return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength")); - + else if ( opcode.equalsIgnoreCase("valueSwap") || opcode.equalsIgnoreCase("mapValueSwap") ) + return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap")); + throw new DMLRuntimeException("Unknown binary opcode " + opcode); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index a02490f..85abf32 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -182,6 +182,7 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "map-*", SPType.Binary); String2SPInstructionType.put( "dropInvalidType", SPType.Binary); String2SPInstructionType.put( "mapdropInvalidLength", SPType.Binary); + String2SPInstructionType.put( "valueSwap", SPType.Binary); String2SPInstructionType.put( "_map", SPType.Binary); // _map refers to the operation map // Relational Instruction Opcodes String2SPInstructionType.put( "==" , SPType.Binary); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java index 8bc8744..59b4591 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java @@ -43,6 +43,12 @@ public class BinaryFrameFrameCPInstruction extends BinaryCPInstruction // Attach result frame with FrameBlock associated with output_name ec.setFrameOutput(output.getName(), retBlock); } + else if(getOpcode().equals("valueSwap")) { + // Perform computation using input frames, and produce the result frame + FrameBlock retBlock = inBlock1.valueSwap(inBlock2); + // Attach result frame with FrameBlock associated with output_name + ec.setFrameOutput(output.getName(), retBlock); + } else { // Execute binary operations BinaryOperator dop = (BinaryOperator) _optr; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java index deb8fb4..acf5bf3 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java @@ -52,6 +52,13 @@ public class BinaryFrameFrameSPInstruction extends BinarySPInstruction { //release input frame sec.releaseFrameInput(input2.getName()); } + else if(getOpcode().equals("valueSwap")) { + // Perform computation using input frames, and produce the result frame + Broadcast<FrameBlock> fb = sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName())); + out = in1.mapValues(new valueSwapBySchema(fb.getValue())); + // Attach result frame with FrameBlock associated with output_name + sec.releaseFrameInput(input2.getName()); + } else { JavaPairRDD<Long, FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable(input2.getName()); // create output frame @@ -63,14 +70,14 @@ public class BinaryFrameFrameSPInstruction extends BinarySPInstruction { //set output RDD and maintain dependencies sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); - if( !getOpcode().equals("dropInvalidType") ) + if( !getOpcode().equals("dropInvalidType") && !getOpcode().equals("valueSwap")) sec.addLineageRDD(output.getName(), input2.getName()); } private static class isCorrectbySchema implements Function<FrameBlock,FrameBlock> { private static final long serialVersionUID = 5850400295183766400L; - private FrameBlock schema_frame = null; + private FrameBlock schema_frame; public isCorrectbySchema(FrameBlock fb_name ) { schema_frame = fb_name; @@ -94,4 +101,19 @@ public class BinaryFrameFrameSPInstruction extends BinarySPInstruction { return arg0._1().binaryOperations(bop, arg0._2(), null); } } + + private static class valueSwapBySchema implements Function<FrameBlock,FrameBlock> { + private static final long serialVersionUID = 5850400295183766402L; + + private FrameBlock schema_frame; + + public valueSwapBySchema(FrameBlock fb_name ) { + schema_frame = fb_name; + } + + @Override + public FrameBlock call(FrameBlock arg0) throws Exception { + return arg0.valueSwap(schema_frame); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java index d146d1d..fce1b38 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java @@ -30,13 +30,7 @@ import java.lang.ref.SoftReference; import java.lang.reflect.InvocationTargetException; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -46,6 +40,7 @@ import java.util.function.Function; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.NotImplementedException; +import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.io.Writable; @@ -1052,54 +1047,62 @@ public class FrameBlock implements CacheBlock, Externalizable { throw new DMLRuntimeException("Frame dimension mismatch "+getNumRows()+" * "+getNumColumns()+ " != "+that.getNumRows()+" * "+that.getNumColumns()); String[][] outputData = new String[getNumRows()][getNumColumns()]; - //compare output value, incl implicit type promotion if necessary - if( !(bop.fn instanceof ValueComparisonFunction) ) + if(bop.fn instanceof ValueComparisonFunction) { + ValueComparisonFunction vcomp = (ValueComparisonFunction) bop.fn; + out = executeValueComparisons(this, that, vcomp, outputData); + } + else throw new DMLRuntimeException("Unsupported binary operation on frames (only comparisons supported)"); - ValueComparisonFunction vcomp = (ValueComparisonFunction) bop.fn; - for (int i = 0; i < getNumColumns(); i++) { - if (getSchema()[i] == ValueType.STRING || that.getSchema()[i] == ValueType.STRING) { - for (int j = 0; j < getNumRows(); j++) { - if(checkAndSetEmpty(this, that, outputData, j, i)) + return out; + } + + private FrameBlock executeValueComparisons(FrameBlock frameBlock, FrameBlock that, ValueComparisonFunction vcomp, + String[][] outputData) { + for(int i = 0; i < getNumColumns(); i++) { + if(getSchema()[i] == ValueType.STRING || that.getSchema()[i] == ValueType.STRING) { + for(int j = 0; j < getNumRows(); j++) { + if(checkAndSetEmpty(frameBlock, that, outputData, j, i)) continue; String v1 = UtilFunctions.objectToString(get(j, i)); String v2 = UtilFunctions.objectToString(that.get(j, i)); outputData[j][i] = String.valueOf(vcomp.compare(v1, v2)); } } - else if (getSchema()[i] == ValueType.FP64 || that.getSchema()[i] == ValueType.FP64 || - getSchema()[i] == ValueType.FP32 || that.getSchema()[i] == ValueType.FP32) { - for (int j = 0; j < getNumRows(); j++) { - if(checkAndSetEmpty(this, that, outputData, j, i)) + else if(getSchema()[i] == ValueType.FP64 || that + .getSchema()[i] == ValueType.FP64 || getSchema()[i] == ValueType.FP32 || that + .getSchema()[i] == ValueType.FP32) { + for(int j = 0; j < getNumRows(); j++) { + if(checkAndSetEmpty(frameBlock, that, outputData, j, i)) continue; ScalarObject so1 = new DoubleObject(Double.parseDouble(get(j, i).toString())); ScalarObject so2 = new DoubleObject(Double.parseDouble(that.get(j, i).toString())); outputData[j][i] = String.valueOf(vcomp.compare(so1.getDoubleValue(), so2.getDoubleValue())); } } - else if (getSchema()[i] == ValueType.INT64 || that.getSchema()[i] == ValueType.INT64 || - getSchema()[i] == ValueType.INT32 || that.getSchema()[i] == ValueType.INT32) { - for (int j = 0; j < this.getNumRows(); j++) { - if(checkAndSetEmpty(this, that, outputData, j, i)) + else if(getSchema()[i] == ValueType.INT64 || that + .getSchema()[i] == ValueType.INT64 || getSchema()[i] == ValueType.INT32 || that + .getSchema()[i] == ValueType.INT32) { + for(int j = 0; j < this.getNumRows(); j++) { + if(checkAndSetEmpty(frameBlock, that, outputData, j, i)) continue; ScalarObject so1 = new IntObject(Integer.parseInt(get(j, i).toString())); ScalarObject so2 = new IntObject(Integer.parseInt(that.get(j, i).toString())); - outputData[j][i] = String.valueOf(vcomp.compare(so1.getLongValue(), so2.getLongValue())); + outputData[j][i] = String.valueOf(vcomp.compare(so1.getLongValue(), so2.getLongValue())); } } else { - for (int j = 0; j < getNumRows(); j++) { - if(checkAndSetEmpty(this, that, outputData, j, i)) + for(int j = 0; j < getNumRows(); j++) { + if(checkAndSetEmpty(frameBlock, that, outputData, j, i)) continue; - ScalarObject so1 = new BooleanObject( Boolean.parseBoolean(get(j, i).toString())); - ScalarObject so2 = new BooleanObject( Boolean.parseBoolean(that.get(j, i).toString())); + ScalarObject so1 = new BooleanObject(Boolean.parseBoolean(get(j, i).toString())); + ScalarObject so2 = new BooleanObject(Boolean.parseBoolean(that.get(j, i).toString())); outputData[j][i] = String.valueOf(vcomp.compare(so1.getBooleanValue(), so2.getBooleanValue())); } } } - - return new FrameBlock(UtilFunctions.nCopies(this.getNumColumns(), ValueType.BOOLEAN), outputData); + return new FrameBlock(UtilFunctions.nCopies(frameBlock.getNumColumns(), ValueType.BOOLEAN), outputData); } private static boolean checkAndSetEmpty(FrameBlock fb1, FrameBlock fb2, String[][] out, int r, int c) { @@ -2285,7 +2288,8 @@ public class FrameBlock implements CacheBlock, Externalizable { return DMVUtils.syntacticalPatternDiscovery(this, Double.parseDouble(arguments[0]), arguments[1]); } else if (args.contains(";")) { String[] arguments = args.split(";"); - return EMAUtils.exponentialMovingAverageImputation(this, Integer.parseInt(arguments[0]), arguments[1], Integer.parseInt(arguments[2]), Double.parseDouble(arguments[3]), Double.parseDouble(arguments[4]), Double.parseDouble(arguments[5])); + return EMAUtils.exponentialMovingAverageImputation(this, Integer.parseInt(arguments[0]), arguments[1], + Integer.parseInt(arguments[2]), Double.parseDouble(arguments[3]), Double.parseDouble(arguments[4]), Double.parseDouble(arguments[5])); } } if(lambdaExpr.contains("jaccardSim")) @@ -2293,6 +2297,85 @@ public class FrameBlock implements CacheBlock, Externalizable { return map(getCompiledFunction(lambdaExpr)); } + public FrameBlock valueSwap(FrameBlock schema) { + String[] schemaString = schema.getStringRowIterator().next(); + String dataValue2 = null; + double minSimScore = 0; + int bestIdx = 0; + // remove the precision info + for(int i = 0; i < schemaString.length; i++) + schemaString[i] = schemaString[i].replaceAll("\\d", ""); + + double[] minColLength = new double[this.getNumColumns()]; + double[] maxColLength = new double[this.getNumColumns()]; + + for(int k = 0; k < this.getNumColumns(); k++) { + String[] data = ((StringArray) this.getColumn(k))._data; + + double minLength = Arrays.stream(data).filter(Objects::nonNull).mapToDouble(String::length).min().orElse(Double.NaN); + double maxLength = Arrays.stream(data).filter(Objects::nonNull).mapToDouble(String::length).max().orElse(Double.NaN); + maxColLength[k] = maxLength; + minColLength[k] = minLength; + } + ArrayList<Integer> probColList = new ArrayList(); + for(int i = 0; i < this.getNumColumns(); i++) { + for(int j = 0; j < this.getNumRows(); j++) { + if(this.get(j, i) == null) + continue; + String dataValue = this.get(j, i).toString().trim().replace("\"", "").toLowerCase(); + ValueType dataType = isType(dataValue); + + String type = dataType.toString().replaceAll("\\d", ""); + // get the avergae column length + if(!dataType.toString().contains(schemaString[i]) && !(dataType == ValueType.BOOLEAN && schemaString[i] + .equals("INT")) && !(dataType == ValueType.BOOLEAN && schemaString[i].equals("FP")) && !(dataType + .toString().contains("INT") && schemaString[i].equals("FP"))) { + LOG.warn("conflict " + dataType + " " + schemaString[i] + " " + dataValue); + // check the other column with satisfy the data type of this value + for(int w = 0; w < schemaString.length; w++) { + if(schemaString[w].equals(type) && dataValue.length() > minColLength[w] && dataValue + .length() < maxColLength[w] && (w != i)) { + Object item = this.get(j, w); + String dataValueProb = (item != null) ? item.toString().trim().replace("\"", "") + .toLowerCase() : "0"; + ValueType dataTypeProb = isType(dataValueProb); + if(!dataTypeProb.toString().equals(schemaString[w])) { + bestIdx = w; + break; + } + probColList.add(w); + } + } + // if we have more than one column that is the probable match for this value then find the most + // appropriate one by using the similarity score + if(probColList.size() > 1) { + for(int w : probColList) { + int randomIndex = ThreadLocalRandom.current().nextInt(0, _numRows - 1); + Object value = this.get(randomIndex, w); + if(value != null) { + dataValue2 = value.toString(); + } + + // compute distance between sample and invalid value + double simScore = StringUtils.getLevenshteinDistance(dataValue, dataValue2); + if(simScore < minSimScore) { + minSimScore = simScore; + bestIdx = w; + } + } + } + else if(probColList.size() > 0) { + bestIdx = probColList.get(0); + } + String tmp = dataValue; + this.set(j, i, this.get(j, bestIdx)); + this.set(j, bestIdx, tmp); + } + } + } + return this; + } + public FrameBlock map (FrameMapFunction lambdaExpr) { // Prepare temporary output array String[][] output = new String[getNumRows()][getNumColumns()]; diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinFixInvalidLengths.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinFixInvalidLengths.java new file mode 100644 index 0000000..6ea48d5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinFixInvalidLengths.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.builtin; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class BuiltinFixInvalidLengths extends AutomatedTestBase { + private final static String TEST_NAME = "fixInvalidLengthstest"; + private final static String TEST_DIR = "functions/frame/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinFixInvalidLengths.class.getSimpleName() + "/"; + private final static String INPUT = DATASET_DIR+"/Salaries.csv"; + + public static void init() { + TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR); + } + + public static void cleanUp() { + if (TEST_CACHE_ENABLED) { + TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR); + } + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"})); + } + @Test + public void fixInvalidTestCP() { + runFixInvalidLength(Types.ExecType.CP); + } + + @Test + public void fixInvalidTestSP() { + runFixInvalidLength(Types.ExecType.SPARK); + } + + private void runFixInvalidLength(Types.ExecType et) + { + Types.ExecMode platformOld = setExecMode(et); + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + System.out.println(fullDMLScriptName); + programArgs = new String[] {"-args", INPUT, output("B")}; + runTest(true, false, null, -1); + boolean retCondition = HDFSTool.readBooleanFromHDFSFile(output("B")); + Assert.assertEquals(true, retCondition); + + } + catch (Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameValueSwapTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameValueSwapTest.java new file mode 100644 index 0000000..bc45b4e --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameValueSwapTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.frame; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +public class FrameValueSwapTest extends AutomatedTestBase +{ + private final static String TEST_NAME = "valueSwaps"; + private final static String TEST_DIR = "functions/frame/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FrameValueSwapTest.class.getSimpleName() + "/"; + + private final static String INPUT = DATASET_DIR+"/homes3/homes.csv"; + + public static void init() { + TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR); + } + + public static void cleanUp() { + if (TEST_CACHE_ENABLED) { + TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR); + } + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"})); + } + // + @Test + public void tesSwapValueTestCP() { + runValueSwapTest(ExecType.CP); + } + + // TODO fix frame comparisons in spark context + @Ignore + public void tesSwapValueTestSP() { + runValueSwapTest(ExecType.SPARK); + } + + private void runValueSwapTest(ExecType et) + { + Types.ExecMode platformOld = setExecMode(et); + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-args", INPUT, output("B")}; + runTest(true, false, null, -1); + boolean retCondition = HDFSTool.readBooleanFromHDFSFile(output("B")); + Assert.assertEquals(true, retCondition); + + } + catch (Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + } + } +} diff --git a/src/test/scripts/functions/frame/fixInvalidLengthstest.dml b/src/test/scripts/functions/frame/fixInvalidLengthstest.dml new file mode 100644 index 0000000..6c6e199 --- /dev/null +++ b/src/test/scripts/functions/frame/fixInvalidLengthstest.dml @@ -0,0 +1,46 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +F = read($1, data_type="frame", format="csv", header=TRUE, + naStrings= ["NA", "null"," ","NaN", "nan", "", "?", "99999"]); + +# # get the length +F = F[, 2:ncol(F)] + +F1 = F +idx = sample(nrow(F), 15) +# # swap values +for(i in 1:nrow(idx)) +{ + r = as.scalar(idx[i]) + tmp = F1[r, 1] + F1[r, 1] = F1[r, 2] + F1[r, 2] = tmp +} +q0 = 0.05 +q1 = 0.95 + +[W, M] = fixInvalidLengths(F1, q0, q1) +comp = as.matrix(W != F) +out = sum(comp) == 0 +print(out) +write(out, $2) diff --git a/src/test/scripts/functions/frame/valueSwaps.dml b/src/test/scripts/functions/frame/valueSwaps.dml new file mode 100644 index 0000000..68e948f --- /dev/null +++ b/src/test/scripts/functions/frame/valueSwaps.dml @@ -0,0 +1,41 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# read the inputs +F = read($1, data_type="frame", format="csv", header=TRUE, naStrings= ["NA", "null"," ","NaN", "nan", "", "?", "99999"]); + +d = detectSchema(F) +idx = sample(20, 10) +F1 = F +# # swap values +for(i in 1:10) +{ + r = as.scalar(idx[i]) + tmp = F1[r, 1] + F1[r, 1] = F1[r, 2] + F1[r, 2] = tmp +} + +R = valueSwap(F1, d) +f1 = as.matrix(F == R) +result = ((ncol(F) * nrow(F)) == sum(f1)) + +write(result, $2) \ No newline at end of file diff --git a/src/test/scripts/functions/pipelines/intermediates/classification/bestAcc.csv b/src/test/scripts/functions/pipelines/intermediates/classification/bestAcc.csv index 646eef1..b0f99b2 100644 --- a/src/test/scripts/functions/pipelines/intermediates/classification/bestAcc.csv +++ b/src/test/scripts/functions/pipelines/intermediates/classification/bestAcc.csv @@ -1,3 +1,3 @@ -94.5945945945946 -94.5945945945946 -94.5945945945946 +95.4954954954955 +95.4954954954955 +95.4954954954955 diff --git a/src/test/scripts/functions/pipelines/intermediates/classification/hp.csv b/src/test/scripts/functions/pipelines/intermediates/classification/hp.csv index af4ceec..d2046a7 100644 --- a/src/test/scripts/functions/pipelines/intermediates/classification/hp.csv +++ b/src/test/scripts/functions/pipelines/intermediates/classification/hp.csv @@ -1,3 +1,3 @@ -32.0,2.0,0.029000277257674192,0.9510406998977287,0,0,0,1.0,0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,1.0,0,2.0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 -32.0,2.0,0.033677950367757156,0.9519989979315087,0,0,0,1.0,0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,1.0,0,2.0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 -36.0,3.0,3.0,2.0,1.0,0,0,0,1.0,0,0,0,0,0,1.0,0,0,0,2.0,1.0,0.6200453262235062,0,0,0,0,1.0,1.0,2.0,0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +32.0,2.0,0.021905734704206918,0.9702203388691355,0,0,0,1.0,0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,1.0,0,2.0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +32.0,2.0,0.01621980591011659,0.9590392517606071,0,0,0,1.0,0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,1.0,0,2.0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +32.0,2.0,0.013397523041969582,0.9683942733160031,0,0,0,1.0,0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,1.0,0,2.0,0,0,0,1.0,0,0,0,2.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 diff --git a/src/test/scripts/functions/pipelines/intermediates/classification/pip.csv b/src/test/scripts/functions/pipelines/intermediates/classification/pip.csv index b88ec19..1264630 100644 --- a/src/test/scripts/functions/pipelines/intermediates/classification/pip.csv +++ b/src/test/scripts/functions/pipelines/intermediates/classification/pip.csv @@ -1,3 +1,3 @@ winsorize,imputeByMedian,wtomeklink,dummycoding winsorize,imputeByMedian,wtomeklink,dummycoding -outlierBySd,imputeByMean,abstain,dummycoding +winsorize,imputeByMedian,wtomeklink,dummycoding