This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 58746bc777 [SYSTEMDS-3495] Parallel Compressed Encode
58746bc777 is described below
commit 58746bc777ac922cbabc48d484cb86164581066b
Author: baunsgaard <[email protected]>
AuthorDate: Sun Feb 5 17:21:23 2023 +0100
[SYSTEMDS-3495] Parallel Compressed Encode
This commit updates the compressed encode to encode each encoding
in parallel, while also updating the recode map construction
to an, faster version via putIfAbsent on hashmaps.
For Critero 1Mil:
- Parallel reduced from 8.688 - 6 sec
- PutIfAbsent reduced from 6 - 4.5 sec.
Closes #1781
---
.../dictionary/IdentityDictionarySlice.java | 2 +-
.../sysds/runtime/frame/data/columns/Array.java | 7 +-
.../runtime/transform/encode/CompressedEncode.java | 83 +++++++++++++++++-----
.../transform/encode/MultiColumnEncoder.java | 2 +-
4 files changed, 73 insertions(+), 21 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
index 8e2512b698..53dfc3227b 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
@@ -120,7 +120,7 @@ public class IdentityDictionarySlice extends
IdentityDictionary {
@Override
public int getNumberOfValues(int ncol) {
- return ncol;
+ return nRowCol;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
index e706672e17..d5fbda7874 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
@@ -100,8 +100,11 @@ public abstract class Array<T> implements Writable {
long id = 0;
for(int i = 0; i < size(); i++) {
T val = get(i);
- if(val != null && !map.containsKey(val))
- map.put(val, id++);
+ if(val != null){
+ Long v = map.putIfAbsent(val, id);
+ if(v == null)
+ id++;
+ }
}
return map;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
index 2ab83a381c..32690484d8 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
@@ -22,6 +22,10 @@ package org.apache.sysds.runtime.transform.encode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
@@ -29,6 +33,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
@@ -44,35 +49,69 @@ import
org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
public class CompressedEncode {
protected static final Log LOG =
LogFactory.getLog(CompressedEncode.class.getName());
+ /** The encoding scheme plan */
private final MultiColumnEncoder enc;
+ /** The Input FrameBlock */
private final FrameBlock in;
+ /** The thread count of the instruction */
+ private final int k;
- private CompressedEncode(MultiColumnEncoder enc, FrameBlock in) {
+ private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) {
this.enc = enc;
this.in = in;
+ this.k = k;
}
- public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in)
{
- return new CompressedEncode(enc, in).apply();
+ public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in,
int k) {
+ return new CompressedEncode(enc, in, k).apply();
}
private MatrixBlock apply() {
- List<ColumnEncoderComposite> encoders = enc.getColumnEncoders();
+ final List<ColumnEncoderComposite> encoders =
enc.getColumnEncoders();
+ final List<AColGroup> groups = isParallel() ?
multiThread(encoders) : singleThread(encoders);
+ final int cols = shiftGroups(groups);
+ final MatrixBlock mb = new
CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups);
+ mb.recomputeNonZeros();
+ logging(mb);
+ return mb;
+ }
- List<AColGroup> groups = new ArrayList<>(encoders.size());
+ private boolean isParallel() {
+ return k > 1 && enc.getEncoders().size() > 1;
+ }
+ private List<AColGroup> singleThread(List<ColumnEncoderComposite>
encoders) {
+ List<AColGroup> groups = new ArrayList<>(encoders.size());
for(ColumnEncoderComposite c : encoders)
groups.add(encode(c));
+ return groups;
+ }
- int cols = shiftGroups(groups);
+ private List<AColGroup> multiThread(List<ColumnEncoderComposite>
encoders) {
- MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(),
cols, -1, false, groups);
- mb.recomputeNonZeros();
- logging(mb);
- return mb;
+ final ExecutorService pool = CommonThreadPool.get(k);
+ try {
+ List<EncodeTask> tasks = new
ArrayList<>(encoders.size());
+
+ for(ColumnEncoderComposite c : encoders)
+ tasks.add(new EncodeTask(c));
+
+ List<AColGroup> groups = new
ArrayList<>(encoders.size());
+ for(Future<AColGroup> t : pool.invokeAll(tasks))
+ groups.add(t.get());
+
+ pool.shutdown();
+ return groups;
+ }
+ catch(InterruptedException | ExecutionException ex) {
+ pool.shutdown();
+ throw new DMLRuntimeException("Failed parallel
compressed transform encode", ex);
+ }
}
/**
@@ -108,7 +147,6 @@ public class CompressedEncode {
HashMap<?, Long> map = a.getRecodeMap();
int domain = map.size();
- // int domain = c.getDomainSize();
IColIndex colIndexes = ColIndexFactory.create(0, domain);
ADictionary d = new IdentityDictionary(colIndexes.size());
@@ -153,15 +191,14 @@ public class CompressedEncode {
Array<?> a = in.getColumn(colId - 1);
HashMap<Object, Long> map = (HashMap<Object, Long>)
a.getRecodeMap();
final int blockSz =
ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE);
- if(map.size() >= blockSz){
+ if(map.size() >= blockSz) {
double[] vals = (double[])
a.changeType(ValueType.FP64).get();
MatrixBlock col = new MatrixBlock(a.size(), 1, vals);
col.recomputeNonZeros();
// lets make it an uncompressed column group.
return ColGroupUncompressed.create(colIndexes, col,
false);
}
- else{
-
+ else {
double[] vals = new double[map.size() +
(a.containsNull() ? 1 : 0)];
for(int i = 0; i < a.size(); i++) {
Object v = a.get(i);
@@ -173,7 +210,7 @@ public class CompressedEncode {
vals[map.get(v).intValue()] =
a.getAsDouble(i);
}
}
-
+
ADictionary d = Dictionary.create(vals);
AMapToData m = createMappingAMapToData(a, map);
return ColGroupDDC.create(colIndexes, d, m, null);
@@ -186,13 +223,25 @@ public class CompressedEncode {
Array<?>.ArrayIterator it = a.getIterator();
while(it.hasNext()) {
Object v = it.next();
- if(v != null) {
+ if(v != null)
m.set(it.getIndex(), map.get(v).intValue());
- }
}
return m;
}
+ private class EncodeTask implements Callable<AColGroup> {
+
+ ColumnEncoderComposite c;
+
+ protected EncodeTask(ColumnEncoderComposite c) {
+ this.c = c;
+ }
+
+ public AColGroup call() throws Exception {
+ return encode(c);
+ }
+ }
+
private void logging(MatrixBlock mb) {
if(LOG.isDebugEnabled()) {
LOG.debug(String.format("Uncompressed transform encode
Dense size: %16d", mb.estimateSizeDenseInMemory()));
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 34a2ba8a76..140a5e18b3 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -103,7 +103,7 @@ public class MultiColumnEncoder implements Encoder {
deriveNumRowPartitions(in, k);
try {
if(isCompressedTransformEncode(in, compressedOut))
- return CompressedEncode.encode(this,
(FrameBlock ) in);
+ return CompressedEncode.encode(this,
(FrameBlock ) in, k);
else if(k > 1 && !MULTI_THREADED_STAGES &&
!hasLegacyEncoder()) {
MatrixBlock out = new MatrixBlock();
DependencyThreadPool pool = new
DependencyThreadPool(k);