This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 3a96238 [SYSTEMDS-2851] Compressed n+ operation
3a96238 is described below
commit 3a96238ef43441b3f7ef1cdd85a9ebf0145fd20d
Author: baunsgaard <[email protected]>
AuthorDate: Thu Feb 4 23:18:12 2021 +0100
[SYSTEMDS-2851] Compressed n+ operation
This commit adds the support for n+ operation in compressed space.
This operation take multiple matrices and adds together.
The implementation is the bare minimum, and not very well optimized but
allows for execution of the algorithm Multilogreg with all the icpt
variations.
---
scripts/builtin/pcaTransform.dml | 7 +-
.../builtin/{pcaTransform.dml => scaleApply.dml} | 26 ++--
src/main/java/org/apache/sysds/api/DMLOptions.java | 2 +
.../java/org/apache/sysds/common/Builtins.java | 3 +-
.../compress/AbstractCompressedMatrixBlock.java | 8 --
.../runtime/compress/CompressedMatrixBlock.java | 91 ++++++++++++-
.../compress/CompressedMatrixBlockFactory.java | 16 +--
.../sysds/runtime/compress/colgroup/ColGroup.java | 4 +-
.../runtime/compress/colgroup/ColGroupConst.java | 5 -
.../runtime/compress/colgroup/ColGroupDDC.java | 11 --
.../runtime/compress/colgroup/ColGroupOLE.java | 20 +--
.../runtime/compress/colgroup/ColGroupOffset.java | 4 -
.../runtime/compress/colgroup/ColGroupRLE.java | 9 +-
.../compress/colgroup/ColGroupUncompressed.java | 7 -
.../runtime/compress/colgroup/ColGroupValue.java | 22 ++--
.../runtime/compress/lib/LibBinaryCellOp.java | 98 ++++++++++++--
.../sysds/runtime/compress/lib/LibCompAgg.java | 83 ++++++++++--
.../cp/BinaryMatrixMatrixCPInstruction.java | 13 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 39 +++---
.../sysds/utils/DMLCompressionStatistics.java | 22 +++-
.../java/org/apache/sysds/utils/Statistics.java | 2 +
.../python/systemds/context/systemds_context.py | 4 +-
.../org/apache/sysds/test/AutomatedTestBase.java | 37 ++++--
.../test/functions/compress/compressScale.java | 145 +++++++++++++++++++++
.../SystemDS-config-compress-cost-DDC.xml | 25 ++++
.../SystemDS-config-compress-cost-OLE.xml | 25 ++++
.../SystemDS-config-compress-cost-RLE.xml | 25 ++++
.../SystemDS-config-compress-cost.xml | 24 ++++
.../compressScale/SystemDS-config-default.xml | 24 ++++
.../functions/compress/compressScale/scale.dml | 30 +++++
30 files changed, 671 insertions(+), 160 deletions(-)
diff --git a/scripts/builtin/pcaTransform.dml b/scripts/builtin/pcaTransform.dml
index 5574dfa..429342d 100644
--- a/scripts/builtin/pcaTransform.dml
+++ b/scripts/builtin/pcaTransform.dml
@@ -40,12 +40,7 @@ m_pcaTransform = function(Matrix[Double] X, Matrix[Double]
Clusters,
return (Matrix[Double] Y)
{
- if(nrow(Centering) > 0 & ncol(Centering) > 0){
- X = X - Centering
- }
- if(nrow(ScaleFactor) > 0 & ncol(ScaleFactor) > 0){
- X = X / ScaleFactor
- }
+ X = scaleApply(X, Centering, ScaleFactor)
Y = X %*% Clusters
}
diff --git a/scripts/builtin/pcaTransform.dml b/scripts/builtin/scaleApply.dml
similarity index 62%
copy from scripts/builtin/pcaTransform.dml
copy to scripts/builtin/scaleApply.dml
index 5574dfa..f0d06d2 100644
--- a/scripts/builtin/pcaTransform.dml
+++ b/scripts/builtin/scaleApply.dml
@@ -19,33 +19,25 @@
#
#-------------------------------------------------------------
-# Principal Component Analysis (PCA) for dimensionality reduction prediciton
-#
-# This method is used to transpose data, which the PCA model was not trained
on. To validate how good
-# The PCA is, and to apply in production.
-#
+# Scale and center individual features in the input matrix (column wise.)
using the input matrices.
#
---------------------------------------------------------------------------------------------
-# NAME TYPE DEFAULT MEANING
+# NAME TYPE DEFAULT MEANING
#
---------------------------------------------------------------------------------------------
-# X Matrix --- Input feature matrix
-# Centering Matrix empty matrix The column means of the PCA model,
subtracted to construct the PCA
-# ScaleFactor Matrix empty matrix The scaling of each dimension in the PCA
model
+# X Matrix --- Input feature matrix
+# ColMean Matrix --- The column means to subtract from X (not
done if empty)
+# Centering Matrix --- The column scaling to multiply with X (not
done if empty)
#
---------------------------------------------------------------------------------------------
-# Y Matrix --- Output feature matrix dimensionally
reduced by PCA
+# Y Matrix --- Output feature matrix with K columns
#
---------------------------------------------------------------------------------------------
-m_pcaTransform = function(Matrix[Double] X, Matrix[Double] Clusters,
- Matrix[Double] Centering = matrix(0, rows= 0, cols=0),
- Matrix[Double] ScaleFactor = matrix(0, rows= 0, cols=0))
+m_scaleApply = function(Matrix[Double] X, Matrix[Double] Centering,
Matrix[Double] ScaleFactor)
return (Matrix[Double] Y)
{
-
if(nrow(Centering) > 0 & ncol(Centering) > 0){
- X = X - Centering
+ Y = X - Centering
}
if(nrow(ScaleFactor) > 0 & ncol(ScaleFactor) > 0){
- X = X / ScaleFactor
+ Y = X / ScaleFactor
}
- Y = X %*% Clusters
}
diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 2e95456..a4dae4b 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -44,6 +44,8 @@ import org.apache.sysds.utils.Explain.ExplainType;
* to keep it consistent with {@link DMLOptions} and {@link DMLOptions}
*/
public class DMLOptions {
+ // private static final Log LOG =
LogFactory.getLog(DMLOptions.class.getName());
+
public final Options options;
public Map<String, String> argVals = new HashMap<>(); //
Arguments map containing either named arguments or arguments by position for a
DML program
public String configFile = null; // Path
to config file if default config and default config is to be overridden
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 2d5659d..ad9f141 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -247,7 +247,8 @@ public enum Builtins {
QEXP("qexp", false, true),
REPLACE("replace", false, true),
RMEMPTY("removeEmpty", false, true),
- SCALE("scale", true, false), //TODO parameterize center & scale
+ SCALE("scale", true, false),
+ SCALEAPPLY("scaleApply", true, false),
TIME("time", false),
CVLM("cvlm", true, false),
TOSTRING("toString", false, true),
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
index ffbd017..e1f255a 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
@@ -50,7 +50,6 @@ import org.apache.sysds.runtime.matrix.operators.COVOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
-import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.SortUtils;
@@ -148,13 +147,6 @@ public abstract class AbstractCompressedMatrixBlock
extends MatrixBlock {
// Graceful fallback to uncompressed linear algebra
@Override
- public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue
result) {
- printDecompressWarning("unaryOperations");
- MatrixBlock tmp = decompress();
- return tmp.unaryOperations(op, result);
- }
-
- @Override
public MatrixBlock binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue) {
printDecompressWarning("binaryOperationsInPlace", (MatrixBlock)
thatValue);
MatrixBlock left = decompress();
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 6186d1a..f5db18e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -67,6 +67,7 @@ import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -77,6 +78,7 @@ import
org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.utils.DMLCompressionStatistics;
@@ -279,14 +281,14 @@ public class CompressedMatrixBlock extends
AbstractCompressedMatrixBlock {
// TODO Optimize Quick Get Value, to located the correct column
group without having to search for it
if(isOverlapping()) {
double v = 0.0;
- for(ColGroup group : _colGroups)
- if(Arrays.binarySearch(group.getColIndices(),
c) >= 0)
+ for(ColGroup group : _colGroups)
+ if(Arrays.binarySearch(group.getColIndices(),
c) >= 0)
v += group.get(r, c);
return v;
}
else {
- for(ColGroup group : _colGroups)
- if(Arrays.binarySearch(group.getColIndices(),
c) >= 0)
+ for(ColGroup group : _colGroups)
+ if(Arrays.binarySearch(group.getColIndices(),
c) >= 0)
return group.get(r, c);
return 0;
}
@@ -361,6 +363,10 @@ public class CompressedMatrixBlock extends
AbstractCompressedMatrixBlock {
return LibBinaryCellOp.binaryOperations(op, this, thatValue,
result);
}
+ public MatrixBlock binaryOperationsLeft(BinaryOperator op, MatrixValue
thatValue, MatrixValue result){
+ return LibBinaryCellOp.binaryOperationsLeft(op, this,
thatValue, result);
+ }
+
@Override
public MatrixBlock append(MatrixBlock that, MatrixBlock ret) {
@@ -794,4 +800,81 @@ public class CompressedMatrixBlock extends
AbstractCompressedMatrixBlock {
MatrixBlock tmp = decompress();
tmp.slice(outlist, range, rowCut, colCut, blen, boundaryRlen,
boundaryClen);
}
+
+ @Override
+ public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue
result) {
+ printDecompressWarning("unaryOperations");
+ MatrixBlock tmp = decompress();
+ return tmp.unaryOperations(op, result);
+ }
+
+ // @Override
+ // public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue
result) {
+ // MatrixBlock ret = checkType(result);
+
+ // // estimate the sparsity structure of result matrix
+ // // by default, we guess result.sparsity=input.sparsity, unless not
sparse safe
+ // boolean sp = this.sparse && op.sparseSafe;
+
+ // //allocate output
+ // int n = Builtin.isBuiltinCode(op.fn, BuiltinCode.CUMSUMPROD) ? 1 :
clen;
+ // if( ret == null )
+ // ret = new MatrixBlock(rlen, n, sp, sp ? nonZeros : rlen*n);
+ // else
+ // ret.reset(rlen, n, sp);
+
+ // //core execute
+ // if( LibMatrixAgg.isSupportedUnaryOperator(op) ) {
+ // //e.g., cumsum/cumprod/cummin/cumax/cumsumprod
+ // if( op.getNumThreads() > 1 )
+ // ret = LibMatrixAgg.cumaggregateUnaryMatrix(this, ret, op,
op.getNumThreads());
+ // else
+ // ret = LibMatrixAgg.cumaggregateUnaryMatrix(this, ret, op);
+ // }
+ // else if(!sparse && !isEmptyBlock(false)
+ // && OptimizerUtils.isMaxLocalParallelism(op.getNumThreads())) {
+ // //note: we apply multi-threading in a best-effort manner here
+ // //only for expensive operators such as exp, log, sigmoid, because
+ // //otherwise allocation, read and write anyway dominates
+ // ret.allocateDenseBlock(false);
+ // DenseBlock a = getDenseBlock();
+ // DenseBlock c = ret.getDenseBlock();
+ // for(int bi=0; bi<a.numBlocks(); bi++) {
+ // double[] avals = a.valuesAt(bi), cvals = c.valuesAt(bi);
+ // Arrays.parallelSetAll(cvals, i -> op.fn.execute(avals[i]));
+ // }
+ // ret.recomputeNonZeros();
+ // }
+ // else {
+ // //default execute unary operations
+ // if(op.sparseSafe)
+ // sparseUnaryOperations(op, ret);
+ // else
+ // denseUnaryOperations(op, ret);
+ // }
+
+ // //ensure empty results sparse representation
+ // //(no additional memory requirements)
+ // if( ret.isEmptyBlock(false) )
+ // ret.examSparsity();
+
+ // return ret;
+ // }
+ @Override
+ public double max() {
+ AggregateUnaryOperator op =
InstructionUtils.parseBasicAggregateUnaryOperator("uamax", -1);
+ return aggregateUnaryOperations(op, null, 1000,
null).getValue(0, 0);
+ }
+
+ @Override
+ public double sum() {
+ AggregateUnaryOperator op =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+", -1);
+ return aggregateUnaryOperations(op, null, 1000,
null).getValue(0, 0);
+ }
+
+ @Override
+ public double sumSq() {
+ AggregateUnaryOperator op =
InstructionUtils.parseBasicAggregateUnaryOperator("uasqk+", -1);
+ return aggregateUnaryOperations(op, null, 1000,
null).getValue(0, 0);
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index 1f1fea0..46a38f9 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -28,7 +28,6 @@ import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.cocode.PlanningCoCoder;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
@@ -206,22 +205,22 @@ public class CompressedMatrixBlockFactory {
private void logPhase() {
_stats.setNextTimePhase(time.stop());
- if(DMLScript.STATISTICS) {
+ // if(DMLScript.STATISTICS) {
DMLCompressionStatistics.addCompressionTime(_stats.getLastTimePhase(), phase);
- }
+ // }
if(LOG.isDebugEnabled()) {
switch(phase) {
case 0:
- LOG.debug("--compression phase " +
phase++ + " Classify : " + _stats.getLastTimePhase());
+ LOG.debug("--compression phase " +
phase + " Classify : " + _stats.getLastTimePhase());
break;
case 1:
- LOG.debug("--compression phase " +
phase++ + " Grouping : " + _stats.getLastTimePhase());
+ LOG.debug("--compression phase " +
phase + " Grouping : " + _stats.getLastTimePhase());
break;
case 2:
- LOG.debug("--compression phase " +
phase++ + " Transpose : " + _stats.getLastTimePhase());
+ LOG.debug("--compression phase " +
phase + " Transpose : " + _stats.getLastTimePhase());
break;
case 3:
- LOG.debug("--compression phase " +
phase++ + " Compress : " + _stats.getLastTimePhase());
+ LOG.debug("--compression phase " +
phase + " Compress : " + _stats.getLastTimePhase());
LOG.debug("--compression Hash
collisions:" + DblArrayIntListHashMap.hashMissCount);
DblArrayIntListHashMap.hashMissCount =
0;
break;
@@ -230,7 +229,7 @@ public class CompressedMatrixBlockFactory {
// break;
case 4:
LOG.debug("--num col groups: " +
res.getColGroups().size());
- LOG.debug("--compression phase " +
phase++ + " Cleanup : " + _stats.getLastTimePhase());
+ LOG.debug("--compression phase " +
phase + " Cleanup : " + _stats.getLastTimePhase());
LOG.debug("--col groups types " +
_stats.getGroupsTypesString());
LOG.debug("--col groups sizes " +
_stats.getGroupsSizesString());
LOG.debug("--compressed size: " +
_stats.size);
@@ -251,6 +250,7 @@ public class CompressedMatrixBlockFactory {
default:
}
}
+ phase++;
}
private List<ColGroup> assignColumns(int numCols, ColGroup[] colGroups,
MatrixBlock rawBlock,
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java
index e88f1d0..e76c559 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java
@@ -565,7 +565,9 @@ public abstract class ColGroup implements Serializable {
*
* @return returns if the colgroup is allocated in a dense fashion.
*/
- public abstract boolean isDense();
+ public boolean isDense(){
+ return ! _zeros;
+ }
/**
* Slice out the columns within the range of cl and cu to remove the
dictionary values related to these columns.
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
index 932426a..47056f6 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
@@ -253,9 +253,4 @@ public class ColGroupConst extends ColGroupValue {
rnnz[i] = base;
}
}
-
- @Override
- public boolean isDense() {
- return true;
- }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index 440236f..035971a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -180,13 +180,6 @@ public abstract class ColGroupDDC extends ColGroupValue {
}
}
-
-
- @Override
- protected void computeColSums(double[] c, KahanFunction kplus) {
- _dict.colSum(c, getCounts(), _colIndexes, kplus);
- }
-
@Override
protected void computeRowSums(double[] c, KahanFunction kplus, int rl,
int ru, boolean mean) {
final int numVals = getNumValues();
@@ -412,8 +405,4 @@ public abstract class ColGroupDDC extends ColGroupValue {
*/
protected abstract void setData(int r, int code);
- @Override
- public boolean isDense(){
- return true;
- }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
index f4a272f..5d4785b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
@@ -275,13 +275,16 @@ public class ColGroupOLE extends ColGroupOffset {
int sum = 0;
for(int k = 0; k < numVals; k++) {
int blen = len(k);
- int blocks = _numRows /
CompressionSettings.BITMAP_BLOCK_SZ + 1;
- int count = blen - blocks;
+ int count = 0;
+ int boff = _ptr[k];
+ int bix = 0;
+ for(; bix < blen ; bix += _data[boff + bix] + 1)
+ count += _data[boff + bix];
sum += count;
counts[k] = count;
}
if(_zeros) {
- counts[counts.length - 1] = _numRows *
_colIndexes.length - sum;
+ counts[counts.length - 1] = _numRows - sum;
}
return counts;
}
@@ -302,7 +305,7 @@ public class ColGroupOLE extends ColGroupOffset {
counts[k] = count;
}
if(_zeros) {
- counts[counts.length - 1] = (ru - rl) *
_colIndexes.length - sum;
+ counts[counts.length - 1] = (ru - rl) - sum;
}
return counts;
}
@@ -820,10 +823,7 @@ public class ColGroupOLE extends ColGroupOffset {
}
}
- @Override
- protected final void computeColSums(double[] c, KahanFunction kplus) {
- _dict.colSum(c, getCounts(), _colIndexes, kplus);
- }
+
@Override
protected final void computeRowMxx(MatrixBlock c, Builtin builtin, int
rl, int ru) {
@@ -843,7 +843,7 @@ public class ColGroupOLE extends ColGroupOffset {
// iterate over bitmap blocks and add values
int slen;
int bix = skipScanVal(k, rl);
- for(int off = bix * blksz; bix < blen && off < ru; bix
+= slen + 1, off += blksz) {
+ for(int off = ((rl + 1) / blksz) * blksz; bix < blen &&
off < ru; bix += slen + 1, off += blksz) {
slen = _data[boff + bix];
for(int i = 1; i <= slen; i++) {
int rix = off + _data[boff + bix + i];
@@ -1049,7 +1049,7 @@ public class ColGroupOLE extends ColGroupOffset {
inputIx += blockSz;
blockStartIx += blockSz + 1;
}
-
+
return encodedBlocks;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java
index d05bb97..1711c2e 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java
@@ -212,8 +212,4 @@ public abstract class ColGroupOffset extends ColGroupValue {
return sb.toString();
}
- @Override
- public boolean isDense(){
- return false;
- }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
index df8160e..04e5f3b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
@@ -321,7 +321,7 @@ public class ColGroupRLE extends ColGroupOffset {
counts[k] = count;
}
if(_zeros) {
- counts[counts.length - 1] = _numRows *
_colIndexes.length - sum;
+ counts[counts.length - 1] = _numRows - sum;
}
return counts;
}
@@ -347,7 +347,7 @@ public class ColGroupRLE extends ColGroupOffset {
counts[k] = count;
}
if(_zeros) {
- counts[counts.length - 1] = (ru - rl) *
_colIndexes.length - sum;
+ counts[counts.length - 1] = (ru - rl) - sum;
}
return counts;
}
@@ -801,10 +801,7 @@ public class ColGroupRLE extends ColGroupOffset {
}
}
- @Override
- protected final void computeColSums(double[] c, KahanFunction kplus) {
- _dict.colSum(c, getCounts(), _colIndexes, kplus);
- }
+
@Override
protected final void computeRowMxx(MatrixBlock c, Builtin builtin, int
rl, int ru) {
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index b0aa9df..e7b5eef 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -504,13 +504,6 @@ public class ColGroupUncompressed extends ColGroup {
}
@Override
- public boolean isDense() {
- // Even if the uncompressed column groups can be sparse
allocated,
- // they are dense in the sense of compression.
- return true;
- }
-
- @Override
public ColGroup sliceColumns(int cl, int cu){
throw new NotImplementedException("Not implemented slice
columns");
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 3f3db1a..af9b8b3 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -166,19 +166,13 @@ public abstract class ColGroupValue extends ColGroup
implements Cloneable {
*/
public final int[] getCounts() {
if(counts == null) {
-
- counts = new int[getNumValues() + 1];
- // if(_zeros) {
- // tmp = allocIVector(getNumValues() + 1, true);
- // }
- // else {
- // tmp = allocIVector(getNumValues(), true);
- // }
- return getCounts(counts);
- }
- else {
+ counts = getCounts(new int[getNumValues() + (_zeros ? 1
: 0)]);
+ // LOG.error(Arrays.toString(counts));
return counts;
}
+ else
+ return counts;
+
}
public final int[] getCachedCounts() {
@@ -443,7 +437,6 @@ public abstract class ColGroupValue extends ColGroup
implements Cloneable {
op.aggOp.increOp.fn instanceof Mean) ? KahanPlus
.getKahanPlusFnObject() :
KahanPlusSq.getKahanPlusSqFnObject();
boolean mean = op.aggOp.increOp.fn instanceof Mean;
-
if(op.indexFn instanceof ReduceAll)
computeSum(c.getDenseBlockValues(), kplus);
else if(op.indexFn instanceof ReduceCol)
@@ -599,7 +592,6 @@ public abstract class ColGroupValue extends ColGroup
implements Cloneable {
public abstract int[] getCounts(int rl, int ru, int[] out);
protected void computeSum(double[] c, KahanFunction kplus){
-
if(kplus instanceof KahanPlusSq)
c[0] += _dict.sumsq(getCounts(), _colIndexes.length);
else
@@ -608,7 +600,9 @@ public abstract class ColGroupValue extends ColGroup
implements Cloneable {
protected abstract void computeRowSums(double[] c, KahanFunction kplus,
int rl, int ru, boolean mean);
- protected abstract void computeColSums(double[] c, KahanFunction kplus);
+ protected void computeColSums(double[] c, KahanFunction kplus) {
+ _dict.colSum(c, getCounts(), _colIndexes, kplus);
+ }
protected abstract void computeRowMxx(MatrixBlock c, Builtin builtin,
int rl, int ru);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java
index e47e902..e60956e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java
@@ -66,16 +66,24 @@ public class LibBinaryCellOp {
MatrixValue result) {
MatrixBlock that =
AbstractCompressedMatrixBlock.getUncompressed(thatValue);
LibMatrixBincell.isValidDimensionsBinary(m1, that);
+ BinaryAccessType atype =
LibMatrixBincell.getBinaryAccessType(m1, that);
+ return selectProcessingBasedOnAccessType(op, m1, that,
thatValue, result, atype);
+ }
- return selectProcessingBasedOnAccessType(op, m1, that,
thatValue, result);
+ public static MatrixBlock binaryOperationsLeft(BinaryOperator op,
CompressedMatrixBlock m1, MatrixValue thatValue,
+ MatrixValue result) {
+ MatrixBlock that =
AbstractCompressedMatrixBlock.getUncompressed(thatValue);
+ LibMatrixBincell.isValidDimensionsBinary(that, m1);
+ BinaryAccessType atype =
LibMatrixBincell.getBinaryAccessType(that, m1);
+ return selectProcessingBasedOnAccessType(op, m1, that,
thatValue, result, atype);
}
private static MatrixBlock
selectProcessingBasedOnAccessType(BinaryOperator op, CompressedMatrixBlock m1,
- MatrixBlock that, MatrixValue thatValue, MatrixValue result) {
- BinaryAccessType atype =
LibMatrixBincell.getBinaryAccessType(m1, that);
-
- if(atype == BinaryAccessType.MATRIX_COL_VECTOR || atype ==
BinaryAccessType.MATRIX_MATRIX)
+ MatrixBlock that, MatrixValue thatValue, MatrixValue result,
BinaryAccessType atype) {
+ if(atype == BinaryAccessType.MATRIX_COL_VECTOR)
return binaryMVCol(m1, that, op);
+ else if(atype == BinaryAccessType.MATRIX_MATRIX)
+ return binaryMM(m1, that, op);
else if(isSupportedBinaryCellOp(op.fn))
return bincellOp(m1, that,
setupCompressedReturnMatrixBlock(m1, result), op);
else {
@@ -147,11 +155,13 @@ public class LibBinaryCellOp {
}
}
- protected static CompressedMatrixBlock
binaryMVRow(CompressedMatrixBlock m1, MatrixBlock m2,
- CompressedMatrixBlock ret, BinaryOperator op) {
+ public static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock
m1, double[] v, CompressedMatrixBlock ret, BinaryOperator op){
List<ColGroup> oldColGroups = m1.getColGroups();
- double[] v = forceMatrixBlockToDense(m2);
+
+ if(ret == null)
+ ret = new CompressedMatrixBlock(m1.getNumRows(),
m1.getNumColumns());
+
boolean sparseSafe = true;
for(double x : v) {
if(op.fn.execute(0.0, x) != 0.0) {
@@ -188,6 +198,11 @@ public class LibBinaryCellOp {
ret.allocateColGroupList(newColGroups);
ret.setNonZeros(m1.getNumColumns() * m1.getNumRows());
return ret;
+ }
+
+ protected static CompressedMatrixBlock
binaryMVRow(CompressedMatrixBlock m1, MatrixBlock m2,
+ CompressedMatrixBlock ret, BinaryOperator op) {
+ return binaryMVRow(m1, forceMatrixBlockToDense(m2), ret, op);
}
@@ -271,6 +286,32 @@ public class LibBinaryCellOp {
return ret;
}
+ private static MatrixBlock binaryMM(CompressedMatrixBlock m1,
MatrixBlock m2, BinaryOperator op){
+
+ MatrixBlock ret = new MatrixBlock(m1.getNumRows(),
m1.getNumColumns(), false, -1).allocateBlock();
+
+ final int blkz = CompressionSettings.BITMAP_BLOCK_SZ * 2 /
m1.getNumColumns();
+ int k = OptimizerUtils.getConstrainedNumThreads(-1);
+ ExecutorService pool = CommonThreadPool.get(k);
+ ArrayList<BinaryMMTask> tasks = new ArrayList<>();
+
+ try {
+ for(int i = 0; i * blkz < m1.getNumRows(); i++) {
+ tasks.add(new BinaryMMTask(m1, m2, ret, i *
blkz, Math.min(m1.getNumRows(), (i + 1) * blkz), op));
+ }
+ long nnz = 0;
+ for(Future<Integer> f : pool.invokeAll(tasks))
+ nnz += f.get();
+ ret.setNonZeros(nnz);
+ pool.shutdown();
+ }
+ catch(InterruptedException | ExecutionException e) {
+ e.printStackTrace();
+ throw new DMLRuntimeException(e);
+ }
+ return ret;
+ }
+
private static class BinaryMVColTask implements Callable<Integer> {
private final int _rl;
private final int _ru;
@@ -295,7 +336,6 @@ public class LibBinaryCellOp {
// unsafe decompress, since we count nonzeros
afterwards.
g.decompressToBlockSafe(_ret, _rl, _ru,
g.getValues(), false);
}
-
int nnz = 0;
DenseBlock db = _ret.getDenseBlock();
for(int row = _rl; row < _ru; row++) {
@@ -311,6 +351,46 @@ public class LibBinaryCellOp {
}
}
+
+ private static class BinaryMMTask implements Callable<Integer> {
+ private final int _rl;
+ private final int _ru;
+ private final CompressedMatrixBlock _m1;
+ private final MatrixBlock _m2;
+ private final MatrixBlock _ret;
+ private final BinaryOperator _op;
+
+ protected BinaryMMTask(CompressedMatrixBlock m1, MatrixBlock
m2, MatrixBlock ret, int rl, int ru,
+ BinaryOperator op) {
+ _m1 = m1;
+ _m2 = m2;
+ _ret = ret;
+ _op = op;
+ _rl = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public Integer call() {
+ for(ColGroup g : _m1.getColGroups()) {
+ // unsafe decompress, since we count nonzeros
afterwards.
+ g.decompressToBlockSafe(_ret, _rl, _ru,
g.getValues(), false);
+ }
+ int nnz = 0;
+ DenseBlock db = _ret.getDenseBlock();
+ for(int row = _rl; row < _ru; row++) {
+ for(int col = 0; col < _m1.getNumColumns();
col++) {
+ double vr = _m2.quickGetValue(row, col);
+ double v =
_op.fn.execute(_ret.quickGetValue(row, col), vr);
+ nnz += (v != 0) ? 1 : 0;
+ db.set(row, col, v);
+ }
+ }
+
+ return nnz;
+ }
+ }
+
private static class BinaryMVRowTask implements Callable<ColGroup> {
private final ColGroup _group;
private final double[] _v;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java
index 585b7ad..437cf97 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java
@@ -83,7 +83,6 @@ public class LibCompAgg {
outputMatrix.dropLastRowsOrColumns(op.aggOp.correction);
outputMatrix.recomputeNonZeros();
- memPool.remove();
return outputMatrix;
}
@@ -131,10 +130,10 @@ public class LibCompAgg {
// compute all compressed column groups
if(op.indexFn instanceof ReduceCol) {
final int blkz = CompressionSettings.BITMAP_BLOCK_SZ;
- int blklen = Math.max((int) Math.ceil((double) m1.getNumRows()
/ (op.getNumThreads() * 2)), blkz);
- blklen += (blklen % blkz != 0) ? blkz - blklen % blkz : 0;
-
- for(int i = 0; i < op.getNumThreads() & i * blklen <
m1.getNumRows(); i++)
+ // int blklen = Math.max((int) Math.ceil((double)
m1.getNumRows() / (op.getNumThreads() * 2)), blkz);
+ // blklen += (blklen % blkz != 0) ? blkz - blklen % blkz : 0;
+ int blklen = blkz * 4;
+ for(int i = 0; i * blklen < m1.getNumRows(); i++)
tasks.add(new UnaryAggregateTask(m1.getColGroups(), ret, i
* blklen,
Math.min((i + 1) * blklen, m1.getNumRows()), op,
m1.getNumColumns()));
@@ -142,7 +141,8 @@ public class LibCompAgg {
else {
List<List<ColGroup>> grpParts =
createTaskPartitionNotIncludingUncompressable(m1.getColGroups(), k);
for(List<ColGroup> grp : grpParts)
- tasks.add(new UnaryAggregateTask(grp, ret, 0,
m1.getNumRows(), op, m1.getNumColumns()));
+ tasks.add(new UnaryAggregateTask(grp, ret, 0,
m1.getNumRows(), op, m1.getNumColumns(),
+ m1.isOverlapping()));
}
List<Future<MatrixBlock>> futures = pool.invokeAll(tasks);
@@ -154,6 +154,15 @@ public class LibCompAgg {
aggregateResults(ret, futures, op);
else
sumResults(ret, futures);
+ else if(op.indexFn instanceof ReduceRow && m1.isOverlapping()) {
+ if(op.aggOp.increOp.fn instanceof Builtin)
+ aggregateResultVectors(ret, futures, op);
+ else
+ sumResultVectors(ret, futures);
+ }
+ else
+ for(Future<MatrixBlock> f : futures)
+ f.get();
}
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
@@ -171,6 +180,19 @@ public class LibCompAgg {
}
+ private static void sumResultVectors(MatrixBlock ret,
List<Future<MatrixBlock>> futures)
+ throws InterruptedException, ExecutionException {
+
+ double[] retVals = ret.getDenseBlockValues();
+ for(Future<MatrixBlock> rtask : futures) {
+ double[] taskResult = rtask.get().getDenseBlockValues();
+ for(int i = 0; i < retVals.length; i++) {
+ retVals[i] += taskResult[i];
+ }
+ }
+ ret.setNonZeros(ret.getNumColumns());
+ }
+
private static void aggregateResults(MatrixBlock ret,
List<Future<MatrixBlock>> futures, AggregateUnaryOperator op)
throws InterruptedException, ExecutionException {
double val = ret.quickGetValue(0, 0);
@@ -181,6 +203,18 @@ public class LibCompAgg {
ret.quickSetValue(0, 0, val);
}
+ private static void aggregateResultVectors(MatrixBlock ret,
List<Future<MatrixBlock>> futures,
+ AggregateUnaryOperator op) throws InterruptedException,
ExecutionException {
+ double[] retVals = ret.getDenseBlockValues();
+ for(Future<MatrixBlock> rtask : futures) {
+ double[] taskResult = rtask.get().getDenseBlockValues();
+ for(int i = 0; i < retVals.length; i++) {
+ retVals[i] = op.aggOp.increOp.fn.execute(retVals[i] ,
taskResult[i]);
+ }
+ }
+ ret.setNonZeros(ret.getNumColumns());
+ }
+
private static void aggregateSingleThreaded(CompressedMatrixBlock m1,
MatrixBlock ret, AggregateUnaryOperator op) {
aggregateUnaryOperations(op, m1.getColGroups(), ret, 0,
m1.getNumRows(), m1.getNumColumns());
}
@@ -337,19 +371,22 @@ public class LibCompAgg {
private static void
aggregateUnaryBuiltinRowOperation(AggregateUnaryOperator op, List<ColGroup>
groups,
MatrixBlock ret, int rl, int ru, int numColumns) {
- int[] rnnz = new int[ru - rl];
+ int[] rnnz = null;
int numberDenseColumns = 0;
for(ColGroup grp : groups) {
grp.unaryAggregateOperations(op, ret, rl, ru);
if(grp.isDense())
numberDenseColumns += grp.getNumCols();
- else
+ else{
+ if (rnnz == null)
+ rnnz = new int[ru - rl];
grp.countNonZerosPerRow(rnnz, rl, ru);
+ }
}
-
- for(int row = rl; row < ru; row++)
- if(rnnz[row] + numberDenseColumns < numColumns)
- ret.quickSetValue(row, 0,
op.aggOp.increOp.fn.execute(ret.quickGetValue(row, 0), 0.0));
+ if(rnnz != null)
+ for(int row = rl; row < ru; row++)
+ if(rnnz[row-rl] + numberDenseColumns < numColumns)
+ ret.quickSetValue(row, 0,
op.aggOp.increOp.fn.execute(ret.quickGetValue(row, 0), 0.0));
}
@@ -403,6 +440,28 @@ public class LibCompAgg {
}
+ protected UnaryAggregateTask(List<ColGroup> groups, MatrixBlock ret,
int rl, int ru, AggregateUnaryOperator op,
+ int numColumns, boolean overlapping) {
+ _groups = groups;
+ _op = op;
+ _rl = rl;
+ _ru = ru;
+ _numColumns = numColumns;
+
+ if(_op.indexFn instanceof ReduceAll || (_op.indexFn instanceof
ReduceRow && overlapping)) {
+ _ret = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(),
false);
+ _ret.allocateDenseBlock();
+ if(_op.aggOp.increOp.fn instanceof Builtin)
+ System.arraycopy(ret.getDenseBlockValues(),
+ 0,
+ _ret.getDenseBlockValues(),
+ 0,
+ ret.getNumRows() * ret.getNumColumns());
+ }
+ else // colSums / rowSums
+ _ret = ret;
+ }
+
@Override
public MatrixBlock call() {
aggregateUnaryOperations(_op, _groups, _ret, _rl, _ru,
_numColumns);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 784b46e..20ddfb1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.AbstractCompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -60,10 +61,14 @@ public class BinaryMatrixMatrixCPInstruction extends
BinaryCPInstruction {
if(inBlock1 instanceof CompressedMatrixBlock && inBlock2
instanceof CompressedMatrixBlock){
retBlock = inBlock1.binaryOperations(bop, inBlock2, new
MatrixBlock());
} else if(inBlock2 instanceof CompressedMatrixBlock){
- LOG.error("Binary CP instruction decompressing " + bop);
- LOG.error("inBlock2 stats: " + inBlock2.getNumRows() +
" " +inBlock2.getNumColumns());
- inBlock2 =
AbstractCompressedMatrixBlock.getUncompressed(inBlock2);
- retBlock = inBlock1.binaryOperations(bop, inBlock2, new
MatrixBlock());
+ if((bop.fn instanceof Multiply)){
+ retBlock =
((CompressedMatrixBlock)inBlock2).binaryOperationsLeft(bop, inBlock1, new
MatrixBlock());
+ }else{
+ LOG.error("Binary CP instruction decompressing
" + bop);
+ LOG.error("inBlock2 stats: " +
inBlock2.getNumRows() + " " +inBlock2.getNumColumns());
+ inBlock2 =
AbstractCompressedMatrixBlock.getUncompressed(inBlock2);
+ retBlock = inBlock1.binaryOperations(bop,
inBlock2, new MatrixBlock());
+ }
} else {
retBlock = inBlock1.binaryOperations(bop, inBlock2, new
MatrixBlock());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index aa0bb45..ed090c3 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -19,6 +19,22 @@
package org.apache.sysds.runtime.matrix.data;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutput;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.stream.IntStream;
+
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.math3.random.Well1024a;
@@ -30,6 +46,8 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.lib.LibBinaryCellOp;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
@@ -95,22 +113,6 @@ import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.NativeHelper;
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutput;
-import java.io.ObjectOutputStream;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
-import java.util.stream.IntStream;
-
public class MatrixBlock extends MatrixValue implements CacheBlock,
Externalizable
{
@@ -3619,7 +3621,10 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
for( MatrixBlock in : inputs ) {
if( in.isEmptyBlock(false) )
continue;
- if( in.isInSparseFormat() ) {
+ if(in instanceof CompressedMatrixBlock){
+ in =
LibBinaryCellOp.binaryMVRow((CompressedMatrixBlock) in,c, null, new
BinaryOperator(Plus.getPlusFnObject()) );
+ }
+ else if( in.isInSparseFormat() ) {
SparseBlock a = in.getSparseBlock();
if( a.isEmpty(i) ) continue;
LibMatrixMult.vectAdd(a.values(i), c,
a.indexes(i), a.pos(i), cix, a.size(i));
diff --git a/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
b/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
index cb8031e..0f7fda5 100644
--- a/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
@@ -33,6 +33,23 @@ public class DMLCompressionStatistics {
private static int DecompressMTCount = 0;
private static double DecompressMT = 0.0;
+ public static void reset() {
+ Phase0 = 0.0;
+ Phase1 = 0.0;
+ Phase2 = 0.0;
+ Phase3 = 0.0;
+ Phase4 = 0.0;
+ Phase5 = 0.0;
+ DecompressSTCount = 0;
+ DecompressST = 0.0;
+ DecompressMTCount = 0;
+ DecompressMT = 0.0;
+ }
+
+ public static boolean haveCompressed(){
+ return Phase0 > 0;
+ }
+
public static void addCompressionTime(double time, int phase) {
switch(phase) {
case 0:
@@ -77,9 +94,8 @@ public class DMLCompressionStatistics {
}
public static void display(StringBuilder sb) {
- if(Phase0 > 0.0){ // If compression have been used
- sb.append(String.format(
- "CLA Compression Phases
:\t%.3f/%.3f/%.3f/%.3f/%.3f/%.3f\n",
+ if(haveCompressed()) { // If compression have been used
+ sb.append(String.format("CLA Compression Phases
:\t%.3f/%.3f/%.3f/%.3f/%.3f/%.3f\n",
Phase0 / 1000,
Phase1 / 1000,
Phase2 / 1000,
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index b40e905..320f610 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -515,6 +515,8 @@ public class Statistics
federatedGetCount.reset();
federatedExecuteInstructionCount.reset();
federatedExecuteUDFCount.reset();
+
+ DMLCompressionStatistics.reset();
}
public static void resetJITCompileTime(){
diff --git a/src/main/python/systemds/context/systemds_context.py
b/src/main/python/systemds/context/systemds_context.py
index 3e29f62..107d875 100644
--- a/src/main/python/systemds/context/systemds_context.py
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -123,10 +123,10 @@ class SystemDSContext(object):
if rep > 3:
raise Exception("Failed to start SystemDS context with " + rep
+ " repeated tries")
else:
- ret += 1
+ rep += 1
print("Failed to startup JVM process, retrying: " + rep)
sleep(rep) # Sleeping increasingly long time, maybe this helps.
- return self.__try_startup()
+ return self.__try_startup(command, rep)
def __verify_startup(self, process):
first_stdout = process.stdout.readline()
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index d51f05b..5b7127b 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -142,6 +142,7 @@ public abstract class AutomatedTestBase {
* Location of the SystemDS config file that we use as a template when
generating the configs for each test case.
*/
private static final File CONFIG_TEMPLATE_FILE = new File(CONFIG_DIR,
"SystemDS-config.xml");
+ protected boolean disableConfigFile = false;
protected enum CodegenTestType {
DEFAULT, FUSE_ALL, FUSE_NO_REDUNDANCY;
@@ -1069,14 +1070,17 @@ public abstract class AutomatedTestBase {
String localTemp = curLocalTempDir.getPath();
String configContents = configTemplate
.replace(createXMLElement(DMLConfig.SCRATCH_SPACE, "scratch_space"),
-
createXMLElement(DMLConfig.SCRATCH_SPACE, localTemp + "/scratch_space"))
+
createXMLElement(DMLConfig.SCRATCH_SPACE, localTemp + "/target/scratch_space"))
.replace(createXMLElement(DMLConfig.LOCAL_TMP_DIR, "/tmp/systemds"),
createXMLElement(DMLConfig.LOCAL_TMP_DIR, localTemp + "/localtmp"));
- FileUtils.write(getCurConfigFile(), configContents,
"UTF-8");
+ if(!disableConfigFile){
- if(LOG.isDebugEnabled())
- LOG.debug("This test case will use SystemDS
config file %s\n" + getCurConfigFile());
+ FileUtils.write(getCurConfigFile(),
configContents, "UTF-8");
+
+ if(LOG.isDebugEnabled())
+ LOG.debug("This test case will use
SystemDS config file %s\n" + getCurConfigFile());
+ }
}
catch(IOException e) {
throw new RuntimeException(e);
@@ -1388,7 +1392,7 @@ public abstract class AutomatedTestBase {
args.add(executionFile);
}
- addProgramIndependentArguments(args);
+ addProgramIndependentArguments(args, programArgs);
// program-specific parameters
if(newWay) {
@@ -1460,7 +1464,7 @@ public abstract class AutomatedTestBase {
DMLScript.executeScript(conf, otherArgs);
}
- private void addProgramIndependentArguments(ArrayList<String> args) {
+ private void addProgramIndependentArguments(ArrayList<String> args,
String[] otherArgs) {
// program-independent parameters
args.add("-exec");
@@ -1472,12 +1476,23 @@ public abstract class AutomatedTestBase {
args.add("singlenode");
else if(rtplatform == ExecMode.SPARK)
args.add("spark");
- else {
+ else
throw new RuntimeException("Unknown runtime platform: "
+ rtplatform);
- }
+
// use optional config file since default under SystemDS/DML
- args.add("-config");
- args.add(getCurConfigFile().getPath());
+ boolean configSpecified = false;
+ if(otherArgs != null)
+ for(String i: otherArgs)
+ if(i.equals("-config")){
+ configSpecified = true;
+ break;
+ }
+
+
+ if(!configSpecified){
+ args.add("-config");
+ args.add(getCurConfigFile().getPath());
+ }
if(TEST_GPU)
args.add("-gpu");
@@ -1555,7 +1570,7 @@ public abstract class AutomatedTestBase {
String[] fedWorkArgs = {"-w", Integer.toString(port)};
ArrayList<String> args = new ArrayList<>();
- addProgramIndependentArguments(args);
+ addProgramIndependentArguments(args, otherArgs);
if (otherArgs != null)
args.addAll(Arrays.stream(otherArgs).collect(Collectors.toList()));
diff --git
a/src/test/java/org/apache/sysds/test/functions/compress/compressScale.java
b/src/test/java/org/apache/sysds/test/functions/compress/compressScale.java
new file mode 100644
index 0000000..f1ac926
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/compress/compressScale.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.compress;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.DMLCompressionStatistics;
+import org.junit.Test;
+
+public class compressScale extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(compressScale.class.getName());
+
+ protected String getTestClassDir() {
+ return getTestDir() + this.getClass().getSimpleName() + "/";
+ }
+
+ protected String getTestName() {
+ return "scale";
+ }
+
+ protected String getTestDir() {
+ return "functions/compress/compressScale/";
+ }
+
+ // @Test
+ // public void testInstruction_01() {
+ // compressTest(4, 1000000, 0.2, ExecType.CP, -100, 1000, 1, 1);
+ // }
+
+ @Test
+ public void testInstruction_01_1() {
+ compressTest(50, 1000000, 0.2, ExecType.CP, 1, 2, 1, 1);
+ }
+
+ // @Test
+ // public void testInstruction_02() {
+ // compressTest(10, 200000, 0.2, ExecType.CP, 0, 5, 0, 1);
+ // }
+
+ // @Test
+ // public void testInstruction_03() {
+ // compressTest(10, 200000, 0.2, ExecType.CP, 0, 5, 0, 0);
+ // }
+
+ // @Test
+ // public void testInstruction_04() {
+ // compressTest(10, 200000, 0.2, ExecType.CP, 0, 5, 1, 0);
+ // }
+
+ public void compressTest(int cols, int rows, double sparsity,
LopProperties.ExecType instType, int min, int max,
+ int scale, int center) {
+
+ Types.ExecMode platformOld = setExecMode(instType);
+ try {
+
+ fullDMLScriptName = SCRIPT_DIR + "/" + getTestDir() +
getTestName() + ".dml";
+ loadTestConfiguration(getTestConfiguration(getTestName()));
+
+ // Default arguments
+ programArgs = new String[] {"-config", "", "-nvargs", "cols=" +
cols, "rows=" + rows,
+ "sparsity=" + sparsity, "min=" + min, "max= " + max, "scale="
+ scale, "center=" + center};
+
+ // Default execution
+ programArgs[1] = configPath("SystemDS-config-default.xml");
+ double outStd =
Double.parseDouble(runTest(null).toString().split("\n")[0].split(" ")[0]);
+ LOG.debug("ULA : " + outStd);
+
+ programArgs[1] =
configPath("SystemDS-config-compress-cost-RLE.xml");
+ double RLEoutC =
Double.parseDouble(runTest(null).toString().split("\n")[0].split(" ")[0]);
+ assertTrue(DMLCompressionStatistics.haveCompressed());
+ DMLCompressionStatistics.reset();
+ LOG.debug("RLE : " + RLEoutC);
+
+ programArgs[1] =
configPath("SystemDS-config-compress-cost-OLE.xml");
+ double OLEOutC =
Double.parseDouble(runTest(null).toString().split("\n")[0].split(" ")[0]);
+ assertTrue(DMLCompressionStatistics.haveCompressed());
+ DMLCompressionStatistics.reset();
+ LOG.debug("OLE : " + OLEOutC);
+
+ programArgs[1] =
configPath("SystemDS-config-compress-cost-DDC.xml");
+ double DDCoutC =
Double.parseDouble(runTest(null).toString().split("\n")[0].split(" ")[0]);
+ assertTrue(DMLCompressionStatistics.haveCompressed());
+ DMLCompressionStatistics.reset();
+ LOG.debug("DDC : " + DDCoutC);
+
+ programArgs[1] = configPath("SystemDS-config-compress-cost.xml");
+ double ALLoutC =
Double.parseDouble(runTest(null).toString().split("\n")[0].split(" ")[0]);
+ assertTrue(DMLCompressionStatistics.haveCompressed());
+ DMLCompressionStatistics.reset();
+ LOG.debug("CLA : " + ALLoutC);
+
+ assertEquals(outStd, OLEOutC, 0.1);
+ assertEquals(outStd, RLEoutC, 0.1);
+ assertEquals(outStd, DDCoutC, 0.1);
+ assertEquals(outStd, ALLoutC, 0.1);
+
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ assertTrue("Exception in execution: " + e.getMessage(), false);
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ @Override
+ public void setUp() {
+ disableConfigFile = true;
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName()));
+ }
+
+ private String configPath(String file) {
+ String out = (SCRIPT_DIR + getTestDir() + file).substring(2);
+ return out;
+ }
+
+}
diff --git
a/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-DDC.xml
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-DDC.xml
new file mode 100644
index 0000000..6f374e6
--- /dev/null
+++
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-DDC.xml
@@ -0,0 +1,25 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+ <sysds.compressed.linalg>cost</sysds.compressed.linalg>
+ <sysds.cp.parallel.ops>true</sysds.cp.parallel.ops>
+ <sysds.scratch>target/cost_scale_scratch_space_DDC</sysds.scratch>
+
<sysds.compressed.valid.compressions>DDC</sysds.compressed.valid.compressions>
+</root>
diff --git
a/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-OLE.xml
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-OLE.xml
new file mode 100644
index 0000000..3e35db1
--- /dev/null
+++
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-OLE.xml
@@ -0,0 +1,25 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+ <sysds.compressed.linalg>cost</sysds.compressed.linalg>
+ <sysds.cp.parallel.ops>true</sysds.cp.parallel.ops>
+ <sysds.scratch>target/cost_scale_scratch_space)OLE</sysds.scratch>
+
<sysds.compressed.valid.compressions>OLE</sysds.compressed.valid.compressions>
+</root>
diff --git
a/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-RLE.xml
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-RLE.xml
new file mode 100644
index 0000000..d81b04e
--- /dev/null
+++
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost-RLE.xml
@@ -0,0 +1,25 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+ <sysds.compressed.linalg>cost</sysds.compressed.linalg>
+ <sysds.cp.parallel.ops>true</sysds.cp.parallel.ops>
+ <sysds.scratch>target/cost_scale_scratch_space_RLE</sysds.scratch>
+
<sysds.compressed.valid.compressions>RLE</sysds.compressed.valid.compressions>
+</root>
diff --git
a/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost.xml
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost.xml
new file mode 100644
index 0000000..45d91bf
--- /dev/null
+++
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-compress-cost.xml
@@ -0,0 +1,24 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+ <sysds.compressed.linalg>cost</sysds.compressed.linalg>
+ <sysds.cp.parallel.ops>true</sysds.cp.parallel.ops>
+ <sysds.scratch>target/cost_scale_scratch_space</sysds.scratch>
+</root>
diff --git
a/src/test/scripts/functions/compress/compressScale/SystemDS-config-default.xml
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-default.xml
new file mode 100644
index 0000000..38079ab
--- /dev/null
+++
b/src/test/scripts/functions/compress/compressScale/SystemDS-config-default.xml
@@ -0,0 +1,24 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+ <sysds.compressed.linalg>false</sysds.compressed.linalg>
+ <sysds.cp.parallel.ops>true</sysds.cp.parallel.ops>
+ <sysds.scratch>target/default_scale_scratch_space</sysds.scratch>
+</root>
diff --git a/src/test/scripts/functions/compress/compressScale/scale.dml
b/src/test/scripts/functions/compress/compressScale/scale.dml
new file mode 100644
index 0000000..6d9c5eb
--- /dev/null
+++ b/src/test/scripts/functions/compress/compressScale/scale.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = rand(rows=$rows, cols=$cols, sparsity=$sparsity, min=$min, max=$max, seed
=14)
+A = round(A)
+for(i in 1:10)
+ A = A * i
+
+[A,c,s] = scale(A,$center, $scale)
+print(sum(A))
+
+