This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 54c8696510 [MINOR] Compressed tests
54c8696510 is described below
commit 54c869651038cd9593103b26db90ca8e67a21949
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Wed Feb 5 20:39:05 2025 +0100
[MINOR] Compressed tests
This commit follows up on the rexpand instruction to improve the
test coverage and fix a few bugs in CLA.
Closes #2214
---
.../runtime/compress/colgroup/AColGroupValue.java | 22 +-
.../runtime/compress/colgroup/ColGroupConst.java | 5 +-
.../runtime/compress/colgroup/ColGroupDDCFOR.java | 29 +-
.../runtime/compress/colgroup/ColGroupSDC.java | 17 +-
.../runtime/compress/colgroup/ColGroupSDCFOR.java | 2 +-
.../compress/colgroup/ColGroupSDCSingle.java | 19 +-
.../colgroup/dictionary/DeltaDictionary.java | 5 +
.../compress/colgroup/dictionary/Dictionary.java | 14 +-
.../compress/colgroup/dictionary/IDictionary.java | 8 +
.../colgroup/dictionary/IdentityDictionary.java | 7 +
.../dictionary/IdentityDictionarySlice.java | 12 +-
.../colgroup/dictionary/MatrixBlockDictionary.java | 9 +
.../colgroup/dictionary/PlaceHolderDict.java | 5 +
.../compress/colgroup/dictionary/QDictionary.java | 5 +
.../sysds/runtime/compress/lib/CLALibRexpand.java | 12 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 1 -
.../component/compress/CompressedTestBase.java | 104 +++++-
.../component/compress/CompressedVectorTest.java | 31 +-
.../compress/colgroup/CombineColGroups.java | 157 ++++++++++
.../compress/colgroup/scheme/SchemeTestBase.java | 347 ++++++++++++---------
.../compress/colgroup/scheme/SchemeTestSDC.java | 1 -
.../compress/dictionary/CustomDictionaryTest.java | 13 +
.../compress/dictionary/DictionaryTests.java | 149 +++++++++
.../table/CompressedTableOverwriteTest.java | 122 ++++++++
.../wordembedding/wordEmbeddingUseCase.java | 149 +++++++++
25 files changed, 1042 insertions(+), 203 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
index f3b37daa10..0cde289b30 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
@@ -55,9 +55,9 @@ public abstract class AColGroupValue extends
ADictBasedColGroup {
}
/**
- * Returns the counts of values inside the dictionary. If already
calculated it will return the previous counts.
- * This produce an overhead in cases where the count is calculated, but
the overhead will be limited to number of
- * distinct tuples in the dictionary.
+ * Returns the counts of values inside the dictionary. If already
calculated it will return the previous counts. This
+ * produce an overhead in cases where the count is calculated, but the
overhead will be limited to number of distinct
+ * tuples in the dictionary.
*
* The returned counts always contains the number of zero tuples as
well if there are some contained, even if they
* are not materialized.
@@ -195,16 +195,16 @@ public abstract class AColGroupValue extends
ADictBasedColGroup {
@Override
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int
nRows) {
- try {
- IDictionary d = _dict.rexpandCols(max, ignore, cast,
_colIndexes.size());
- if(d == null)
- return ColGroupEmpty.create(max);
- else
- return copyAndSet(ColIndexFactory.create(max),
d);
- }
- catch(DMLCompressionException e) {
+ IDictionary d = _dict.rexpandCols(max, ignore, cast,
_colIndexes.size());
+ if(d == null) {
+ if(max <= 0)
+ return null;
return ColGroupEmpty.create(max);
}
+ else {
+ IColIndex outCols =
ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
+ return copyAndSet(outCols, d);
+ }
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
index a493b14f04..21c6a0e1d8 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
@@ -527,8 +527,11 @@ public class ColGroupConst extends ADictBasedColGroup
implements IContainDefault
@Override
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int
nRows) {
IDictionary d = _dict.rexpandCols(max, ignore, cast,
_colIndexes.size());
- if(d == null)
+ if(d == null){
+ if(max <= 0)
+ return null;
return ColGroupEmpty.create(max);
+ }
else
return create(max, d);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
index cb51579875..70191a2793 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
@@ -26,7 +26,6 @@ import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.NotImplementedException;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
@@ -392,33 +391,15 @@ public class ColGroupDDCFOR extends AMorphingMMColGroup
implements IFrameOfRefer
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int
nRows) {
final int def = (int) _reference[0];
IDictionary d = _dict.rexpandColsWithReference(max, ignore,
cast, def);
-
if(d == null) {
- if(def <= 0 || def > max)
- return ColGroupEmpty.create(max);
- else {
- double[] retDef = new double[max];
- retDef[def - 1] = 1;
- return ColGroupConst.create(retDef);
- }
+ if(max <= 0)
+ return null;
+ return ColGroupEmpty.create(max);
}
else {
- IColIndex outCols = ColIndexFactory.create(max);
- if(def <= 0) {
- if(ignore)
- return ColGroupDDC.create(outCols, d,
_data, getCachedCounts());
- else
- throw new DMLRuntimeException("Invalid
content of zero in rexpand");
- }
- else if(def > max)
- return ColGroupDDC.create(outCols, d, _data,
getCachedCounts());
- else {
- double[] retDef = new double[max];
- retDef[def - 1] = 1;
- return ColGroupDDCFOR.create(outCols, d, _data,
getCachedCounts(), retDef);
- }
+ IColIndex outCols =
ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
+ return ColGroupDDC.create(outCols, d, _data,
getCachedCounts());
}
-
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 541c2487d5..1270823bfd 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -500,15 +500,24 @@ public class ColGroupSDC extends ASDC implements
IMapToDataGroup {
@Override
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int
nRows) {
IDictionary d = _dict.rexpandCols(max, ignore, cast,
_colIndexes.size());
- return rexpandCols(max, ignore, cast, nRows, d, _indexes,
_data, getCachedCounts(), (int) _defaultTuple[0]);
+ return rexpandCols(max, ignore, cast, nRows, d, _indexes,
_data, getCachedCounts(), (int) _defaultTuple[0],
+ _dict.getNumberOfValues(1));
}
protected static AColGroup rexpandCols(int max, boolean ignore, boolean
cast, int nRows, IDictionary d,
- AOffset indexes, AMapToData data, int[] counts, int def) {
+ AOffset indexes, AMapToData data, int[] counts, int def, int
nVal) {
if(d == null) {
- if(def <= 0 || def > max)
+ if(def <= 0){
+ if(max > 0)
+ return ColGroupEmpty.create(max);
+ else
+ return null;
+ }
+ else if(def > max && max > 0)
return ColGroupEmpty.create(max);
+ else if(max <= 0)
+ return null;
else {
double[] retDef = new double[max];
retDef[def - 1] = 1;
@@ -517,7 +526,7 @@ public class ColGroupSDC extends ASDC implements
IMapToDataGroup {
}
}
else {
- final IColIndex outCols = ColIndexFactory.create(max);
+ final IColIndex outCols =
ColIndexFactory.create(d.getNumberOfColumns(nVal));
if(def <= 0) {
if(ignore)
return ColGroupSDCZeros.create(outCols,
nRows, d, indexes, data, counts);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
index 4c4b2e20a5..41fb7ac570 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
@@ -427,7 +427,7 @@ public class ColGroupSDCFOR extends ASDC implements
IMapToDataGroup, IFrameOfRef
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int
nRows) {
IDictionary d = _dict.rexpandColsWithReference(max, ignore,
cast, (int) _reference[0]);
return ColGroupSDC.rexpandCols(max, ignore, cast, nRows, d,
_indexes, _data, getCachedCounts(),
- (int) _reference[0]);
+ (int) _reference[0], _dict.getNumberOfValues(1));
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index f63df96fa7..fa5772c0c3 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -85,7 +85,7 @@ public class ColGroupSDCSingle extends ASDC {
if(offsets instanceof OffsetEmpty)
return ColGroupConst.create(colIndexes, defaultTuple);
final boolean allZero = ColGroupUtils.allZero(defaultTuple);
- if(dict == null && allZero)
+ if(dict == null && allZero)
return new ColGroupEmpty(colIndexes);
else if(dict == null && offsets.getSize() * 2 > numRows + 2) {
AOffset rev = offsets.reverse(numRows);
@@ -469,8 +469,16 @@ public class ColGroupSDCSingle extends ASDC {
IDictionary d = _dict.rexpandCols(max, ignore, cast,
_colIndexes.size());
final int def = (int) _defaultTuple[0];
if(d == null) {
- if(def <= 0 || def > max)
+ if(def <= 0){
+ if(max > 0)
+ return ColGroupEmpty.create(max);
+ else
+ return null;
+ }
+ else if(def > max && max > 0)
return ColGroupEmpty.create(max);
+ else if(max <= 0)
+ return null;
else {
double[] retDef = new double[max];
retDef[((int) _defaultTuple[0]) - 1] = 1;
@@ -478,18 +486,19 @@ public class ColGroupSDCSingle extends ASDC {
}
}
else {
+ final IColIndex outCols =
ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
if(def <= 0) {
if(ignore)
- return
ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes,
getCachedCounts());
+ return
ColGroupSDCSingleZeros.create(outCols, nRows, d, _indexes, getCachedCounts());
else
throw new DMLRuntimeException("Invalid
content of zero in rexpand");
}
else if(def > max)
- return
ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes,
getCachedCounts());
+ return ColGroupSDCSingleZeros.create(outCols,
nRows, d, _indexes, getCachedCounts());
else {
double[] retDef = new double[max];
retDef[((int) _defaultTuple[0]) - 1] = 1;
- return
ColGroupSDCSingle.create(ColIndexFactory.create(max), nRows, d, retDef,
_indexes, getCachedCounts());
+ return ColGroupSDCSingle.create(outCols, nRows,
d, retDef, _indexes, getCachedCounts());
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
index 5bbc1af594..d67ab95f82 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
@@ -97,6 +97,11 @@ public class DeltaDictionary extends ADictionary {
return _values.length / ncol;
}
+ @Override
+ public int getNumberOfColumns(int nrow){
+ return _values.length / nrow;
+ }
+
@Override
public String getString(int colIndexes) {
throw new NotImplementedException();
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index 139254b534..939b48bf42 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -41,7 +41,7 @@ import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.utils.MemoryEstimates;
@@ -388,6 +388,11 @@ public class Dictionary extends ACachingMBDictionary {
return _values.length / nCol;
}
+ @Override
+ public int getNumberOfColumns(int nrow) {
+ return _values.length / nrow;
+ }
+
@Override
public double[] sumAllRowsToDouble(int nrColumns) {
if(nrColumns == 1)
@@ -1120,8 +1125,11 @@ public class Dictionary extends ACachingMBDictionary {
MatrixBlockDictionary m = getMBDict(1);
if(m == null)
return null;
- IDictionary a = m.applyScalarOp(new
LeftScalarOperator(Plus.getPlusFnObject(), reference));
- return a == null ? null : a.rexpandCols(max, ignore, cast, 1);
+ IDictionary a = m.applyScalarOp(new
RightScalarOperator(Plus.getPlusFnObject(), reference));
+ if(a == null)
+ return null; // second ending
+ a = a.rexpandCols(max, ignore, cast, 1);
+ return a;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
index 54b7cc809d..dddea0eec7 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
@@ -327,6 +327,14 @@ public interface IDictionary {
*/
public int getNumberOfValues(int ncol);
+ /**
+ * Get the number of columns in this dictionary, provided you know the
number of values, or rows.
+ *
+ * @param nrow The number of rows/values known inside this dictionary
+ * @return The number of columns
+ */
+ public int getNumberOfColumns(int nrow);
+
/**
* Method used as a pre-aggregate of each tuple in the dictionary, to
single double values.
*
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
index 41982a6842..40e1b06565 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
@@ -194,6 +194,13 @@ public class IdentityDictionary extends
AIdentityDictionary {
return nRowCol + (withEmpty ? 1 : 0);
}
+ @Override
+ public int getNumberOfColumns(int nrow) {
+ if(nrow != (nRowCol + (withEmpty ? 1 : 0)))
+ throw new DMLCompressionException("Invalid call to get
Number of values assuming wrong number of columns");
+ return nRowCol;
+ }
+
@Override
public double[] sumAllRowsToDouble(int nrColumns) {
if(withEmpty) {
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
index 0f07e1eac7..df702524d5 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
@@ -25,6 +25,7 @@ import java.io.IOException;
import java.util.Arrays;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -192,7 +193,7 @@ public class IdentityDictionarySlice extends
AIdentityDictionary {
public double[] productAllRowsToDouble(int nCol) {
double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)];
if(u - l - 1 == 0)
- ret[l] = 1;
+ ret[l] = 1;
return ret;
}
@@ -201,7 +202,7 @@ public class IdentityDictionarySlice extends
AIdentityDictionary {
int nVal = nRowCol + (withEmpty ? 1 : 0);
double[] ret = new double[nVal + 1];
if(u - l - 1 == 0)
- ret[l] = 1;
+ ret[l] = 1;
ret[nVal] = defaultTuple[0];
for(int i = 1; i < defaultTuple.length; i++)
ret[nVal] *= defaultTuple[i];
@@ -237,6 +238,13 @@ public class IdentityDictionarySlice extends
AIdentityDictionary {
return nRowCol + (withEmpty ? 1 : 0);
}
+ @Override
+ public int getNumberOfColumns(int nrow) {
+ if(nrow != (nRowCol + (withEmpty ? 1 : 0)))
+ throw new DMLCompressionException("Invalid call to get
Number of values assuming wrong number of columns");
+ return u - l;
+ }
+
@Override
public void write(DataOutput out) throws IOException {
out.writeByte(DictionaryFactory.Type.IDENTITY_SLICE.ordinal());
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
index 57f3a80e03..1d6949cbcd 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
@@ -936,6 +936,13 @@ public class MatrixBlockDictionary extends ADictionary {
return _data.getNumRows();
}
+ @Override
+ public int getNumberOfColumns(int nrow) {
+ if(nrow != _data.getNumRows())
+ throw new DMLCompressionException("Invalid call to get
number of columns assuming wrong number of rows");
+ return _data.getNumColumns();
+ }
+
@Override
public double[] sumAllRowsToDouble(int nrColumns) {
double[] ret = new double[_data.getNumRows()];
@@ -2397,6 +2404,8 @@ public class MatrixBlockDictionary extends ADictionary {
if(nCol > 1)
throw new DMLCompressionException("Invalid to rexpand
the column groups if more than one column");
MatrixBlock ret = LibMatrixReorg.rexpand(_data, new
MatrixBlock(), max, false, cast, ignore, 1);
+ if(ret.getNumColumns() == 0)
+ return null;
return MatrixBlockDictionary.create(ret);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
index f5c140e522..f5746647a3 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
@@ -53,6 +53,11 @@ public class PlaceHolderDict extends ADictionary {
return nVal;
}
+ @Override
+ public int getNumberOfColumns(int nrow) {
+ throw new RuntimeException("invalid to get number of columns
for PlaceHolderDict");
+ }
+
@Override
public MatrixBlockDictionary getMBDict() {
throw new RuntimeException(errMessage);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
index 35a08b8d14..6802d920b4 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
@@ -142,6 +142,11 @@ public class QDictionary extends ACachingMBDictionary {
return _values.length / nCol;
}
+ @Override
+ public int getNumberOfColumns(int nCol) {
+ return _values.length / nCol;
+ }
+
@Override
public double[] sumAllRowsToDouble(int nrColumns) {
if(nrColumns == 1)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java
index 5be508febd..1bf43c49e5 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java
@@ -111,9 +111,15 @@ public final class CLALibRexpand {
cast, ignore, k);
else {
CompressedMatrixBlock retC = new
CompressedMatrixBlock(nRows, max);
-
retC.allocateColGroup(in.getColGroups().get(0).rexpandCols(max, ignore, cast,
nRows));
- retC.recomputeNonZeros();
- return retC;
+ AColGroup g = in.getColGroups().get(0).rexpandCols(max,
ignore, cast, nRows);
+ if(g == null)
+ return new MatrixBlock(nRows,0,0);
+ else {
+ retC.setNumColumns(g.getNumCols());
+ retC.allocateColGroup(g);
+ retC.recomputeNonZeros();
+ return retC;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 33903954a2..7bc516588a 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5010,7 +5010,6 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
return LibMatrixReorg.rexpand(this, result, max, rows, cast,
ignore, k);
}
-
@Override
public final MatrixBlock replaceOperations(MatrixValue result, double
pattern, double replacement) {
return replaceOperations(result, pattern, replacement, 1);
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
index 8692f56b69..c1fb10d211 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
@@ -307,7 +307,8 @@ public abstract class CompressedTestBase extends TestBase {
}
else if(ov ==
OverLapping.PLUS_ROW_VECTOR) {
- MatrixBlock v =
TestUtils.generateTestMatrixBlock(1, cols, -1, 1, 1.0, 4);
+ MatrixBlock v =
TestUtils.generateTestMatrixBlock(1, cols, 0, 4, 1.0, 4);
+ v = TestUtils.ceil(v);
BinaryOperator bop = new
BinaryOperator(Plus.getPlusFnObject(), _k);
mb = mb.binaryOperations(bop,
v, null);
cmb = cmb.binaryOperations(bop,
v, null);
@@ -504,13 +505,15 @@ public abstract class CompressedTestBase extends TestBase
{
@Test
public void testVectorMatrixMult() {
- MatrixBlock vector = TestUtils.generateTestMatrixBlock(1, rows,
0.9, 1.5, 1.0, 3);
+ MatrixBlock vector = TestUtils.generateTestMatrixBlock(1, rows,
0, 5, 1.0, 3);
+ vector = TestUtils.ceil(vector);
testLeftMatrixMatrix(vector);
}
@Test
public void testLeftMatrixMatrixMultSmall() {
- MatrixBlock matrix = TestUtils.generateTestMatrixBlock(3, rows,
0.9, 1.5, 1.0, 3);
+ MatrixBlock matrix = TestUtils.generateTestMatrixBlock(3, rows,
0, 5, 1.0, 3);
+ matrix = TestUtils.ceil(matrix);
testLeftMatrixMatrix(matrix);
}
@@ -522,7 +525,8 @@ public abstract class CompressedTestBase extends TestBase {
@Test
public void testLeftMatrixMatrixMultSparse() {
- MatrixBlock matrix = TestUtils.generateTestMatrixBlock(2, rows,
0.9, 1.5, .1, 3);
+ MatrixBlock matrix = TestUtils.generateTestMatrixBlock(2, rows,
0, 5, .1, 3);
+ matrix = TestUtils.ceil(matrix);
testLeftMatrixMatrix(matrix);
}
@@ -1053,6 +1057,98 @@ public abstract class CompressedTestBase extends
TestBase {
}
}
+
+ @Test
+ public void testReshape2() {
+ testReshape(2);
+ }
+
+ @Test
+ public void testReshape3() {
+ testReshape(3);
+ }
+
+ @Test
+ public void testReshape10() {
+ testReshape(10);
+ }
+
+ /**
+ * Test the reshape mechanic of the compressed block by reshaping the
matrix by making it x times wider.
+ *
+ * @param multiplier the multiplier x.
+ */
+ public void testReshape(int multiplier) {
+ try {
+ if((double) rows / multiplier != rows / multiplier)
+ return;
+
+ final MatrixBlock ret2 = cmb.reshape(rows / multiplier,
cols * multiplier, true);
+ final MatrixBlock ret1 = mb.reshape(rows / multiplier,
cols * multiplier, true);
+ compareResultMatrices(ret1, ret2, 1);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ throw new DMLRuntimeException("Error in Reshape", e);
+ }
+ }
+
+ @Test
+ public void testReshape2_divider() {
+ testReshapeDivider(2);
+ }
+
+ @Test
+ public void testReshape3_divider() {
+ testReshapeDivider(3);
+ }
+
+ @Test
+ public void testReshape10_divider() {
+ testReshapeDivider(10);
+ }
+
+ /**
+ * Test the reshape mechanic of the compressed block by reshaping the
matrix by making it x times taller.
+ *
+ * @param divider the divider x.
+ */
+ public void testReshapeDivider(int divider) {
+ try {
+ if((double) cols /divider != cols / divider)
+ return;
+
+ final MatrixBlock ret2 = cmb.reshape(rows * divider,
cols / divider, true);
+ final MatrixBlock ret1 = mb.reshape(rows * divider,
cols / divider, true);
+ compareResultMatrices(ret1, ret2, 1);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ throw new DMLRuntimeException("Error in Reshape", e);
+ }
+ }
+
+
+ @Test
+ public void testReshape_opposite() {
+ testReshape(cols, rows);
+ }
+
+ public void testReshape(int newRows, int newCols) {
+ try {
+ if((double) newRows * newCols != rows * cols)
+ return;
+
+ final MatrixBlock ret2 = cmb.reshape(newRows, newCols,
true);
+ final MatrixBlock ret1 = mb.reshape(newRows, newCols,
true);
+ compareResultMatrices(ret1, ret2, 1);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ throw new DMLRuntimeException("Error in Reshape", e);
+ }
+ }
+
@Test
public void testCompressAgain() {
try {
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java
b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java
index f30f3401c7..5d91ddb273 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java
@@ -158,14 +158,41 @@ public class CompressedVectorTest extends
CompressedTestBase {
testReExpand(true);
}
+ @Test
+ public void testReExpandColNoIgnore() {
+ testReExpand(true, 0, false, true);
+ }
+
+ @Test
+ public void testReExpandColNoCast() {
+ testReExpand(true, 0, false, false);
+ }
+
public void testReExpand(boolean col) {
+ testReExpand(col, 50, true, true);
+ }
+
+ public void testReExpand(boolean col, int max, boolean ignore, boolean
cast) {
try {
if(cmb instanceof CompressedMatrixBlock) {
- MatrixBlock ret1 = cmb.rexpandOperations(new
MatrixBlock(), 50, !col, true, true, _k);
- MatrixBlock ret2 = mb.rexpandOperations(new
MatrixBlock(), 50, !col, true, true, _k);
+ MatrixBlock ret1 = null;
+ try{
+ ret1 = cmb.rexpandOperations(new
MatrixBlock(), max, !col, cast, ignore, _k);
+ }
+ catch(RuntimeException re){
+ if(! re.getMessage().contains("Invalid
input value <= 0 for ignore=false:"))
+ throw re;
+ else
+ return; // great!
+ }
+ MatrixBlock ret2 = mb.rexpandOperations(new
MatrixBlock(), max, !col, cast, ignore, _k);
compareResultMatrices(ret2, ret1, 0);
}
}
+ catch(AssertionError e){
+ LOG.error(cmb);
+ throw e;
+ }
catch(Exception e) {
e.printStackTrace();
throw e;
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/CombineColGroups.java
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CombineColGroups.java
new file mode 100644
index 0000000000..0168a176ce
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CombineColGroups.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.colgroup;
+
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
+import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.estim.EstimationFactors;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class CombineColGroups {
+ protected static final Log LOG =
LogFactory.getLog(CombineColGroups.class.getName());
+
+ /** Uncompressed ground truth */
+ final MatrixBlock mb;
+ /** ColGroup 1 */
+ final AColGroup a;
+ /** ColGroup 2 */
+ final AColGroup b;
+
+ @Parameters
+ public static Collection<Object[]> data() {
+ ArrayList<Object[]> tests = new ArrayList<>();
+
+ try {
+ addTwoCols(tests, 100, 3);
+ addTwoCols(tests, 1000, 3);
+ // addSingleVSMultiCol(tests, 100, 3, 1, 3);
+ // addSingleVSMultiCol(tests, 100, 3, 3, 4);
+ addSingleVSMultiCol(tests, 1000, 3, 1, 3, 1.0);
+ addSingleVSMultiCol(tests, 1000, 3, 3, 4, 1.0);
+ addSingleVSMultiCol(tests, 1000, 3, 3, 1, 1.0);
+ addSingleVSMultiCol(tests, 1000, 2, 1, 10, 0.05);
+ addSingleVSMultiCol(tests, 1000, 2, 10, 10, 0.05);
+ addSingleVSMultiCol(tests, 1000, 2, 10, 1, 0.05);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail("failed constructing tests");
+ }
+
+ return tests;
+ }
+
+ public CombineColGroups(MatrixBlock mb, AColGroup a, AColGroup b) {
+ this.mb = mb;
+ this.a = a;
+ this.b = b;
+
+ CompressedMatrixBlock.debug = true;
+ }
+
+ @Test
+ public void combine() {
+ try {
+ AColGroup c = a.combine(b, mb.getNumRows());
+ MatrixBlock ref = new MatrixBlock(mb.getNumRows(),
mb.getNumColumns(), false);
+ ref.allocateDenseBlock();
+ c.decompressToDenseBlock(ref.getDenseBlock(), 0,
mb.getNumRows());
+ ref.recomputeNonZeros();
+ String errMessage = a.getClass().getSimpleName() + ": "
+ a.getColIndices() + " -- "
+ + b.getClass().getSimpleName() + ": " +
b.getColIndices();
+
+ TestUtils.compareMatricesBitAvgDistance(mb, ref, 0, 0,
errMessage);
+ }
+ catch(NotImplementedException | DMLCompressionException e) {
+ // allowed
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ private static void addTwoCols(ArrayList<Object[]> tests, int nRow, int
distinct) {
+ MatrixBlock mb = TestUtils.ceil(//
+ TestUtils.generateTestMatrixBlock(nRow, 2, 0, distinct,
1.0, 231));
+
+ List<AColGroup> c1s = getGroups(mb, ColIndexFactory.createI(0));
+ List<AColGroup> c2s = getGroups(mb, ColIndexFactory.createI(1));
+
+ for(int i = 0; i < c1s.size(); i++) {
+ for(int j = 0; j < c2s.size(); j++) {
+ tests.add(new Object[] {mb, c1s.get(i),
c2s.get(j)});
+ }
+ }
+ }
+
+ private static void addSingleVSMultiCol(ArrayList<Object[]> tests, int
nRow, int distinct, int nColL, int nColR,
+ double sparsity) {
+ MatrixBlock mb = TestUtils.ceil(//
+ TestUtils.generateTestMatrixBlock(nRow, nColL + nColR,
0, distinct, sparsity, 231));
+
+ List<AColGroup> c1s = getGroups(mb,
ColIndexFactory.create(nColL));
+ List<AColGroup> c2s = getGroups(mb,
ColIndexFactory.create(nColL, nColR + nColL));
+
+ for(int i = 0; i < c1s.size(); i++) {
+ for(int j = 0; j < c2s.size(); j++) {
+ tests.add(new Object[] {mb, c1s.get(0),
c2s.get(0)});
+ }
+ }
+ }
+
+ private static List<AColGroup> getGroups(MatrixBlock mb, IColIndex
cols) {
+ final CompressionSettings cs = new
CompressionSettingsBuilder().create();
+
+ final int nRow = mb.getNumColumns();
+ final List<CompressedSizeInfoColGroup> es = new ArrayList<>();
+ final EstimationFactors f = new EstimationFactors(nRow, nRow,
mb.getSparsity());
+ es.add(new CompressedSizeInfoColGroup(cols, f, 312152,
CompressionType.DDC));
+ es.add(new CompressedSizeInfoColGroup(cols, f, 321521,
CompressionType.RLE));
+ es.add(new CompressedSizeInfoColGroup(cols, f, 321452,
CompressionType.SDC));
+ es.add(new CompressedSizeInfoColGroup(cols, f, 325151,
CompressionType.UNCOMPRESSED));
+ final CompressedSizeInfo csi = new CompressedSizeInfo(es);
+ return ColGroupFactory.compressColGroups(mb, csi, cs);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java
index 16d248c545..da5bc28505 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java
@@ -59,27 +59,37 @@ public abstract class SchemeTestBase {
TestUtils.compareMatricesBitAvgDistance(in, d, 0, 0);
}
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
- fail(e.getMessage());
+ fail(e.getMessage() + " " + sh);
}
}
@Test
public void testEncodeT() {
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 20, 0, distinct,
0.9, 7));
- AColGroup out = sh.encodeT(in);
- MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
- d.allocateBlock();
- out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
- d.recomputeNonZeros();
- TestUtils.compareMatricesBitAvgDistance(in,
LibMatrixReorg.transpose(d), 0, 0);
+ try {
+
+ MatrixBlock in = TestUtils
+
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 20, 0, distinct,
0.9, 7));
+ AColGroup out = sh.encodeT(in);
+ MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
+ d.allocateBlock();
+ out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
+ d.recomputeNonZeros();
+ TestUtils.compareMatricesBitAvgDistance(in,
LibMatrixReorg.transpose(d), 0, 0);
+ }
+ catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
+ e.printStackTrace();
+ fail(e.getMessage() + " " + sh);
+ }
}
@Test
public void testEncode_sparse() {
try {
-
MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(100, 100, 0, distinct, 0.05,
7));
AColGroup out = sh.encode(in);
MatrixBlock d = new MatrixBlock(in.getNumRows(),
src.getNumColumns(), false);
@@ -90,8 +100,10 @@ public abstract class SchemeTestBase {
TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0,
0);
}
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
- fail(e.getMessage());
+ fail(e.getMessage() + " " + sh);
}
}
@@ -109,8 +121,10 @@ public abstract class SchemeTestBase {
TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
}
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
- fail(e.getMessage());
+ fail(e.getMessage() + " " + sh);
}
}
@@ -137,10 +151,11 @@ public abstract class SchemeTestBase {
d.recomputeNonZeros();
TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0,
0);
}
-
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
- fail(e.getMessage());
+ fail(e.getMessage() + " " + sh);
}
}
@@ -173,88 +188,116 @@ public abstract class SchemeTestBase {
TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
}
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
- fail(e.getMessage());
+ fail(e.getMessage() + " " + sh);
}
}
@Test
public void testUpdateSparse() {
- MatrixBlock in = TestUtils
- .round(TestUtils.generateTestMatrixBlock(130,
src.getNumColumns() + 30, 0, distinct + 1, 0.1, 7));
- if(!in.isInSparseFormat())
- throw new RuntimeException();
try {
- sh.encode(in);
+
+ MatrixBlock in = TestUtils
+ .round(TestUtils.generateTestMatrixBlock(130,
src.getNumColumns() + 30, 0, distinct + 1, 0.1, 7));
+ if(!in.isInSparseFormat())
+ throw new RuntimeException();
+ try {
+ sh.encode(in);
+ }
+ catch(NullPointerException e) {
+ // all good expected
+ // we want to have an exception thrown if we
try to encode something that is not possible to encode.
+ }
+ ICLAScheme shc = sh.clone();
+ shc = shc.update(in);
+ AColGroup out = shc.encode(in); // should be possible
now.
+ MatrixBlock d = new MatrixBlock(in.getNumRows(),
src.getNumColumns(), false);
+ d.allocateBlock();
+ out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumRows());
+ MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1,
0, src.getNumColumns() - 1);
+ d.recomputeNonZeros();
+ TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0,
0);
}
- catch(NullPointerException e) {
- // all good expected
- // we want to have an exception thrown if we try to
encode something that is not possible to encode.
+ catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
+ e.printStackTrace();
+ fail(e.getMessage() + " " + sh);
}
- ICLAScheme shc = sh.clone();
- shc = shc.update(in);
- AColGroup out = shc.encode(in); // should be possible now.
- MatrixBlock d = new MatrixBlock(in.getNumRows(),
src.getNumColumns(), false);
- d.allocateBlock();
- out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumRows());
- MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0,
src.getNumColumns() - 1);
- d.recomputeNonZeros();
- TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0);
-
}
@Test
public void testUpdateSparseT() {
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 1000, 0, distinct
+ 1, 0.1, 7));
- if(!in.isInSparseFormat())
- throw new RuntimeException();
try {
- sh.encodeT(in);
+
+ MatrixBlock in = TestUtils
+
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 1000, 0, distinct
+ 1, 0.1, 7));
+ if(!in.isInSparseFormat())
+ throw new RuntimeException();
+ try {
+ sh.encodeT(in);
+ }
+ catch(NullPointerException e) {
+ // all good expected
+ // we want to have an exception thrown if we
try to encode something that is not possible to encode.
+ // but we can also not have an exception
thrown...
+ }
+ ICLAScheme shc = sh.clone();
+ shc = shc.updateT(in);
+
+ AColGroup out = shc.encodeT(in); // should be possible
now.
+ MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
+ d.allocateBlock();
+ out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
+ MatrixBlock inSlice = in.slice(0, src.getNumColumns() -
1, 0, in.getNumColumns() - 1);
+ d.recomputeNonZeros();
+ TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
}
- catch(NullPointerException e) {
- // all good expected
- // we want to have an exception thrown if we try to
encode something that is not possible to encode.
- // but we can also not have an exception thrown...
+ catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index"))
+ return; // all good
+ e.printStackTrace();
+ fail(e.getMessage());
}
- ICLAScheme shc = sh.clone();
- shc = shc.updateT(in);
-
- AColGroup out = shc.encodeT(in); // should be possible now.
- MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
- d.allocateBlock();
- out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
- MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0,
in.getNumColumns() - 1);
- d.recomputeNonZeros();
- TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
}
@Test
public void testUpdateSparseTEmptyColumn() {
- MatrixBlock in = new MatrixBlock(src.getNumColumns(), 100, 0.0);
- MatrixBlock b = new MatrixBlock(1, 100, 1.0);
- in = in.append(b, false);
- in.denseToSparse(true);
- if(!in.isInSparseFormat())
- throw new RuntimeException();
try {
- sh.encodeT(in);
+
+ MatrixBlock in = new MatrixBlock(src.getNumColumns(),
100, 0.0);
+ MatrixBlock b = new MatrixBlock(1, 100, 1.0);
+ in = in.append(b, false);
+ in.denseToSparse(true);
+ if(!in.isInSparseFormat())
+ throw new RuntimeException();
+ try {
+ sh.encodeT(in);
+ }
+ catch(NullPointerException e) {
+ // all good expected
+ // we want to have an exception thrown if we
try to encode something that is not possible to encode.
+ // but we can also not have an exception
thrown...
+ }
+ ICLAScheme shc = sh.clone();
+ shc = shc.updateT(in);
+
+ AColGroup out = shc.encodeT(in); // should be possible
now.
+ MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
+ d.allocateBlock();
+ out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
+ MatrixBlock inSlice = in.slice(0, src.getNumColumns() -
1, 0, in.getNumColumns() - 1);
+ d.recomputeNonZeros();
+ TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
}
- catch(NullPointerException e) {
- // all good expected
- // we want to have an exception thrown if we try to
encode something that is not possible to encode.
- // but we can also not have an exception thrown...
+ catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return; // all good expected exception
+ e.printStackTrace();
+ fail(e.getMessage());
}
- ICLAScheme shc = sh.clone();
- shc = shc.updateT(in);
-
- AColGroup out = shc.encodeT(in); // should be possible now.
- MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
- d.allocateBlock();
- out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
- MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0,
in.getNumColumns() - 1);
- d.recomputeNonZeros();
- TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
}
@Test
@@ -282,65 +325,85 @@ public abstract class SchemeTestBase {
TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0,
0);
}
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
- fail(e.getMessage());
+ fail(e.getMessage() + " " + sh);
}
}
@Test
public void testUpdateLargeBlockT() {
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 130, 0, distinct
+ 5, 1.0, 7));
- in = ReadersTestCompareReaders.createMock(in);
try {
- sh.encodeT(in);
- }
- catch(NullPointerException e) {
- // all good expected
- // we want to have an exception thrown if we try to
encode something that is not possible to encode.
- // but we can also not have an exception thrown...
- }
- ICLAScheme shc = sh.clone();
- shc = shc.updateT(in);
+ MatrixBlock in = TestUtils
+
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 130, 0, distinct
+ 5, 1.0, 7));
+ in = ReadersTestCompareReaders.createMock(in);
+ try {
+ sh.encodeT(in);
+ }
+ catch(NullPointerException e) {
+ // all good expected
+ // we want to have an exception thrown if we
try to encode something that is not possible to encode.
+ // but we can also not have an exception
thrown...
+ }
+ ICLAScheme shc = sh.clone();
- AColGroup out = shc.encodeT(in); // should be possible now.
- MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
- d.allocateBlock();
- out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
- MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0,
in.getNumColumns() - 1);
- d.recomputeNonZeros();
- TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
+ shc = shc.updateT(in);
+
+ AColGroup out = shc.encodeT(in); // should be possible
now.
+ MatrixBlock d = new MatrixBlock(in.getNumColumns(),
src.getNumColumns(), false);
+ d.allocateBlock();
+ out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumColumns());
+ MatrixBlock inSlice = in.slice(0, src.getNumColumns() -
1, 0, in.getNumColumns() - 1);
+ d.recomputeNonZeros();
+ TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
+ }
+ catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
+ e.printStackTrace();
+ fail(e.getMessage() + " " + sh);
+ }
}
@Test
public void testUpdateEmpty() {
- MatrixBlock in = new MatrixBlock(5, src.getNumColumns(), 0.0);
-
try {
- sh.encode(in);
+
+ MatrixBlock in = new MatrixBlock(5,
src.getNumColumns(), 0.0);
+
+ try {
+ sh.encode(in);
+ }
+ catch(NullPointerException e) {
+ // all good expected
+ // we want to have an exception thrown if we
try to encode something that is not possible to encode.
+ }
+ ICLAScheme shc = sh.clone();
+ shc = shc.update(in);
+ AColGroup out = shc.encode(in); // should be possible
now.
+ MatrixBlock d = new MatrixBlock(in.getNumRows(),
src.getNumColumns(), false);
+ d.allocateBlock();
+ out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumRows());
+ MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1,
0, src.getNumColumns() - 1);
+ d.recomputeNonZeros();
+ TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0,
0);
}
- catch(NullPointerException e) {
- // all good expected
- // we want to have an exception thrown if we try to
encode something that is not possible to encode.
+ catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
+ e.printStackTrace();
+ fail(e.getMessage() + " " + sh);
}
- ICLAScheme shc = sh.clone();
- shc = shc.update(in);
- AColGroup out = shc.encode(in); // should be possible now.
- MatrixBlock d = new MatrixBlock(in.getNumRows(),
src.getNumColumns(), false);
- d.allocateBlock();
- out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumRows());
- MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0,
src.getNumColumns() - 1);
- d.recomputeNonZeros();
- TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0);
-
}
@Test
public void testUpdateEmptyT() {
- MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0);
// 5 rows to encode transposed
+
+ MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0);
try {
sh.encodeT(in);
}
@@ -351,8 +414,6 @@ public abstract class SchemeTestBase {
}
ICLAScheme shc = sh.clone();
- shc = shc.updateT(in);
-
AColGroup out = shc.encodeT(in); // should be possible now.
// now we learned how to encode. lets decompress the encoded.
@@ -390,8 +451,10 @@ public abstract class SchemeTestBase {
TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0,
0);
}
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
- fail(e.getMessage());
+ fail(e.getMessage() + " " + sh);
}
}
@@ -400,6 +463,7 @@ public abstract class SchemeTestBase {
public void testUpdateEmptyMyColsT() {
MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0);
in = in.append(new MatrixBlock(src.getNumColumns(), 1, 1.0),
true);
+
try {
sh.encodeT(in);
}
@@ -431,16 +495,14 @@ public abstract class SchemeTestBase {
@Test
public void testUpdateAndEncode() {
double newVal = distinct + 4;
- MatrixBlock in = TestUtils
- .round(TestUtils.generateTestMatrixBlock(100,
src.getNumColumns(), 0, newVal, 1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0,
newVal, 1.0, 7));
testUpdateAndEncode(in);
}
@Test
public void testUpdateAndEncodeT() {
double newVal = distinct + 4;
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal,
1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0,
newVal, 1.0, 7));
testUpdateAndEncodeT(in);
}
@@ -455,8 +517,7 @@ public abstract class SchemeTestBase {
@Test
public void testUpdateAndEncodeSparseT() {
double newVal = distinct + 4;
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal,
0.1, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0,
newVal, 0.1, 7));
testUpdateAndEncodeT(in);
}
@@ -472,8 +533,7 @@ public abstract class SchemeTestBase {
@Test
public void testUpdateAndEncodeLarge() {
double newVal = distinct + 4;
- MatrixBlock in = TestUtils
- .round(TestUtils.generateTestMatrixBlock(100,
src.getNumColumns(), 0, newVal, 1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0,
newVal, 1.0, 7));
in = ReadersTestCompareReaders.createMock(in);
testUpdateAndEncode(in);
@@ -482,8 +542,7 @@ public abstract class SchemeTestBase {
@Test
public void testUpdateAndEncodeLargeT() {
double newVal = distinct + 4;
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal,
1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0,
newVal, 1.0, 7));
in = ReadersTestCompareReaders.createMock(in);
testUpdateAndEncodeT(in);
}
@@ -491,16 +550,14 @@ public abstract class SchemeTestBase {
@Test
public void testUpdateAndEncodeManyNew() {
double newVal = distinct + 300;
- MatrixBlock in = TestUtils
- .round(TestUtils.generateTestMatrixBlock(100,
src.getNumColumns(), 0, newVal, 1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0,
newVal, 1.0, 7));
testUpdateAndEncode(in);
}
@Test
public void testUpdateAndEncodeTManyNew() {
double newVal = distinct + 300;
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal,
1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0,
newVal, 1.0, 7));
testUpdateAndEncodeT(in);
}
@@ -515,16 +572,14 @@ public abstract class SchemeTestBase {
@Test
public void testUpdateAndEncodeSparseTManyNew() {
double newVal = distinct + 300;
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal,
0.1, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0,
newVal, 0.1, 7));
testUpdateAndEncodeT(in);
}
@Test
public void testUpdateAndEncodeLargeManyNew() {
double newVal = distinct + 300;
- MatrixBlock in = TestUtils
- .round(TestUtils.generateTestMatrixBlock(100,
src.getNumColumns(), 0, newVal, 1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(100, src.getNumColumns(), 0,
newVal, 1.0, 7));
in = ReadersTestCompareReaders.createMock(in);
testUpdateAndEncode(in);
@@ -533,8 +588,7 @@ public abstract class SchemeTestBase {
@Test
public void testUpdateAndEncodeLargeTManyNew() {
double newVal = distinct + 300;
- MatrixBlock in = TestUtils
-
.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0, newVal,
1.0, 7));
+ MatrixBlock in =
TestUtils.round(TestUtils.generateTestMatrixBlock(src.getNumColumns(), 100, 0,
newVal, 1.0, 7));
in = ReadersTestCompareReaders.createMock(in);
testUpdateAndEncodeT(in);
}
@@ -566,14 +620,23 @@ public abstract class SchemeTestBase {
}
public void testUpdateAndEncode(MatrixBlock in) {
- Pair<ICLAScheme, AColGroup> r = sh.clone().updateAndEncode(in);
- AColGroup out = r.getValue();
- MatrixBlock d = new MatrixBlock(in.getNumRows(),
src.getNumColumns(), false);
- d.allocateBlock();
- out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumRows());
- MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0,
src.getNumColumns() - 1);
- d.recomputeNonZeros();
- TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0, 0);
+ try {
+
+ Pair<ICLAScheme, AColGroup> r =
sh.clone().updateAndEncode(in);
+ AColGroup out = r.getValue();
+ MatrixBlock d = new MatrixBlock(in.getNumRows(),
src.getNumColumns(), false);
+ d.allocateBlock();
+ out.decompressToDenseBlock(d.getDenseBlock(), 0,
in.getNumRows());
+ MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1,
0, src.getNumColumns() - 1);
+ d.recomputeNonZeros();
+ TestUtils.compareMatricesBitAvgDistance(inSlice, d, 0,
0);
+ }
+ catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
+ e.printStackTrace();
+ fail(e.getMessage() + " " + sh);
+ }
}
public void testUpdateAndEncodeT(MatrixBlock in) {
@@ -588,6 +651,8 @@ public abstract class SchemeTestBase {
TestUtils.compareMatricesBitAvgDistance(inSlice,
LibMatrixReorg.transpose(d), 0, 0);
}
catch(Exception e) {
+ if(e.getMessage().contains("Invalid SDC group that
contains index with size == numRows"))
+ return;// all good
e.printStackTrace();
fail(e.getMessage() + " " + sh);
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java
index 1f7c872b0f..064f10e9f3 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestSDC.java
@@ -85,7 +85,6 @@ public class SchemeTestSDC extends SchemeTestBase {
catch(Exception e) {
e.printStackTrace();
fail(e.getMessage());
- throw new RuntimeException();
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java
index edeadecbf1..79f50f9898 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java
@@ -648,4 +648,17 @@ public class CustomDictionaryTest {
public void notEqualsObject(){
assertNotEquals(Dictionary.create(new double[]{1.1,2.2,3.3}),
new Object());
}
+
+
+ @Test
+ public void createIdentity_1(){
+
+ assertTrue(IdentityDictionary.create(1) instanceof Dictionary);
+ }
+
+ @Test
+ public void createIdentity_2(){
+
+ assertTrue(IdentityDictionary.create(1, true) instanceof
Dictionary);
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
index 8ef384e45a..3dd48636ae 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
@@ -377,6 +377,64 @@ public class DictionaryTests {
}
}
+ @Test
+ public void getNumValues() {
+ assertEquals(a.getNumberOfValues(nCol),
b.getNumberOfValues(nCol));
+ }
+
+ @Test(expected = Exception.class)
+ public void getNumValuesInvalida() {
+ a.getNumberOfValues(nCol + 1);
+ throw new RuntimeException("pass");
+ }
+
+ @Test(expected = Exception.class)
+ public void getNumValuesInvalidb() {
+ b.getNumberOfValues(nCol + 1);
+ throw new RuntimeException("pass");
+ }
+
+ @Test
+ public void getNumColumns() {
+ assertEquals(a.getNumberOfColumns(nRow),
b.getNumberOfColumns(nRow));
+ }
+
+ @Test(expected = Exception.class)
+ public void getNumColumnsInvalida() {
+ a.getNumberOfColumns(nRow + 1);
+ throw new RuntimeException("pass");
+ }
+
+ @Test(expected = Exception.class)
+ public void getNumColumnsInvalidb() {
+ b.getNumberOfColumns(nRow + 1);
+ throw new RuntimeException("pass");
+ }
+
+ @Test(expected = Exception.class)
+ public void outOfRange1() {
+ assertEquals(0, a.getValue(nRow, nCol - 1, nCol), 0.0);
+ throw new RuntimeException();
+ }
+
+ @Test(expected = Exception.class)
+ public void outOfRange2() {
+ assertEquals(0, b.getValue(nRow, nCol - 1, nCol), 0.0);
+ throw new RuntimeException();
+ }
+
+ @Test(expected = Exception.class)
+ public void outOfRange3() {
+ assertEquals(0, a.getValue(nRow * nCol + 1), 0.0);
+ throw new RuntimeException();
+ }
+
+ @Test(expected = Exception.class)
+ public void outOfRange4() {
+ assertEquals(0, b.getValue(nRow * nCol + 1), 0.0);
+ throw new RuntimeException();
+ }
+
@Test
public void getDictType() {
assertNotEquals(a.getDictType(), b.getDictType());
@@ -726,6 +784,13 @@ public class DictionaryTests {
compare(ad, bd, nRow, e - s);
}
+ @Test
+ public void sliceOutEverything() {
+ IDictionary ad = a.sliceOutColumnRange(0, nCol, nCol);
+ IDictionary bd = b.sliceOutColumnRange(0, nCol, nCol);
+ compare(ad, bd, nRow, nCol);
+ }
+
@Test
public void contains1() {
containsValue(1);
@@ -1184,6 +1249,61 @@ public class DictionaryTests {
}
+ @Test
+ public void rightMMPreAggUltraSparse() {
+ final int nColsOut = 30;
+ MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000,
nColsOut, -10, 10, 0.001, 100);
+ sparse = TestUtils.ceil(sparse);
+ sparse.denseToSparse(true);
+ SparseBlock sb = sparse.getSparseBlock();
+ if(sb == null)
+ throw new NotImplementedException();
+
+ IColIndex agCols = new RangeIndex(nColsOut);
+ IColIndex thisCols = new RangeIndex(0, nCol);
+
+ int nVals = a.getNumberOfValues(nCol);
+ try {
+
+ IDictionary aa = a.rightMMPreAggSparse(nVals, sb,
thisCols, agCols, nColsOut);
+ IDictionary bb = b.rightMMPreAggSparse(nVals, sb,
thisCols, agCols, nColsOut);
+ compare(aa, bb, nColsOut);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+
+ }
+
+
+ @Test
+ public void rightMMPreAggUltraSparseTwoOut() {
+ final int nColsOut = 2;
+ MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000,
nColsOut, -10, 10, 0.001, 100);
+ sparse = TestUtils.ceil(sparse);
+ sparse.denseToSparse(true);
+ SparseBlock sb = sparse.getSparseBlock();
+ if(sb == null)
+ throw new NotImplementedException();
+
+ IColIndex agCols = new RangeIndex(nColsOut);
+ IColIndex thisCols = new RangeIndex(0, nCol);
+
+ int nVals = a.getNumberOfValues(nCol);
+ try {
+
+ IDictionary aa = a.rightMMPreAggSparse(nVals, sb,
thisCols, agCols, nColsOut);
+ IDictionary bb = b.rightMMPreAggSparse(nVals, sb,
thisCols, agCols, nColsOut);
+ compare(aa, bb, nColsOut);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+
+ }
+
@Test
public void rightMMPreAggSparse2() {
final int nColsOut = 1000;
@@ -1238,6 +1358,35 @@ public class DictionaryTests {
}
+
+ @Test
+ public void rightMMPreAggSparseDifferentColumnsUltraSparse() {
+ final int nColsOut = 3;
+ MatrixBlock sparse = TestUtils.generateTestMatrixBlock(1000,
50, -10, 10, 0.001, 100);
+ sparse = TestUtils.ceil(sparse);
+ sparse.denseToSparse(true);
+ SparseBlock sb = sparse.getSparseBlock();
+ if(sb == null)
+ throw new NotImplementedException();
+
+ IColIndex agCols = new ArrayIndex(new int[] {4, 10, 38});
+ IColIndex thisCols = new RangeIndex(0, nCol);
+
+ int nVals = a.getNumberOfValues(nCol);
+ try {
+
+ IDictionary aa = a.rightMMPreAggSparse(nVals, sb,
thisCols, agCols, 50);
+ IDictionary bb = b.rightMMPreAggSparse(nVals, sb,
thisCols, agCols, 50);
+ compare(aa, bb, nColsOut);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+
+ }
+
+
@Test
public void MMDictScalingDense() {
double[] left =
TestUtils.ceil(TestUtils.generateTestVector(a.getNumberOfValues(nCol) * 3, -10,
10, 1.0, 3214));
diff --git
a/src/test/java/org/apache/sysds/test/functions/compress/table/CompressedTableOverwriteTest.java
b/src/test/java/org/apache/sysds/test/functions/compress/table/CompressedTableOverwriteTest.java
new file mode 100644
index 0000000000..11bf1b394e
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/compress/table/CompressedTableOverwriteTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.compress.table;
+
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.io.File;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class CompressedTableOverwriteTest extends AutomatedTestBase {
+ protected static final Log LOG =
LogFactory.getLog(CompressedTableOverwriteTest.class.getName());
+
+ private final static String TEST_DIR = "functions/compress/table/";
+
+ protected String getTestClassDir() {
+ return getTestDir() + this.getClass().getSimpleName() + "/";
+ }
+
+ protected String getTestName() {
+ return "table";
+ }
+
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+
+ @Test
+ public void testRewireTable_2() {
+ rewireTableTest(10, 2, 0.2, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testRewireTable_20() {
+ rewireTableTest(30, 20, 0.2, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testRewireTable_80() {
+ rewireTableTest(100, 80, 0.2, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testRewireTable_80_1000() {
+ rewireTableTest(1000, 80, 0.2, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testRewireTable_80_1000_dense() {
+ rewireTableTest(1000, 80, 1.0, ExecType.CP, "01");
+ }
+
+
+ public void rewireTableTest(int rows, int unique, double sparsity,
ExecType instType, String name) {
+
+ OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ CompressedMatrixBlock.debug = true;
+ CompressedMatrixBlock.allowCachingUncompressed = false;
+ try {
+
+ super.setOutputBuffering(true);
+
loadTestConfiguration(getTestConfiguration(getTestName()));
+ fullDMLScriptName = SCRIPT_DIR + "/" +
getTestClassDir() + name + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
"rows=" + rows, "unique=" + unique,
+ "sparsity=" + sparsity};
+ String s = runTest(null).toString();
+
+ if(s.contains("Failed"))
+ fail(s);
+ // else
+ // LOG.debug(s);
+
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ assertTrue("Exception in execution: " + e.getMessage(),
false);
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName()));
+ }
+
+ @Override
+ protected File getConfigTemplateFile() {
+ return new
File("./src/test/scripts/functions/compress/SystemDS-config-compress.xml");
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/compress/wordembedding/wordEmbeddingUseCase.java
b/src/test/java/org/apache/sysds/test/functions/compress/wordembedding/wordEmbeddingUseCase.java
new file mode 100644
index 0000000000..b52ffb0764
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/compress/wordembedding/wordEmbeddingUseCase.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.compress.wordembedding;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.io.File;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import
org.apache.sysds.test.functions.compress.table.CompressedTableOverwriteTest;
+import org.junit.Test;
+
+public class wordEmbeddingUseCase extends AutomatedTestBase {
+
+ protected static final Log LOG =
LogFactory.getLog(CompressedTableOverwriteTest.class.getName());
+
+ private final static String TEST_DIR =
"functions/compress/wordembedding/";
+
+ protected String getTestClassDir() {
+ return getTestDir();
+ }
+
+ protected String getTestName() {
+ return "wordembedding";
+ }
+
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+
+ @Test
+ public void testWordEmb() {
+ wordEmb(10, 2, 2, 2, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testWordEmb_medium() {
+ wordEmb(100, 30, 4, 3, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testWordEmb_bigWords() {
+ wordEmb(10, 2, 2, 10, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testWordEmb_longSentences() {
+ wordEmb(100, 30, 5, 2, ExecType.CP, "01");
+ }
+
+ @Test
+ public void testWordEmb_moreUniqueWordsThanSentences() {
+ wordEmb(100, 200, 5, 2, ExecType.CP, "01");
+ }
+
+
+ public void wordEmb(int rows, int unique, int l, int embeddingSize,
ExecType instType, String name) {
+
+ OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ CompressedMatrixBlock.debug = true;
+
+ try {
+ super.setOutputBuffering(true);
+
loadTestConfiguration(getTestConfiguration(getTestName()));
+ fullDMLScriptName = SCRIPT_DIR + getTestClassDir() +
name + ".dml";
+
+ programArgs = new String[] {"-stats", "100", "-args",
input("X"), input("W"), "" + l, output("R")};
+
+ MatrixBlock X = TestUtils.generateTestMatrixBlock(rows,
1, 1, unique + 1, 1.0, 32);
+ X = TestUtils.floor(X);
+ writeBinaryWithMTD("X", X);
+
+ MatrixBlock W =
TestUtils.generateTestMatrixBlock(unique, embeddingSize, 1.0, -1, 1, 32);
+ writeBinaryWithMTD("W", W);
+
+ runTest(null);
+
+ MatrixBlock R = TestUtils.readBinary(output("R"));
+
+ analyzeResult(X, W, R, l);
+
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ assertTrue("Exception in execution: " + e.getMessage(),
false);
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ private void analyzeResult(MatrixBlock X, MatrixBlock W, MatrixBlock R,
int l){
+ for(int i = 0; i < X.getNumRows(); i++){
+ // for each row in X, it should embed with a W, in
accordance to what value it used
+
+ // the entry to look into W. // as in row
+ int e = UtilFunctions.toInt(X.get(i,0)) -1;
+ int rowR = i / l;
+ int offR = i % l;
+
+ for(int j = 0; j < W.getNumColumns(); j++){
+ assertEquals(R.get(rowR, offR*
W.getNumColumns() + j), W.get(e, j), 0.0);
+ }
+ }
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new
TestConfiguration(getTestClassDir(), getTestName()));
+ }
+
+ @Override
+ protected File getConfigTemplateFile() {
+ return new
File("./src/test/scripts/functions/compress/SystemDS-config-compress.xml");
+ }
+
+}