This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 89056f1ec9 [SYSTEMDS-3293] Optimize partitions count with memory 
estimate
89056f1ec9 is described below

commit 89056f1ec97082a6c720bec7ba4fdcae65f3a8f1
Author: arnabp <[email protected]>
AuthorDate: Wed May 4 13:12:41 2022 +0200

    [SYSTEMDS-3293] Optimize partitions count with memory estimate
    
    This patch extends the optimizer for transformencode to reduce
    the build partitions count if they don't fit in the memory budget.
---
 .../runtime/transform/encode/ColumnEncoder.java    |  9 +++++
 .../transform/encode/ColumnEncoderComposite.java   |  6 ++-
 .../transform/encode/ColumnEncoderRecode.java      |  1 +
 .../transform/encode/MultiColumnEncoder.java       | 46 +++++++++++++++++-----
 .../apache/sysds/runtime/util/DependencyTask.java  |  2 +-
 .../sysds/utils/stats/TransformStatistics.java     |  5 ++-
 6 files changed, 56 insertions(+), 13 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 89423521b1..b243c857c2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -60,6 +60,7 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
        protected int _colID;
        protected ArrayList<Integer> _sparseRowsWZeros = null;
        protected long _estMetaSize = 0;
+       protected int _estNumDistincts = 0;
        protected int _nBuildPartitions = 0;
        protected int _nApplyPartitions = 0;
 
@@ -291,6 +292,14 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
                return _estMetaSize;
        }
 
