phaniarnab commented on a change in pull request #1261:
URL: https://github.com/apache/systemds/pull/1261#discussion_r631326992
##########
File path:
src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
##########
@@ -72,11 +96,60 @@ public MatrixBlock encode(FrameBlock in) {
}
public void build(FrameBlock in) {
- for(ColumnEncoder columnEncoder : _columnEncoders)
- columnEncoder.build(in);
+ build(in, 1);
+ }
+
+ public void build(FrameBlock in, int k) {
+ if(MULTI_THREADED && k > 1) {
+ buildMT(in, k);
+ }
+ else {
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ columnEncoder.build(in);
+ }
legacyBuild(in);
}
+ private void buildMT(FrameBlock in, int k) {
+ int blockSize = BUILD_BLOCKSIZE <= 0 ? in.getNumRows() :
BUILD_BLOCKSIZE;
+ List<Callable<Integer>> tasks = new ArrayList<>();
Review comment:
_Discussion:_ What if we need column-specific block sizes in the future?
##########
File path:
src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
##########
@@ -463,4 +589,72 @@ public void applyColumnOffset() {
if(_legacyMVImpute != null)
_legacyMVImpute.shiftCols(_colOffset);
}
+
+ private static class ColumnApplyTask implements Callable<Integer> {
+
+ private final ColumnEncoder _encoder;
+ private final FrameBlock _input;
+ private final MatrixBlock _out;
+ private final int _columnOut;
+ private int _rowStart = 0;
+ private int _blk = -1;
+
+ protected ColumnApplyTask(ColumnEncoder encoder, FrameBlock
input, MatrixBlock out, int columnOut) {
+ _encoder = encoder;
+ _input = input;
+ _out = out;
+ _columnOut = columnOut;
+ }
+
+ protected ColumnApplyTask(ColumnEncoder encoder, FrameBlock
input, MatrixBlock out, int columnOut, int rowStart,
+ int blk) {
+ this(encoder, input, out, columnOut);
+ _rowStart = rowStart;
+ _blk = blk;
+ }
+
+ @Override
+ public Integer call() throws Exception {
+ _encoder.apply(_input, _out, _columnOut, _rowStart,
_blk);
+ // TODO return NNZ
+ return 1;
+ }
+ }
+
+ private static class ColumnBuildTask implements Callable<Integer> {
+
+ private final ColumnEncoder _encoder;
+ private final FrameBlock _input;
+
+ // if a pool is passed the task may be split up into multiple
smaller tasks.
+ protected ColumnBuildTask(ColumnEncoder encoder, FrameBlock
input) {
+ _encoder = encoder;
+ _input = input;
+ }
+
+ @Override
+ public Integer call() throws Exception {
+ _encoder.build(_input);
+ return 1;
+ }
+ }
+
+ private static class ColumnMergeBuildPartialTask implements
Callable<Integer> {
+
+ private final ColumnEncoderComposite _encoder;
+ private final List<Future<Object>> _partials;
+
+ // if a pool is passed the task may be split up into multiple
smaller tasks.
+ protected ColumnMergeBuildPartialTask(ColumnEncoderComposite
encoder, List<Future<Object>> partials) {
+ _encoder = encoder;
+ _partials = partials;
+ }
+
+ @Override
+ public Integer call() throws Exception {
+ _encoder.mergeBuildPartial(_partials, 0,
_partials.size());
+ return 1;
Review comment:
What is the significance of this hard-coded 1?
##########
File path:
src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
##########
@@ -33,20 +37,36 @@
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;
public class MultiColumnEncoder implements Encoder {
protected static final Log LOG =
LogFactory.getLog(MultiColumnEncoder.class.getName());
+ private static final boolean MULTI_THREADED = true;
private List<ColumnEncoderComposite> _columnEncoders;
// These encoders are deprecated and will be fazed out soon.
private EncoderMVImpute _legacyMVImpute = null;
private EncoderOmit _legacyOmit = null;
private int _colOffset = 0; // offset for federated Workers who are
using subrange encoders
private FrameBlock _meta = null;
+ // TEMP CONSTANTS for testing only
+ private int APPLY_BLOCKSIZE = 0; // temp only for testing until
automatic calculation of block size
+ public static int BUILD_BLOCKSIZE = 0;
+
+ public void setApplyBlockSize(int blk) {
+ APPLY_BLOCKSIZE = blk;
+ }
Review comment:
Can you please add a test with non-zero `APPLY_BLOCKSIZE`?
##########
File path:
src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
##########
@@ -72,11 +96,60 @@ public MatrixBlock encode(FrameBlock in) {
}
public void build(FrameBlock in) {
- for(ColumnEncoder columnEncoder : _columnEncoders)
- columnEncoder.build(in);
+ build(in, 1);
+ }
+
+ public void build(FrameBlock in, int k) {
+ if(MULTI_THREADED && k > 1) {
+ buildMT(in, k);
+ }
+ else {
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ columnEncoder.build(in);
+ }
legacyBuild(in);
}
+ private void buildMT(FrameBlock in, int k) {
+ int blockSize = BUILD_BLOCKSIZE <= 0 ? in.getNumRows() :
BUILD_BLOCKSIZE;
+ List<Callable<Integer>> tasks = new ArrayList<>();
+ ExecutorService pool = CommonThreadPool.get(k);
+ try {
+ if(blockSize != in.getNumRows()) {
+ // Partial builds and merges
+ List<List<Future<Object>>> partials = new
ArrayList<>();
+ for(ColumnEncoderComposite encoder :
_columnEncoders) {
+ List<Callable<Object>>
partialBuildTasks = encoder.getPartialBuildTasks(in, blockSize);
+ if(partialBuildTasks == null) {
+ partials.add(null);
+ continue;
+ }
+
partials.add(pool.invokeAll(partialBuildTasks));
+ }
+ for(int e = 0; e < _columnEncoders.size(); e++)
{
+ List<Future<Object>> partial =
partials.get(e);
+ if(partial == null)
+ continue;
+ tasks.add(new
ColumnMergeBuildPartialTask(_columnEncoders.get(e), partial));
+ }
Review comment:
_Discussion:_ This logic of creating tasks (column-wise row partition)
restricts us from more sophisticated task creation with an arbitrary number of
columns. This may not be a problem though.
##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
##########
@@ -643,7 +643,36 @@ public void quickSetValue(int r, int c, double v)
nonZeros--;
}
}
-
+
+ /*
+ Thread save set.
+ Blocks need to be allocated and in case of MCSR sparse all rows
that are going to be accessed need to be allocated
+ aswell.
+ */
+ public void quickSetValueThreadSafe(int r, int c, double v){
+ if(sparse){
+ if(!(sparseBlock instanceof SparseBlockMCSR))
+ throw new RuntimeException("Only MCSR Blocks
are supported for Multithreaded sparse set.");
+ synchronized (sparseBlock.get(r)){
+ sparseBlock.set(r,c,v);
+ }
+ }else{
+ denseBlock.set(r,c,v);
+ }
Review comment:
Is the denseBlock/sparseBlock guaranteed to be allocated here?
why synchronize only sparse?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]