This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new f4b2a2b [SYSTEMDS-2972] Initial Multi-threaded transformencode
f4b2a2b is described below
commit f4b2a2b6c974a2584a97119ade9fb8cf031250b6
Author: Lukas Erlbacher <[email protected]>
AuthorDate: Wed May 12 22:37:52 2021 +0200
[SYSTEMDS-2972] Initial Multi-threaded transformencode
This patch adds basic multithreading to the transformencode logic.
Each ColumnEncoder can be executed on a separate thread or can be
split up into even smaller subjobs which only apply to a certain row range.
Initial benchmarks with 16CPUs show up to a 50x speed improvement
in comparison to the old SystemML implementation.
Currently, this code is dormant, which means a call to transformencode
in a DML script still uses a single threaded implementation.
Large Matrices (e.g. 1000000x1000) are still not viable due to suspected
Thread starving. This will be addressed in a future PR with some sort of
access partitioning (Radix/Range).
Closes #1261.
---
.../sysds/runtime/matrix/data/FrameBlock.java | 14 ++
.../sysds/runtime/matrix/data/MatrixBlock.java | 31 ++-
.../runtime/transform/encode/ColumnEncoder.java | 13 ++
.../runtime/transform/encode/ColumnEncoderBin.java | 95 +++++++--
.../transform/encode/ColumnEncoderComposite.java | 56 ++++-
.../transform/encode/ColumnEncoderDummycode.java | 36 +++-
.../transform/encode/ColumnEncoderFeatureHash.java | 39 +++-
.../transform/encode/ColumnEncoderPassThrough.java | 46 ++++-
.../transform/encode/ColumnEncoderRecode.java | 116 +++++++++--
.../transform/encode/MultiColumnEncoder.java | 225 +++++++++++++++++++--
.../apache/sysds/runtime/util/UtilFunctions.java | 7 +
.../mt/TransformFrameBuildMultithreadedTest.java | 192 ++++++++++++++++++
.../mt/TransformFrameEncodeMultithreadedTest.java | 198 ++++++++++++++++++
.../datasets/homes3/homes.tfspec_dummy_all.json | 1 +
14 files changed, 991 insertions(+), 78 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index b280903..bc4bfc3 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -651,6 +651,20 @@ public class FrameBlock implements CacheBlock,
Externalizable {
}
/**
+ * Get a row iterator over the frame where all selected fields are
+ * encoded as strings independent of their value types.
+ *
+ * @param rl lower row index
+ * @param ru upper row index
+ * @param colID columnID, 1-based
+ * @return string array iterator
+ */
+ public Iterator<String[]> getStringRowIterator(int rl, int ru, int
colID) {
+ return new StringRowIterator(rl, ru, new int[] {colID});
+ }
+
+
+ /**
* Get a row iterator over the frame where all fields are encoded
* as boxed objects according to their value types.
*
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 3e8e305..4031a67 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -643,7 +643,36 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
nonZeros--;
}
}
-
+
+ /*
+ Thread save set.
+ Blocks need to be allocated, and in case of MCSR sparse, all
rows
+ that are going to be accessed need to be allocated as well.
+ */
+ public void quickSetValueThreadSafe(int r, int c, double v) {
+ if(sparse) {
+ if(!(sparseBlock instanceof SparseBlockMCSR))
+ throw new RuntimeException("Only MCSR Blocks
are supported for Multithreaded sparse set.");
+ synchronized (sparseBlock.get(r)) {
+ sparseBlock.set(r,c,v);
+ }
+ }
+ else
+ denseBlock.set(r,c,v);
+ }
+
+ public double quickGetValueThreadSafe(int r, int c) {
+ if(sparse) {
+ if(!(sparseBlock instanceof SparseBlockMCSR))
+ throw new RuntimeException("Only MCSR Blocks
are supported for Multithreaded sparse get.");
+ synchronized (sparseBlock.get(r)) {
+ return sparseBlock.get(r,c);
+ }
+ }
+ else
+ return denseBlock.get(r,c);
+ }
+
public double getValueDenseUnsafe(int r, int c) {
if(denseBlock==null)
return 0;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 943e0a7..47afe0e 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -25,6 +25,10 @@ import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -47,6 +51,10 @@ public abstract class ColumnEncoder implements
Externalizable, Encoder, Comparab
public abstract MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol);
+ public abstract MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk);
+
+ public abstract MatrixBlock apply(FrameBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk);
+
/**
* Indicates if this encoder is applicable, i.e, if there is a column
to encode.
*
@@ -156,6 +164,11 @@ public abstract class ColumnEncoder implements
Externalizable, Encoder, Comparab
return Integer.compare(getEncoderType(this), getEncoderType(o));
}
+ public abstract List<Callable<Object>> getPartialBuildTasks(FrameBlock
in, int blockSize);
+
+ public abstract void mergeBuildPartial(List<Future<Object>>
futurePartials, int start, int end)
+ throws ExecutionException, InterruptedException;
+
public enum EncoderType {
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit,
MVImpute, Composite
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 45dbfde..0a00a05 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -19,10 +19,17 @@
package org.apache.sysds.runtime.transform.encode;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.MutableTriple;
import org.apache.sysds.lops.Lop;
@@ -82,15 +89,42 @@ public class ColumnEncoderBin extends ColumnEncoder {
public void build(FrameBlock in) {
if(!isApplicable())
return;
+ double[] pairMinMax = getMinMaxOfCol(in, _colID, 0, -1);
+ computeBins(pairMinMax[0], pairMinMax[1]);
+ }
+ private static double[] getMinMaxOfCol(FrameBlock in, int colID, int
startRow, int blockSize){
// derive bin boundaries from min/max per column
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
- for(int i = 0; i < in.getNumRows(); i++) {
- double inVal =
UtilFunctions.objectToDouble(in.getSchema()[_colID - 1], in.get(i, _colID - 1));
+ for(int i = startRow; i < getEndIndex(in.getNumRows(),
startRow, blockSize); i++) {
+ double inVal =
UtilFunctions.objectToDouble(in.getSchema()[colID - 1], in.get(i, colID - 1));
min = Math.min(min, inVal);
max = Math.max(max, inVal);
}
+ return new double[]{min, max};
+ }
+
+ @Override
+ public List<Callable<Object>> getPartialBuildTasks(FrameBlock in, int
blockSize){
+ List<Callable<Object>> tasks = new ArrayList<>();
+ for(int i = 0; i < in.getNumRows(); i=i+blockSize)
+ tasks.add(new BinPartialBuildTask(in, _colID, i,
blockSize));
+ if(in.getNumRows() % blockSize != 0)
+ tasks.add(new BinPartialBuildTask(in, _colID,
+
in.getNumRows()-in.getNumRows()%blockSize, -1));
+ return tasks;
+ }
+
+ @Override
+ public void mergeBuildPartial(List<Future<Object>> futurePartials, int
start, int end) throws ExecutionException, InterruptedException {
+ double min = Double.POSITIVE_INFINITY;
+ double max = Double.NEGATIVE_INFINITY;
+ for(int i = start; i < end; i++){
+ double[] pairMinMax = (double[])
futurePartials.get(i).get();
+ min = Math.min(min, pairMinMax[0]);
+ max = Math.max(max, pairMinMax[1]);
+ }
computeBins(min, max);
}
@@ -116,35 +150,40 @@ public class ColumnEncoderBin extends ColumnEncoder {
if(!isApplicable())
return;
// derive bin boundaries from min/max per column
- double min = Double.POSITIVE_INFINITY;
- double max = Double.NEGATIVE_INFINITY;
- for(int i = 0; i < in.getNumRows(); i++) {
- double inVal =
UtilFunctions.objectToDouble(in.getSchema()[_colID - 1], in.get(i, _colID - 1));
- min = Math.min(min, inVal);
- max = Math.max(max, inVal);
- }
- _colMins = min;
- _colMaxs = max;
+ double[] pairMinMax = getMinMaxOfCol(in, _colID, 0 ,-1);
+ _colMins = pairMinMax[0];
+ _colMaxs = pairMinMax[1];
}
@Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol)
{
- for(int i = 0; i < in.getNumRows(); i++) {
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol,
int rowStart, int blk) {
+ for(int i = rowStart; i < getEndIndex(in.getNumRows(),
rowStart, blk); i++) {
double inVal =
UtilFunctions.objectToDouble(in.getSchema()[_colID - 1], in.get(i, _colID - 1));
int ix = Arrays.binarySearch(_binMaxs, inVal);
int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
- out.quickSetValue(i, outputCol, binID);
+ out.quickSetValueThreadSafe(i, outputCol, binID);
}
return out;
}
@Override
- public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
- for(int i = 0; i < in.getNumRows(); i++) {
- double inVal = in.quickGetValue(i, _colID - 1);
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
+ int end = (blk <= 0)? in.getNumRows(): in.getNumRows() <
rowStart + blk ? in.getNumRows() : rowStart + blk;
+ for(int i = rowStart; i < end; i++) {
+ double inVal = in.quickGetValueThreadSafe(i, _colID -
1);
int ix = Arrays.binarySearch(_binMaxs, inVal);
int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
- out.quickSetValue(i, outputCol, binID);
+ out.quickSetValueThreadSafe(i, outputCol, binID);
}
return out;
}
@@ -236,4 +275,26 @@ public class ColumnEncoderBin extends ColumnEncoder {
_binMins[j] = in.readDouble();
}
}
+
+ private static class BinPartialBuildTask implements Callable<Object> {
+
+ private final FrameBlock _input;
+ private final int _blockSize;
+ private final int _startRow;
+ private final int _colID;
+
+ // if a pool is passed the task may be split up into multiple
smaller tasks.
+ protected BinPartialBuildTask(FrameBlock input, int colID, int
startRow, int blocksize){
+ _input = input;
+ _blockSize = blocksize;
+ _colID = colID;
+ _startRow = startRow;
+ }
+
+ @Override
+ public double[] call() throws Exception {
+ return getMinMaxOfCol(_input, _colID, _startRow,
_blockSize);
+ }
+ }
+
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index ae3507d..f3640f3 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -23,8 +23,13 @@ import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -42,6 +47,9 @@ public class ColumnEncoderComposite extends ColumnEncoder {
private List<ColumnEncoder> _columnEncoders = null;
private FrameBlock _meta = null;
+ // map to keep track of which encoder has how many build tasks
+ private Map<ColumnEncoder, Integer> _partialBuildTaskMap;
+
public ColumnEncoderComposite() {
super(-1);
}
@@ -93,6 +101,32 @@ public class ColumnEncoderComposite extends ColumnEncoder {
}
@Override
+ public List<Callable<Object>> getPartialBuildTasks(FrameBlock in, int
blockSize) {
+ List<Callable<Object>> tasks = new ArrayList<>();
+ _partialBuildTaskMap = new HashMap<>();
+ for(ColumnEncoder columnEncoder : _columnEncoders) {
+ List<Callable<Object>> _tasks =
columnEncoder.getPartialBuildTasks(in, blockSize);
+ if(_tasks != null)
+ tasks.addAll(_tasks);
+ _partialBuildTaskMap.put(columnEncoder, _tasks != null
? _tasks.size() : 0);
+ }
+ return tasks.size() == 0 ? null : tasks;
+ }
+
+ @Override
+ public void mergeBuildPartial(List<Future<Object>> futurePartials, int
start, int end)
+ throws ExecutionException, InterruptedException {
+ int endLocal;
+ for(ColumnEncoder columnEncoder : _columnEncoders) {
+ endLocal = start +
_partialBuildTaskMap.get(columnEncoder);
+ columnEncoder.mergeBuildPartial(futurePartials, start,
endLocal);
+ start = endLocal;
+ if(start >= end)
+ break;
+ }
+ }
+
+ @Override
public void prepareBuildPartial() {
for(ColumnEncoder columnEncoder : _columnEncoders)
columnEncoder.prepareBuildPartial();
@@ -106,14 +140,24 @@ public class ColumnEncoderComposite extends ColumnEncoder
{
@Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol)
{
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol,
int rowStart, int blk) {
try {
for(int i = 0; i < _columnEncoders.size(); i++) {
if(i == 0) {
// 1. encoder writes data into
MatrixBlock Column all others use this column for further encoding
- _columnEncoders.get(i).apply(in, out,
outputCol);
+ _columnEncoders.get(i).apply(in, out,
outputCol, rowStart, blk);
}
else {
- _columnEncoders.get(i).apply(out, out,
outputCol);
+ _columnEncoders.get(i).apply(out, out,
outputCol, rowStart, blk);
}
}
}
@@ -125,20 +169,20 @@ public class ColumnEncoderComposite extends ColumnEncoder
{
}
@Override
- public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
try {
for(int i = 0; i < _columnEncoders.size(); i++) {
if(i == 0) {
// 1. encoder writes data into
MatrixBlock Column all others use this column for further encoding
- _columnEncoders.get(i).apply(in, out,
outputCol);
+ _columnEncoders.get(i).apply(in, out,
outputCol, rowStart, blk);
}
else {
- _columnEncoders.get(i).apply(out, out,
outputCol);
+ _columnEncoders.get(i).apply(out, out,
outputCol, rowStart, blk);
}
}
}
catch(Exception ex) {
- LOG.error("Failed to transform-apply frame with \n" +
this);
+ LOG.error("Failed to transform-apply matrix with \n" +
this);
throw ex;
}
return in;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 5fc2883..708fead 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -19,11 +19,15 @@
package org.apache.sysds.runtime.transform.encode;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.List;
import java.util.Objects;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -53,21 +57,43 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
}
@Override
+ public List<Callable<Object>> getPartialBuildTasks(FrameBlock in, int
blockSize) {
+ // do nothing
+ return null;
+ }
+
+ @Override
+ public void mergeBuildPartial(List<Future<Object>> futurePartials, int
start, int end) {
+
+ }
+
+ @Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol)
{
- throw new DMLRuntimeException("Called DummyCoder with
FrameBlock");
+ return apply(in, out, outputCol, 0, -1);
}
public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol,
int rowStart, int blk) {
+ throw new DMLRuntimeException("Called DummyCoder with
FrameBlock");
+ }
+
+ @Override
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
// Out Matrix should already be correct size!
// append dummy coded or unchanged values to output
- for(int i = 0; i < in.getNumRows(); i++) {
+ for(int i = rowStart; i < getEndIndex(in.getNumRows(),
rowStart, blk); i++) {
// Using outputCol here as index since we have a
MatrixBlock as input where dummycoding could have been
// applied in a previous encoder
- double val = in.quickGetValue(i, outputCol);
+ double val = in.quickGetValueThreadSafe(i, outputCol);
int nCol = outputCol + (int) val - 1;
- out.quickSetValue(i, nCol, 1);
+ // Setting value to 0 first in case of sparse so the
row vector does not need to be resized
if(nCol != outputCol)
- out.quickSetValue(i, outputCol, 0);
+ out.quickSetValueThreadSafe(i, outputCol, 0);
+ out.quickSetValueThreadSafe(i, nCol, 1);
}
return out;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
index cd4b272..84d09b4 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
@@ -19,9 +19,14 @@
package org.apache.sysds.runtime.transform.encode;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -61,27 +66,49 @@ public class ColumnEncoderFeatureHash extends ColumnEncoder
{
}
@Override
+ public List<Callable<Object>> getPartialBuildTasks(FrameBlock in, int
blockSize) {
+ // do nothing
+ return null;
+ }
+
+ @Override
+ public void mergeBuildPartial(List<Future<Object>> futurePartials, int
start, int end) {
+
+ }
+
+ @Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol)
{
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol,
int rowStart, int blk) {
// apply feature hashing column wise
- for(int i = 0; i < in.getNumRows(); i++) {
+ for(int i = rowStart; i < getEndIndex(in.getNumRows(),
rowStart, blk); i++) {
Object okey = in.get(i, _colID - 1);
String key = (okey != null) ? okey.toString() : null;
if(key == null)
throw new DMLRuntimeException("Missing Value
encountered in input Frame for FeatureHash");
long code = getCode(key);
- out.quickSetValue(i, outputCol, (code >= 0) ? code :
Double.NaN);
+ out.quickSetValueThreadSafe(i, outputCol, (code >= 0) ?
code : Double.NaN);
}
return out;
}
@Override
- public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
+ int end = (blk <= 0) ? in.getNumRows() : in.getNumRows() <
rowStart + blk ? in.getNumRows() : rowStart + blk;
// apply feature hashing column wise
- for(int i = 0; i < in.getNumRows(); i++) {
- Object okey = in.quickGetValue(i, _colID - 1);
+ for(int i = rowStart; i < end; i++) {
+ Object okey = in.quickGetValueThreadSafe(i, _colID - 1);
String key = okey.toString();
long code = getCode(key);
- out.quickSetValue(i, outputCol, (code >= 0) ? code :
Double.NaN);
+ out.quickSetValueThreadSafe(i, outputCol, (code >= 0) ?
code : Double.NaN);
}
return out;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index 3c3cfdc..7e4a02f 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -19,6 +19,12 @@
package org.apache.sysds.runtime.transform.encode;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -41,26 +47,48 @@ public class ColumnEncoderPassThrough extends ColumnEncoder
{
}
@Override
+ public List<Callable<Object>> getPartialBuildTasks(FrameBlock in, int
blockSize) {
+ // do nothing
+ return null;
+ }
+
+ @Override
+ public void mergeBuildPartial(List<Future<Object>> futurePartials, int
start, int end) {
+
+ }
+
+ @Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol)
{
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol,
int rowStart, int blk) {
int col = _colID - 1; // 1-based
ValueType vt = in.getSchema()[col];
- for(int i = 0; i < in.getNumRows(); i++) {
+ for(int i = rowStart; i < getEndIndex(in.getNumRows(),
rowStart, blk); i++) {
Object val = in.get(i, col);
- out.quickSetValue(i,
- outputCol,
- (val == null || (vt == ValueType.STRING &&
val.toString().isEmpty())) ? Double.NaN : UtilFunctions
- .objectToDouble(vt, val));
+ double v = (val == null ||
+ (vt == ValueType.STRING &&
val.toString().isEmpty()))
+ ? Double.NaN :
UtilFunctions.objectToDouble(vt, val);
+ out.quickSetValueThreadSafe(i, outputCol, v);
}
return out;
}
@Override
- public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
// only transfer from in to out
+ int end = (blk <= 0) ? in.getNumRows() : in.getNumRows() <
rowStart + blk ? in.getNumRows() : rowStart + blk;
int col = _colID - 1; // 1-based
- for(int i = 0; i < in.getNumRows(); i++) {
- double val = in.quickGetValue(i, col);
- out.quickSetValue(i, outputCol, val);
+ for(int i = rowStart; i < end; i++) {
+ double val = in.quickGetValueThreadSafe(i, col);
+ out.quickSetValueThreadSafe(i, outputCol, val);
}
return out;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index 25c4550..d15db5b 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -19,16 +19,23 @@
package org.apache.sysds.runtime.transform.encode;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -97,11 +104,30 @@ public class ColumnEncoderRecode extends ColumnEncoder {
}
public void sortCPRecodeMaps() {
- String[] keys = _rcdMap.keySet().toArray(new String[0]);
+ sortCPRecodeMaps(_rcdMap);
+ }
+
+ private static void sortCPRecodeMaps(HashMap<String, Long> map) {
+ String[] keys = map.keySet().toArray(new String[0]);
Arrays.sort(keys);
- _rcdMap.clear();
+ map.clear();
for(String key : keys)
- putCode(_rcdMap, key);
+ putCode(map, key);
+ }
+
+ private static void makeRcdMap(FrameBlock in, HashMap<String, Long>
map, int colID, int startRow, int blk) {
+ Iterator<String[]> iter = in.getStringRowIterator(startRow,
getEndIndex(in.getNumRows(), startRow, blk), colID);
+ while(iter.hasNext()) {
+ String[] row = iter.next();
+ // probe and build column map
+ String key = row[0]; // 0 since there is only one
column in the row
+ if(key != null && !key.isEmpty() &&
!map.containsKey(key))
+ putCode(map, key);
+ }
+
+ if(SORT_RECODE_MAP) {
+ sortCPRecodeMaps(map);
+ }
}
private long lookupRCDMap(String key) {
@@ -113,18 +139,33 @@ public class ColumnEncoderRecode extends ColumnEncoder {
public void build(FrameBlock in) {
if(!isApplicable())
return;
+ makeRcdMap(in, _rcdMap, _colID, 0, in.getNumRows());
+ }
- Iterator<String[]> iter = in.getStringRowIterator(_colID);
- while(iter.hasNext()) {
- String[] row = iter.next();
- // probe and build column map
- String key = row[0]; // 0 since there is only one
column in the row
- if(key != null && !key.isEmpty() &&
!_rcdMap.containsKey(key))
- putCode(_rcdMap, key);
- }
+ @Override
+ public List<Callable<Object>> getPartialBuildTasks(FrameBlock in, int
blockSize) {
+ List<Callable<Object>> tasks = new ArrayList<>();
+ for(int i = 0; i < in.getNumRows(); i = i + blockSize)
+ tasks.add(new RecodePartialBuildTask(in, _colID, i,
blockSize));
+ if(in.getNumRows() % blockSize != 0)
+ tasks.add(new RecodePartialBuildTask(in, _colID,
in.getNumRows() - in.getNumRows() % blockSize, -1));
+ return tasks;
+ }
- if(SORT_RECODE_MAP) {
- sortCPRecodeMaps();
+ @Override
+ public void mergeBuildPartial(List<Future<Object>> futurePartials, int
start, int end)
+ throws ExecutionException, InterruptedException {
+ for(int i = start; i < end; i++) {
+ Object partial = futurePartials.get(i).get();
+ if(!(partial instanceof HashMap)) {
+ throw new DMLRuntimeException(
+ "Tried to merge " + partial.getClass()
+ " object into RecodeEncoder. " + "HashMap was expected.");
+ }
+ HashMap<?, ?> partialMap = (HashMap<?, ?>) partial;
+ partialMap.forEach((k, v) -> {
+ if(!_rcdMap.containsKey((String) k))
+ putCode(_rcdMap, (String) k);
+ });
}
}
@@ -134,7 +175,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
* @param map column map
* @param key key for the new entry
*/
- protected void putCode(HashMap<String, Long> map, String key) {
+ protected static void putCode(HashMap<String, Long> map, String key) {
map.put(key, (long) (map.size() + 1));
}
@@ -161,16 +202,28 @@ public class ColumnEncoderRecode extends ColumnEncoder {
@Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol)
{
- for(int i = 0; i < in.getNumRows(); i++) {
+ return apply(in, out, outputCol, 0, -1);
+ }
+
+ @Override
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol,
int rowStart, int blk) {
+ // FrameBlock is column Major and MatrixBlock row Major this
results in cache inefficiencies :(
+ for(int i = rowStart; i < getEndIndex(in.getNumRows(),
rowStart, blk); i++) {
Object okey = in.get(i, _colID - 1);
String key = (okey != null) ? okey.toString() : null;
long code = lookupRCDMap(key);
- out.quickSetValue(i, outputCol, (code >= 0) ? code :
Double.NaN);
+ out.quickSetValueThreadSafe(i, outputCol, (code >= 0) ?
code : Double.NaN);
}
return out;
}
@Override
+ public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
+ throw new DMLRuntimeException(
+ "Recode called with MatrixBlock. Should not happen
since Recode is the first " + "encoder in the Stack");
+ }
+
+ @Override
public MatrixBlock apply(MatrixBlock in, MatrixBlock out, int
outputCol) {
throw new DMLRuntimeException(
"Recode called with MatrixBlock. Should not happen
since Recode is the first " + "encoder in the Stack");
@@ -215,8 +268,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
StringBuilder sb = new StringBuilder(); // for reuse
int rowID = 0;
for(Entry<String, Long> e : _rcdMap.entrySet()) {
- meta.set(rowID++,
- _colID - 1, // 1-based
+ meta.set(rowID++, _colID - 1, // 1-based
constructRecodeMapEntry(e.getKey(),
e.getValue(), sb));
}
meta.getColumnMetadata(_colID -
1).setNumDistinct(getNumDistinctValues());
@@ -271,4 +323,32 @@ public class ColumnEncoderRecode extends ColumnEncoder {
public int hashCode() {
return Objects.hash(_rcdMap);
}
+
+ public HashMap<String, Long> getRcdMap() {
+ return _rcdMap;
+ }
+
+ private static class RecodePartialBuildTask implements Callable<Object>
{
+
+ private final FrameBlock _input;
+ private final int _blockSize;
+ private final int _startRow;
+ private final int _colID;
+
+ // if a pool is passed the task may be split up into multiple
smaller tasks.
+ protected RecodePartialBuildTask(FrameBlock input, int colID,
int startRow, int blocksize) {
+ _input = input;
+ _blockSize = blocksize;
+ _colID = colID;
+ _startRow = startRow;
+ }
+
+ @Override
+ public HashMap<String, Long> call() throws Exception {
+ HashMap<String, Long> partialMap = new HashMap<>();
+ makeRcdMap(_input, partialMap, _colID, _startRow,
_blockSize);
+ return partialMap;
+ }
+ }
+
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index fa8ef6d..4a8570a 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -25,6 +25,10 @@ import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -33,13 +37,17 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;
public class MultiColumnEncoder implements Encoder {
protected static final Log LOG =
LogFactory.getLog(MultiColumnEncoder.class.getName());
+ private static final boolean MULTI_THREADED = true;
private List<ColumnEncoderComposite> _columnEncoders;
// These encoders are deprecated and will be fazed out soon.
private EncoderMVImpute _legacyMVImpute = null;
@@ -47,6 +55,18 @@ public class MultiColumnEncoder implements Encoder {
private int _colOffset = 0; // offset for federated Workers who are
using subrange encoders
private FrameBlock _meta = null;
+ // TEMP CONSTANTS for testing only
+ private int APPLY_BLOCKSIZE = 0; // temp only for testing until
automatic calculation of block size
+ public static int BUILD_BLOCKSIZE = 0;
+
+ public void setApplyBlockSize(int blk) {
+ APPLY_BLOCKSIZE = blk;
+ }
+
+ public void setBuildBlockSize(int blk) {
+ BUILD_BLOCKSIZE = blk;
+ }
+
public MultiColumnEncoder(List<ColumnEncoderComposite> columnEncoders) {
_columnEncoders = columnEncoders;
}
@@ -56,13 +76,17 @@ public class MultiColumnEncoder implements Encoder {
}
public MatrixBlock encode(FrameBlock in) {
+ return encode(in, 1);
+ }
+
+ public MatrixBlock encode(FrameBlock in, int k) {
MatrixBlock out;
try {
- build(in);
+ build(in, k);
_meta = getMetaData(new FrameBlock(in.getNumColumns(),
Types.ValueType.STRING));
initMetaData(_meta);
// apply meta data
- out = apply(in);
+ out = apply(in, k);
}
catch(Exception ex) {
LOG.error("Failed transform-encode frame with \n" +
this);
@@ -72,11 +96,60 @@ public class MultiColumnEncoder implements Encoder {
}
public void build(FrameBlock in) {
- for(ColumnEncoder columnEncoder : _columnEncoders)
- columnEncoder.build(in);
+ build(in, 1);
+ }
+
+ public void build(FrameBlock in, int k) {
+ if(MULTI_THREADED && k > 1) {
+ buildMT(in, k);
+ }
+ else {
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ columnEncoder.build(in);
+ }
legacyBuild(in);
}
+ private void buildMT(FrameBlock in, int k) {
+ int blockSize = BUILD_BLOCKSIZE <= 0 ? in.getNumRows() :
BUILD_BLOCKSIZE;
+ List<Callable<Integer>> tasks = new ArrayList<>();
+ ExecutorService pool = CommonThreadPool.get(k);
+ try {
+ if(blockSize != in.getNumRows()) {
+ // Partial builds and merges
+ List<List<Future<Object>>> partials = new
ArrayList<>();
+ for(ColumnEncoderComposite encoder :
_columnEncoders) {
+ List<Callable<Object>>
partialBuildTasks = encoder.getPartialBuildTasks(in, blockSize);
+ if(partialBuildTasks == null) {
+ partials.add(null);
+ continue;
+ }
+
partials.add(pool.invokeAll(partialBuildTasks));
+ }
+ for(int e = 0; e < _columnEncoders.size(); e++)
{
+ List<Future<Object>> partial =
partials.get(e);
+ if(partial == null)
+ continue;
+ tasks.add(new
ColumnMergeBuildPartialTask(_columnEncoders.get(e), partial));
+ }
+ }
+ else {
+ // building every column in one thread
+ for(ColumnEncoderComposite e : _columnEncoders)
{
+ tasks.add(new ColumnBuildTask(e, in));
+ }
+ }
+ List<Future<Integer>> rtasks = pool.invokeAll(tasks);
+ pool.shutdown();
+ for(Future<Integer> t : rtasks)
+ t.get();
+ }
+ catch(InterruptedException | ExecutionException e) {
+ LOG.error("MT Column encode failed");
+ e.printStackTrace();
+ }
+ }
+
public void legacyBuild(FrameBlock in) {
if(_legacyOmit != null)
_legacyOmit.build(in);
@@ -85,37 +158,90 @@ public class MultiColumnEncoder implements Encoder {
}
public MatrixBlock apply(FrameBlock in) {
+ return apply(in, 1);
+ }
+
+ public MatrixBlock apply(FrameBlock in, int k) {
int numCols = in.getNumColumns() + getNumExtraCols();
- MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols,
false);
- return apply(in, out, 0);
+ long estNNz = (long) in.getNumColumns() * (long)
in.getNumRows();
+ boolean sparse =
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz);
+ MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols,
sparse, estNNz);
+ return apply(in, out, 0, k);
}
public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol)
{
+ return apply(in, out, outputCol, 1);
+ }
+
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol,
int k) {
// There should be a encoder for every column
int numEncoders = getFromAll(ColumnEncoderComposite.class,
ColumnEncoder::getColID).size();
if(in.getNumColumns() != numEncoders)
throw new DMLRuntimeException("Not every column in has
a CompositeEncoder. Please make sure every column "
+ "has a encoder or slice the input
accordingly");
-
- try {
+ // Denseblock allocation since access is only on the DenseBlock
+ out.allocateBlock();
+ if(out.isInSparseFormat()) {
+ SparseBlock block = out.getSparseBlock();
+ if(!(block instanceof SparseBlockMCSR))
+ throw new RuntimeException(
+ "Transform apply currently only
supported for MCSR sparse and dense output Matrices");
+ for(int r = 0; r < out.getNumRows(); r++) {
+ // allocate all sparse rows so MT sync can be
done.
+ // should be rare that rows have only 0
+ block.allocate(r, in.getNumColumns());
+ }
+ }
+ // TODO smart checks
+ if(MULTI_THREADED && k > 1) {
+ applyMT(in, out, outputCol, k);
+ }
+ else {
int offset = outputCol;
for(ColumnEncoderComposite columnEncoder :
_columnEncoders) {
columnEncoder.apply(in, out,
columnEncoder._colID - 1 + offset);
if(columnEncoder.hasEncoder(ColumnEncoderDummycode.class))
offset +=
columnEncoder.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
}
- if(_legacyOmit != null)
- out = _legacyOmit.apply(in, out);
- if(_legacyMVImpute != null)
- out = _legacyMVImpute.apply(in, out);
- }
- catch(Exception ex) {
- LOG.error("Failed to transform-apply frame with \n" +
this);
- throw ex;
}
+ // Recomputing NNZ since we access the Dense block directly
+ // TODO set NNZ explicit count them in the encoders
+ out.recomputeNonZeros();
+ if(_legacyOmit != null)
+ out = _legacyOmit.apply(in, out);
+ if(_legacyMVImpute != null)
+ out = _legacyMVImpute.apply(in, out);
+
return out;
}
+ private void applyMT(FrameBlock in, MatrixBlock out, int outputCol, int
k) {
+ try {
+ ExecutorService pool = CommonThreadPool.get(k);
+ ArrayList<ColumnApplyTask> tasks = new ArrayList<>();
+ int offset = outputCol;
+ // TODO calculate smart blocksize
+ int blockSize = APPLY_BLOCKSIZE <= 0 ? in.getNumRows()
: APPLY_BLOCKSIZE;
+ for(ColumnEncoderComposite e : _columnEncoders) {
+ for(int i = 0; i < in.getNumRows(); i = i +
blockSize)
+ tasks.add(new ColumnApplyTask(e, in,
out, e._colID - 1 + offset, i, blockSize));
+ if(in.getNumRows() % blockSize != 0)
+ tasks.add(new ColumnApplyTask(e, in,
out, e._colID - 1 + offset,
+ in.getNumRows() -
in.getNumRows() % blockSize, -1));
+ if(e.hasEncoder(ColumnEncoderDummycode.class))
+ offset +=
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+ }
+ List<Future<Integer>> rtasks = pool.invokeAll(tasks);
+ pool.shutdown();
+ for(Future<Integer> t : rtasks)
+ t.get();
+ }
+ catch(InterruptedException | ExecutionException e) {
+ LOG.error("MT Column encode failed");
+ e.printStackTrace();
+ }
+ }
+
@Override
public FrameBlock getMetaData(FrameBlock meta) {
if(_meta != null)
@@ -463,4 +589,71 @@ public class MultiColumnEncoder implements Encoder {
if(_legacyMVImpute != null)
_legacyMVImpute.shiftCols(_colOffset);
}
+
+ private static class ColumnApplyTask implements Callable<Integer> {
+
+ private final ColumnEncoder _encoder;
+ private final FrameBlock _input;
+ private final MatrixBlock _out;
+ private final int _columnOut;
+ private int _rowStart = 0;
+ private int _blk = -1;
+
+ protected ColumnApplyTask(ColumnEncoder encoder, FrameBlock
input, MatrixBlock out, int columnOut) {
+ _encoder = encoder;
+ _input = input;
+ _out = out;
+ _columnOut = columnOut;
+ }
+
+ protected ColumnApplyTask(ColumnEncoder encoder, FrameBlock
input, MatrixBlock out, int columnOut, int rowStart, int blk) {
+ this(encoder, input, out, columnOut);
+ _rowStart = rowStart;
+ _blk = blk;
+ }
+
+ @Override
+ public Integer call() throws Exception {
+ _encoder.apply(_input, _out, _columnOut, _rowStart,
_blk);
+ // TODO return NNZ
+ return 1;
+ }
+ }
+
+ private static class ColumnBuildTask implements Callable<Integer> {
+
+ private final ColumnEncoder _encoder;
+ private final FrameBlock _input;
+
+ // if a pool is passed the task may be split up into multiple
smaller tasks.
+ protected ColumnBuildTask(ColumnEncoder encoder, FrameBlock
input) {
+ _encoder = encoder;
+ _input = input;
+ }
+
+ @Override
+ public Integer call() throws Exception {
+ _encoder.build(_input);
+ return 1;
+ }
+ }
+
+ private static class ColumnMergeBuildPartialTask implements
Callable<Integer> {
+
+ private final ColumnEncoderComposite _encoder;
+ private final List<Future<Object>> _partials;
+
+ // if a pool is passed the task may be split up into multiple
smaller tasks.
+ protected ColumnMergeBuildPartialTask(ColumnEncoderComposite
encoder, List<Future<Object>> partials) {
+ _encoder = encoder;
+ _partials = partials;
+ }
+
+ @Override
+ public Integer call() throws Exception {
+ _encoder.mergeBuildPartial(_partials, 0,
_partials.size());
+ return 1;
+ }
+ }
+
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index f038994..c0507ce 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -983,4 +983,11 @@ public class UtilFunctions {
}
return vt;
}
+
+
+ public static int getEndIndex(int arrayLength, int startIndex, int
blockSize){
+ return (blockSize <= 0)? arrayLength: Math.min(arrayLength,
startIndex + blockSize);
+ }
+
+
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
new file mode 100644
index 0000000..f40a150
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform.mt;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.List;
+
+import static org.junit.Assert.*;
+
+public class TransformFrameBuildMultithreadedTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 =
"TransformFrameBuildMultithreadedTest";
+ private final static String TEST_DIR = "functions/transform/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameBuildMultithreadedTest.class.getSimpleName() + "/";
+
+ // dataset and transform tasks without missing values
+ private final static String DATASET1 = "homes3/homes.csv";
+ private final static String SPEC1 = "homes3/homes.tfspec_recode.json";
+ private final static String SPEC1b = "homes3/homes.tfspec_recode2.json";
+ private final static String SPEC2 = "homes3/homes.tfspec_dummy.json";
+ private final static String SPEC2b = "homes3/homes.tfspec_dummy2.json";
+ private final static String SPEC3 = "homes3/homes.tfspec_bin.json"; //
recode
+ private final static String SPEC3b = "homes3/homes.tfspec_bin2.json";
// recode
+ private final static String SPEC6 =
"homes3/homes.tfspec_recode_dummy.json";
+ private final static String SPEC6b =
"homes3/homes.tfspec_recode_dummy2.json";
+ private final static String SPEC7 =
"homes3/homes.tfspec_binDummy.json"; // recode+dummy
+ private final static String SPEC7b =
"homes3/homes.tfspec_binDummy2.json"; // recode+dummy
+ private final static String SPEC8 = "homes3/homes.tfspec_hash.json";
+ private final static String SPEC8b = "homes3/homes.tfspec_hash2.json";
+ private final static String SPEC9 =
"homes3/homes.tfspec_hash_recode.json";
+ private final static String SPEC9b =
"homes3/homes.tfspec_hash_recode2.json";
+
+ public enum TransformType {
+ RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY, HASH, HASH_RECODE,
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"}));
+ }
+
+ @Test
+ public void testHomesRecodeIDsSingleNodeCSV() {
+ runTransformTest(Types.ExecMode.SINGLE_NODE, "csv",
TransformType.RECODE, false);
+ }
+
+ @Test
+ public void testHomesDummyCodeIDsSingleNodeCSV() {
+ runTransformTest(Types.ExecMode.SINGLE_NODE, "csv",
TransformType.DUMMY, false);
+ }
+
+ @Test
+ public void testHomesRecodeDummyCodeIDsSingleNodeCSV() {
+ runTransformTest(Types.ExecMode.SINGLE_NODE, "csv",
TransformType.RECODE_DUMMY, false);
+ }
+
+ @Test
+ public void testHomesBinIDsSingleNodeCSV() {
+ runTransformTest(Types.ExecMode.SINGLE_NODE, "csv",
TransformType.BIN, false);
+ }
+
+ @Test
+ public void testHomesBinDummyIDsSingleNodeCSV() {
+ runTransformTest(Types.ExecMode.SINGLE_NODE, "csv",
TransformType.BIN_DUMMY, false);
+ }
+
+ @Test
+ public void testHomesHashIDsSingleNodeCSV() {
+ runTransformTest(Types.ExecMode.SINGLE_NODE, "csv",
TransformType.HASH, false);
+ }
+
+ @Test
+ public void testHomesHashRecodeIDsSingleNodeCSV() {
+ runTransformTest(Types.ExecMode.SINGLE_NODE, "csv",
TransformType.HASH_RECODE, false);
+ }
+
+
+ private void runTransformTest(Types.ExecMode rt, String ofmt,
TransformType type, boolean colnames)
+ {
+ // set transform specification
+ String SPEC = null;
+ String DATASET = null;
+ switch (type) {
+ case RECODE:
+ SPEC = colnames ? SPEC1b : SPEC1;
+ DATASET = DATASET1;
+ break;
+ case DUMMY:
+ SPEC = colnames ? SPEC2b : SPEC2;
+ DATASET = DATASET1;
+ break;
+ case BIN:
+ SPEC = colnames ? SPEC3b : SPEC3;
+ DATASET = DATASET1;
+ break;
+ case RECODE_DUMMY:
+ SPEC = colnames ? SPEC6b : SPEC6;
+ DATASET = DATASET1;
+ break;
+ case BIN_DUMMY:
+ SPEC = colnames ? SPEC7b : SPEC7;
+ DATASET = DATASET1;
+ break;
+ case HASH:
+ SPEC = colnames ? SPEC8b : SPEC8;
+ DATASET = DATASET1;
+ break;
+ case HASH_RECODE:
+ SPEC = colnames ? SPEC9b : SPEC9;
+ DATASET = DATASET1;
+ break;
+ }
+
+ if (!ofmt.equals("csv"))
+ throw new RuntimeException("Unsupported test output
format");
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ //String HOME = SCRIPT_DIR + TEST_DIR;
+ DATASET = DATASET_DIR + DATASET;
+ SPEC = DATASET_DIR + SPEC;
+
+ FileFormatPropertiesCSV props = new
FileFormatPropertiesCSV();
+ props.setHeader(true);
+ FrameBlock input =
FrameReaderFactory.createFrameReader(Types.FileFormat.CSV, props)
+ .readFrameFromHDFS(DATASET, -1L, -1L);
+ StringBuilder specSb = new StringBuilder();
+ Files.readAllLines(Paths.get(SPEC)).forEach(s ->
specSb.append(s).append("\n"));
+ MultiColumnEncoder encoderS =
EncoderFactory.createEncoder(specSb.toString(),
+ input.getColumnNames(),
input.getNumColumns(), null);
+ MultiColumnEncoder encoderM =
EncoderFactory.createEncoder(specSb.toString(),
+ input.getColumnNames(),
input.getNumColumns(), null);
+
+ encoderM.setBuildBlockSize(10);
+ encoderS.build(input, 1);
+ encoderM.build(input, 12);
+ if (type == TransformType.RECODE) {
+ List<ColumnEncoderRecode> encodersS =
encoderS.getColumnEncoders(ColumnEncoderRecode.class);
+ List<ColumnEncoderRecode> encodersM =
encoderM.getColumnEncoders(ColumnEncoderRecode.class);
+ assertEquals(encodersS.size(),
encodersM.size());
+ for (int i = 0; i < encodersS.size(); i++) {
+
assertEquals(encodersS.get(i).getRcdMap().keySet(),
encodersM.get(i).getRcdMap().keySet());
+ }
+ }
+ else if (type == TransformType.BIN) {
+ List<ColumnEncoderBin> encodersS =
encoderS.getColumnEncoders(ColumnEncoderBin.class);
+ List<ColumnEncoderBin> encodersM =
encoderM.getColumnEncoders(ColumnEncoderBin.class);
+ assertEquals(encodersS.size(),
encodersM.size());
+ for (int i = 0; i < encodersS.size(); i++) {
+
assertArrayEquals(encodersS.get(i).getBinMins(), encodersM.get(i).getBinMins(),
0);
+
assertArrayEquals(encodersS.get(i).getBinMaxs(), encodersM.get(i).getBinMaxs(),
0);
+ }
+ }
+ }
+ catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
new file mode 100644
index 0000000..75ac71c
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform.mt;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
+public class TransformFrameEncodeMultithreadedTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 =
"TransformFrameEncodeMultithreadedTest";
+ private final static String TEST_DIR = "functions/transform/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeMultithreadedTest.class.getSimpleName() + "/";
+
+ //dataset and transform tasks without missing values
+ private final static String DATASET1 = "homes3/homes.csv";
+ private final static String SPEC1 =
"homes3/homes.tfspec_recode.json";
+ private final static String SPEC1b =
"homes3/homes.tfspec_recode2.json";
+ private final static String SPEC2 = "homes3/homes.tfspec_dummy.json";
+ private final static String SPEC2all =
"homes3/homes.tfspec_dummy_all.json";
+ private final static String SPEC2b =
"homes3/homes.tfspec_dummy2.json";
+ private final static String SPEC3 = "homes3/homes.tfspec_bin.json";
//recode
+ private final static String SPEC3b = "homes3/homes.tfspec_bin2.json";
//recode
+ private final static String SPEC6 =
"homes3/homes.tfspec_recode_dummy.json";
+ private final static String SPEC6b =
"homes3/homes.tfspec_recode_dummy2.json";
+ private final static String SPEC7 =
"homes3/homes.tfspec_binDummy.json"; //recode+dummy
+ private final static String SPEC7b =
"homes3/homes.tfspec_binDummy2.json"; //recode+dummy
+ private final static String SPEC8 = "homes3/homes.tfspec_hash.json";
+ private final static String SPEC8b = "homes3/homes.tfspec_hash2.json";
+ private final static String SPEC9 =
"homes3/homes.tfspec_hash_recode.json";
+ private final static String SPEC9b =
"homes3/homes.tfspec_hash_recode2.json";
+
+ private static final int[] BIN_col3 = new int[]{1,4,2,3,3,2,4};
+ private static final int[] BIN_col8 = new int[]{1,2,2,2,2,2,3};
+
+ public enum TransformType {
+ RECODE,
+ DUMMY,
+ DUMMY_ALL, //to test sparse
+ RECODE_DUMMY,
+ BIN,
+ BIN_DUMMY,
+ HASH,
+ HASH_RECODE,
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "y" }) );
+ }
+
+ @Test
+ public void testHomesRecodeIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.RECODE, false);
+ }
+
+ @Test
+ public void testHomesDummyCodeIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.DUMMY, false);
+ }
+
+ @Test
+ public void testHomesDummyAllCodeIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.DUMMY_ALL, false);
+ }
+
+
+ @Test
+ public void testHomesRecodeDummyCodeIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.RECODE_DUMMY, false);
+ }
+
+ @Test
+ public void testHomesBinIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.BIN, false);
+ }
+
+ @Test
+ public void testHomesBinDummyIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.BIN_DUMMY, false);
+ }
+
+ @Test
+ public void testHomesHashIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.HASH, false);
+ }
+
+ @Test
+ public void testHomesHashRecodeIDsSingleNodeCSV() {
+ runTransformTest(ExecMode.SINGLE_NODE, "csv",
TransformType.HASH_RECODE, false);
+ }
+
+ private void runTransformTest( ExecMode rt, String ofmt, TransformType
type, boolean colnames) {
+
+ //set transform specification
+ String SPEC = null; String DATASET = null;
+ switch( type ) {
+ case RECODE: SPEC = colnames?SPEC1b:SPEC1; DATASET =
DATASET1; break;
+ case DUMMY: SPEC = colnames?SPEC2b:SPEC2; DATASET =
DATASET1; break;
+ case DUMMY_ALL: SPEC = SPEC2all; DATASET = DATASET1;
break;
+ case BIN: SPEC = colnames?SPEC3b:SPEC3; DATASET =
DATASET1; break;
+ case RECODE_DUMMY: SPEC = colnames?SPEC6b:SPEC6;
DATASET = DATASET1; break;
+ case BIN_DUMMY: SPEC = colnames?SPEC7b:SPEC7; DATASET =
DATASET1; break;
+ case HASH: SPEC = colnames?SPEC8b:SPEC8; DATASET
= DATASET1; break;
+ case HASH_RECODE: SPEC = colnames?SPEC9b:SPEC9; DATASET
= DATASET1; break;
+ }
+
+ if( !ofmt.equals("csv") )
+ throw new RuntimeException("Unsupported test output
format");
+
+ try
+ {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ //String HOME = SCRIPT_DIR + TEST_DIR;
+ DATASET = DATASET_DIR + DATASET;
+ SPEC = DATASET_DIR + SPEC;
+
+ FileFormatPropertiesCSV props = new
FileFormatPropertiesCSV();
+ props.setHeader(true);
+ FrameBlock input =
FrameReaderFactory.createFrameReader(FileFormat.CSV,
props).readFrameFromHDFS(DATASET, -1L,-1L);
+ StringBuilder specSb = new StringBuilder();
+ Files.readAllLines(Paths.get(SPEC)).forEach(s ->
specSb.append(s).append("\n"));
+ MultiColumnEncoder encoder =
EncoderFactory.createEncoder(specSb.toString(), input.getColumnNames(),
input.getNumColumns(), null);
+
+ MatrixBlock outputS = encoder.encode(input, 1);
+ MatrixBlock outputM = encoder.encode(input, 12);
+
+ double[][] R1 =
DataConverter.convertToDoubleMatrix(outputS);
+ double[][] R2 =
DataConverter.convertToDoubleMatrix(outputM);
+ TestUtils.compareMatrices(R1, R2, R1.length,
R1[0].length, 0);
+ Assert.assertEquals(outputS.getNonZeros(),
outputM.getNonZeros());
+ Assert.assertTrue(outputM.getNonZeros() > 0);
+
+ if( rt == ExecMode.HYBRID ) {
+ Assert.assertEquals("Wrong number of executed
Spark instructions: " +
+ Statistics.getNoOfExecutedSPInst(), new
Long(0), new Long(Statistics.getNoOfExecutedSPInst()));
+ }
+
+ //additional checks for binning as encode-decode
impossible
+ //TODO fix distributed binning as well
+ if( type == TransformType.BIN ) {
+ for(int i=0; i<7; i++) {
+ Assert.assertEquals(BIN_col3[i],
R1[i][2], 1e-8);
+ Assert.assertEquals(BIN_col8[i],
R1[i][7], 1e-8);
+ }
+ }
+ else if( type == TransformType.BIN_DUMMY ) {
+ Assert.assertEquals(14, R1[0].length);
+ for(int i=0; i<7; i++) {
+ for(int j=0; j<4; j++) { //check dummy
coded
+
Assert.assertEquals((j==BIN_col3[i]-1)?
+ 1:0, R1[i][2+j], 1e-8);
+ }
+ for(int j=0; j<3; j++) { //check dummy
coded
+
Assert.assertEquals((j==BIN_col8[i]-1)?
+ 1:0, R1[i][10+j], 1e-8);
+ }
+ }
+ }
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+}
diff --git a/src/test/resources/datasets/homes3/homes.tfspec_dummy_all.json
b/src/test/resources/datasets/homes3/homes.tfspec_dummy_all.json
new file mode 100644
index 0000000..65b8fee
--- /dev/null
+++ b/src/test/resources/datasets/homes3/homes.tfspec_dummy_all.json
@@ -0,0 +1 @@
+{"ids": true, "dummycode": [ 2, 7, 1, 3, 4, 5, 6, 8, 9 ] }
\ No newline at end of file