+       public void setEstNumDistincts(int numDistincts) {
+               _estNumDistincts = numDistincts;
+       }
+
+       public int getEstNumDistincts() {
+               return _estNumDistincts;
+       }
+
        @Override
        public int compareTo(ColumnEncoder o) {
                return Integer.compare(getEncoderType(this), getEncoderType(o));
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index a22cab19ab..7194939853 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -361,11 +361,15 @@ public class ColumnEncoderComposite extends ColumnEncoder 
{
        }
 
        public void computeRCDMapSizeEstimate(CacheBlock in, int[] 
sampleIndices) {
+               int estNumDist = 0;
                for (ColumnEncoder e : _columnEncoders)
-                       if (e.getClass().equals(ColumnEncoderRecode.class))
+                       if (e.getClass().equals(ColumnEncoderRecode.class)) {
                                ((ColumnEncoderRecode) 
e).computeRCDMapSizeEstimate(in, sampleIndices);
+                               estNumDist = e.getEstNumDistincts();
+                       }
                long totEstSize = 
_columnEncoders.stream().mapToLong(ColumnEncoder::getEstMetaSize).sum();
                setEstMetaSize(totEstSize);
+               setEstNumDistincts(estNumDist);
        }
 
        public void setNumPartitions(int nBuild, int nApply) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index 8ed89856d8..a6e0329d3a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -154,6 +154,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
                int[] freq = distinctFreq.values().stream().mapToInt(v -> 
v).toArray();
                int estDistCount = SampleEstimatorFactory.distinctCount(freq, 
in.getNumRows(),
                        sampleIndices.length, 
SampleEstimatorFactory.EstimationType.HassAndStokes);
+               setEstNumDistincts(estDistCount);
 
                // Compute total size estimates for each partial recode map
                // We assume each partial map contains all distinct values and 
have the same size
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index b34a152fc7..a869ef8208 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -46,7 +46,6 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlockCSR;
 import org.apache.sysds.runtime.data.SparseRowVector;
@@ -427,18 +426,23 @@ public class MultiColumnEncoder implements Encoder {
                while (numBlocks[1] > 1 && nRow/numBlocks[1] < minNumRows)
                        numBlocks[1]--;
 
-               // Reduce #build blocks if all don't fit in memory
+               // Reduce #build blocks for the recoders if all don't fit in 
memory
+               int rcdNumBuildBlks = numBlocks[0];
                if (numBlocks[0] > 1) {
                        // Estimate recode map sizes
                        estimateRCMapSize(in, recodeEncoders);
-                       long totEstSize = 
recodeEncoders.stream().mapToLong(ColumnEncoderComposite::getEstMetaSize).sum();
-                       // Worst case scenario: all partial maps contain all 
distinct values
-                       long totPartMapSize = totEstSize * numBlocks[0];
-                       if (totPartMapSize > 
InfrastructureAnalyzer.getLocalMaxMemory())
-                               numBlocks[0] = 1;
-                       // TODO: Maintain #blocks per encoder. Reduce only the 
ones with large maps
-                       // TODO: If this not enough, add dependencies between 
recode build tasks
+                       // Memory budget for maps = 70% of heap - sizeof(input)
+                       long memBudget = (long) 
(OptimizerUtils.getLocalMemBudget() - in.getInMemorySize());
+                       // Worst case scenario: all partial maps contain all 
distinct values (if < #rows)
+                       long totMemOverhead = getTotalMemOverhead(in, 
rcdNumBuildBlks, recodeEncoders);
+                       // Reduce recode build blocks count till they fit int 
the memory budget
+                       while (rcdNumBuildBlks > 1 && totMemOverhead > 
memBudget) {
+                               rcdNumBuildBlks--;
+                               totMemOverhead = getTotalMemOverhead(in, 
rcdNumBuildBlks, recodeEncoders);
+                               // TODO: Reduce only the ones with large maps
+                       }
                }
+               // TODO: If still don't fit, serialize the column encoders
 
                // Set to 1 if not set by the above logics
                for (int i=0; i<2; i++)
@@ -448,6 +452,11 @@ public class MultiColumnEncoder implements Encoder {
                _partitionDone = true;
                // Materialize the partition counts in the encoders
                _columnEncoders.forEach(e -> e.setNumPartitions(numBlocks[0], 
numBlocks[1]));
+               if (rcdNumBuildBlks > 0 && rcdNumBuildBlks != numBlocks[0]) {
+                       int rcdNumBlocks = rcdNumBuildBlks;
+                       recodeEncoders.forEach(e -> 
e.setNumPartitions(rcdNumBlocks, numBlocks[1]));
+               }
+               //System.out.println("Block count = ["+numBlocks[0]+", 
"+numBlocks[1]+"], Recode block count = "+rcdNumBuildBlks);
        }
 
        private void estimateRCMapSize(CacheBlock in, 
List<ColumnEncoderComposite> rcList) {
@@ -477,6 +486,25 @@ public class MultiColumnEncoder implements Encoder {
                }
        }
 
+       // Estimate total memory overhead of the partial recode maps of all 
recoders
+       private long getTotalMemOverhead(CacheBlock in, int nBuildpart, 
List<ColumnEncoderComposite> rcEncoders) {
+               long totMemOverhead = 0;
+               if (nBuildpart == 1) {
+                       // Sum the estimated map sizes
+                       totMemOverhead = 
rcEncoders.stream().mapToLong(ColumnEncoderComposite::getEstMetaSize).sum();
+                       return totMemOverhead;
+               }
+               // Estimate map size of each partition and sum
+               for (ColumnEncoderComposite rce : rcEncoders) {
+                       long avgEntrySize = rce.getEstMetaSize()/ 
rce.getEstNumDistincts();
+                       int partSize = in.getNumRows()/nBuildpart;
+                       int partNumDist = Math.min(partSize, 
rce.getEstNumDistincts()); //#distincts not more than #rows
+                       long allMapsSize = partNumDist * avgEntrySize * 
nBuildpart; //worst-case scenario
+                       totMemOverhead += allMapsSize;
+               }
+               return totMemOverhead;
+       }
+
        private static void outputMatrixPreProcessing(MatrixBlock output, 
CacheBlock input, boolean hasDC) {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                if(output.isInSparseFormat()) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java 
b/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
index 69c25fede6..943b344502 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
@@ -30,7 +30,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
 
 public class DependencyTask<E> implements Comparable<DependencyTask<?>>, 
Callable<E> {
-       public static final boolean ENABLE_DEBUG_DATA = false;
+       public static final boolean ENABLE_DEBUG_DATA = false; // explain task 
graph
        protected static final Log LOG = 
LogFactory.getLog(DependencyTask.class.getName());
 
        private final Callable<E> _task;
diff --git 
a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java 
b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
index f1a4f1f3d8..05f06b065c 100644
--- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
@@ -174,8 +174,9 @@ public class TransformStatistics {
                                
outMatrixPreProcessingTime.longValue()*1e-9)).append(" sec.\n");
                        sb.append("TransformEncode PostProc. 
time:\t").append(String.format("%.3f",
                                
outMatrixPostProcessingTime.longValue()*1e-9)).append(" sec.\n");
-                       sb.append("TransformEncode SizeEst. 
time:\t").append(String.format("%.3f",
-                               
mapSizeEstimationTime.longValue()*1e-9)).append(" sec.\n");
+                       if(mapSizeEstimationTime.longValue() > 0)
+                               sb.append("TransformEncode SizeEst. 
time:\t").append(String.format("%.3f",
+                                       
mapSizeEstimationTime.longValue()*1e-9)).append(" sec.\n");
                        return sb.toString();
                }
                return "";

Reply via email to