This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a25c111401 [MINOR] Dedicated tests for CLA Dictionaries
a25c111401 is described below

commit a25c111401abe0c8bb222555e27a2e55bbe80e9a
Author: baunsgaard <[email protected]>
AuthorDate: Fri Oct 28 15:26:28 2022 +0200

    [MINOR] Dedicated tests for CLA Dictionaries
    
    This commits adds a bunch of dedicated tests for dictionaries in CLA.
    This is the beginning of adding new Dictionary types for specialized,
    formats such as the previous UInt8, that can be reintegrated with this.
    
    Close #1715
---
 .../runtime/compress/colgroup/AColGroupValue.java  |  14 +-
 .../runtime/compress/colgroup/ColGroupConst.java   |   4 +-
 .../runtime/compress/colgroup/ColGroupDDC.java     |   4 +-
 .../runtime/compress/colgroup/ColGroupDDCFOR.java  |   2 +-
 .../runtime/compress/colgroup/ColGroupFactory.java |   3 +
 .../runtime/compress/colgroup/ColGroupSDC.java     |   4 +-
 .../runtime/compress/colgroup/ColGroupSDCFOR.java  |   2 +-
 .../compress/colgroup/ColGroupSDCSingle.java       |   2 +-
 .../compress/colgroup/ColGroupSDCSingleZeros.java  |   2 +-
 .../compress/colgroup/ColGroupSDCZeros.java        |   2 +-
 .../compress/colgroup/ColGroupUncompressed.java    |   2 +-
 .../compress/colgroup/dictionary/ADictionary.java  |  24 +-
 .../compress/colgroup/dictionary/Dictionary.java   |  60 +--
 .../colgroup/dictionary/DictionaryFactory.java     |   8 +-
 .../colgroup/dictionary/MatrixBlockDictionary.java | 160 +++-----
 .../compress/colgroup/dictionary/QDictionary.java  |   8 +-
 .../runtime/compress/lib/CLALibBinaryCellOp.java   |   5 +-
 .../compress/colgroup/ColGroupFactoryTest.java     |  32 +-
 .../compress/dictionary/CustomDictionaryTest.java  | 155 +++++++
 .../compress/dictionary/DeltaDictionaryTest.java   |  10 +
 .../compress/dictionary/DictionaryTest.java        |  47 ---
 .../compress/dictionary/DictionaryTests.java       | 456 +++++++++++++++++++++
 22 files changed, 784 insertions(+), 222 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
index 601a105d99..23abdf42cf 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.compress.colgroup;
 
 import java.lang.ref.SoftReference;
 
+import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
 import org.apache.sysds.runtime.compress.utils.Util;
 import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -189,11 +190,16 @@ public abstract class AColGroupValue extends 
ADictBasedColGroup {
 
        @Override
        public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int 
nRows) {
-               ADictionary d = _dict.rexpandCols(max, ignore, cast, 
_colIndexes.length);
-               if(d == null)
+               try {
+                       ADictionary d = _dict.rexpandCols(max, ignore, cast, 
_colIndexes.length);
+                       if(d == null)
+                               return ColGroupEmpty.create(max);
+                       else
+                               return copyAndSet(Util.genColsIndices(max), d);
+               }
+               catch(DMLCompressionException e) {
                        return ColGroupEmpty.create(max);
-               else
-                       return copyAndSet(Util.genColsIndices(max), d);
+               }
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
index 0fa17e6a0b..79ad64f332 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
@@ -543,7 +543,7 @@ public class ColGroupConst extends ADictBasedColGroup {
        @Override
        public AColGroup append(AColGroup g) {
                if(g instanceof ColGroupConst && g._colIndexes.length == 
_colIndexes.length &&
-                       ((ColGroupConst) g)._dict.eq(_dict))
+                       ((ColGroupConst) g)._dict.equals(_dict))
                        return this;
                return null;
        }
@@ -551,7 +551,7 @@ public class ColGroupConst extends ADictBasedColGroup {
        @Override
        public AColGroup appendNInternal(AColGroup[] g) {
                for(int i = 0; i < g.length; i++)
-                       if(!Arrays.equals(_colIndexes, g[i]._colIndexes) || 
!this._dict.eq(((ColGroupConst) g[i])._dict))
+                       if(!Arrays.equals(_colIndexes, g[i]._colIndexes) || 
!this._dict.equals(((ColGroupConst) g[i])._dict))
                                return null;
                return this;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index 4a1dea31db..a7ec0d8b6a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -497,7 +497,7 @@ public class ColGroupDDC extends APreAgg implements 
AMapToDataGroup {
                        if(Arrays.equals(g.getColIndices(), _colIndexes)) {
 
                                ColGroupDDC gDDC = (ColGroupDDC) g;
-                               if(gDDC._dict.eq(_dict)) {
+                               if(gDDC._dict.equals(_dict)) {
                                        AMapToData nd = 
_data.append(gDDC._data);
                                        return create(_colIndexes, _dict, nd, 
null);
                                }
@@ -528,7 +528,7 @@ public class ColGroupDDC extends APreAgg implements 
AMapToDataGroup {
                        }
 
                        final ColGroupDDC gDDC = (ColGroupDDC) g[i];
-                       if(!gDDC._dict.eq(_dict)) {
+                       if(!gDDC._dict.equals(_dict)) {
                                LOG.warn("Not same Dictionaries therefore not 
appending DDC\n" + _dict + "\n\n" + gDDC._dict);
                                return null;
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
index 060915fbb5..0a58c78f29 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
@@ -438,7 +438,7 @@ public class ColGroupDDCFOR extends AMorphingMMColGroup {
        public AColGroup append(AColGroup g) {
                if(g instanceof ColGroupDDCFOR && 
Arrays.equals(g.getColIndices(), _colIndexes)) {
                        ColGroupDDCFOR gDDC = (ColGroupDDCFOR) g;
-                       if(Arrays.equals(_reference , gDDC._reference) && 
gDDC._dict.eq(_dict)){
+                       if(Arrays.equals(_reference , gDDC._reference) && 
gDDC._dict.equals(_dict)){
                                AMapToData nd = _data.append(gDDC._data);
                                return create(_colIndexes, _dict, nd, null, 
_reference);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 1f94655521..178527d2b2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -321,6 +321,9 @@ public class ColGroupFactory {
                else
                        readToMapDDC(col, map, d);
 
+               if(map.size() == 0)
+                       return new ColGroupEmpty(colIndexes);
+               
                ADictionary dict = DictionaryFactory.create(map);
                final int nUnique = map.size();
                final AMapToData resData = MapToFactory.resize(d, nUnique);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index db9cc6f1a5..d1c12f535c 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -575,7 +575,7 @@ public class ColGroupSDC extends ASDC implements 
AMapToDataGroup {
        public AColGroup append(AColGroup g) {
                if(g instanceof ColGroupSDC && Arrays.equals(g.getColIndices(), 
_colIndexes)) {
                        final ColGroupSDC gSDC = (ColGroupSDC) g;
-                       if(Arrays.equals(_defaultTuple, gSDC._defaultTuple) && 
gSDC._dict.eq(_dict)) {
+                       if(Arrays.equals(_defaultTuple, gSDC._defaultTuple) && 
gSDC._dict.equals(_dict)) {
                                final AMapToData nd = _data.append(gSDC._data);
                                final AOffset ofd = 
_indexes.append(gSDC._indexes, getNumRows());
                                return create(_colIndexes, _numRows + 
gSDC._numRows, _dict, _defaultTuple, ofd, nd, null);
@@ -600,7 +600,7 @@ public class ColGroupSDC extends ASDC implements 
AMapToDataGroup {
                        }
 
                        final ColGroupSDC gc = (ColGroupSDC) g[i];
-                       if(!gc._dict.eq(_dict)) {
+                       if(!gc._dict.equals(_dict)) {
                                LOG.warn("Not same Dictionaries therefore not 
appending \n" + _dict + "\n\n" + gc._dict);
                                return null;
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
index dd953c283a..276d5703aa 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
@@ -466,7 +466,7 @@ public class ColGroupSDCFOR extends ASDC implements 
AMapToDataGroup {
                        }
 
                        final ColGroupSDCFOR gc = (ColGroupSDCFOR) g[i];
-                       if(!gc._dict.eq(_dict)) {
+                       if(!gc._dict.equals(_dict)) {
                                LOG.warn("Not same Dictionaries therefore not 
appending \n" + _dict + "\n\n" + gc._dict);
                                return null;
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index 739eb33379..60b567996a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -598,7 +598,7 @@ public class ColGroupSDCSingle extends ASDC {
                        }
 
                        final ColGroupSDCSingle gc = (ColGroupSDCSingle) g[i];
-                       if(!gc._dict.eq(_dict)) {
+                       if(!gc._dict.equals(_dict)) {
                                LOG.warn("Not same Dictionaries therefore not 
appending \n" + _dict + "\n\n" + gc._dict);
                                return null;
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
index e4a53b3cfb..bc6756b18c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
@@ -831,7 +831,7 @@ public class ColGroupSDCSingleZeros extends ASDCZero {
                        }
 
                        final ColGroupSDCSingleZeros gc = 
(ColGroupSDCSingleZeros) g[i];
-                       if(!gc._dict.eq(_dict)) {
+                       if(!gc._dict.equals(_dict)) {
                                LOG.warn("Not same Dictionaries therefore not 
appending \n" + _dict + "\n\n" + gc._dict);
                                return null;
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index f926cdac6f..ae2ecd334e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -747,7 +747,7 @@ public class ColGroupSDCZeros extends ASDCZero implements 
AMapToDataGroup{
                        }
 
                        final ColGroupSDCZeros gc = (ColGroupSDCZeros) g[i];
-                       if(!gc._dict.eq(_dict)) {
+                       if(!gc._dict.equals(_dict)) {
                                LOG.warn("Not same Dictionaries therefore not 
appending \n" + _dict + "\n\n" + gc._dict);
                                return null;
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index 80756c67f0..fb2550ece7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -358,7 +358,7 @@ public class ColGroupUncompressed extends AColGroup {
        public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean 
isRowSafe) {
                LOG.warn("Binary row op left is not supported for Uncompressed 
Matrix, "
                        + "Implement support for VMr in MatrixBlock Binary Cell 
operations");
-               MatrixBlockDictionary d = new MatrixBlockDictionary(_data);
+               MatrixBlockDictionary d = MatrixBlockDictionary.create(_data);
                ADictionary dm = d.binOpLeft(op, v, _colIndexes);
                if(dm == null)
                        return create(null, _colIndexes);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index 9c30b71f6f..dd9557dc41 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -43,6 +43,10 @@ public abstract class ADictionary implements Serializable {
 
        protected static final Log LOG = 
LogFactory.getLog(ADictionary.class.getName());
 
+       public static enum DictType {
+               Delta, Dict, MatrixBlock, UInt8;
+       }
+
        /**
         * Get all the values contained in the dictionary as a linearized 
double array.
         * 
@@ -311,11 +315,11 @@ public abstract class ADictionary implements Serializable 
{
        public abstract long getExactSizeOnDisk();
 
        /**
-        * Specify if the Dictionary is lossy.
+        * Get the dictionary type this dictionary is.
         * 
-        * @return A boolean
+        * @return The Dictionary type this is.
         */
-       public abstract boolean isLossy();
+       public abstract DictType getDictType();
 
        /**
         * Get the number of distinct tuples given that the column group has n 
columns
@@ -668,7 +672,7 @@ public abstract class ADictionary implements Serializable {
         * @param nRows  The number of rows in total of the column group
         * @return The central moment Object
         */
-       public CM_COV_Object centralMoment(ValueFunction fn, int[] counts, int 
nRows) {
+       public final CM_COV_Object centralMoment(ValueFunction fn, int[] 
counts, int nRows) {
                return centralMoment(new CM_COV_Object(), fn, counts, nRows);
        }
 
@@ -694,7 +698,7 @@ public abstract class ADictionary implements Serializable {
         * @param nRows  The number of rows in total of the column group
         * @return The central moment Object
         */
-       public CM_COV_Object centralMomentWithDefault(ValueFunction fn, int[] 
counts, double def, int nRows) {
+       public final CM_COV_Object centralMomentWithDefault(ValueFunction fn, 
int[] counts, double def, int nRows) {
                return centralMomentWithDefault(new CM_COV_Object(), fn, 
counts, def, nRows);
        }
 
@@ -722,7 +726,7 @@ public abstract class ADictionary implements Serializable {
         * @param nRows     The number of rows in total of the column group
         * @return The central moment Object
         */
-       public CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] 
counts, double reference, int nRows) {
+       public final CM_COV_Object centralMomentWithReference(ValueFunction fn, 
int[] counts, double reference, int nRows) {
                return centralMomentWithReference(new CM_COV_Object(), fn, 
counts, reference, nRows);
        }
 
@@ -890,7 +894,7 @@ public abstract class ADictionary implements Serializable {
        protected abstract void TSMMToUpperTriangleSparseScaling(SparseBlock 
left, int[] rowsLeft, int[] colsRight,
                int[] scale, MatrixBlock result);
 
-       protected String doubleToString(double v) {
+       protected static String doubleToString(double v) {
                if(v == (long) v)
                        return Long.toString(((long) v));
                else
@@ -905,11 +909,11 @@ public abstract class ADictionary implements Serializable 
{
        }
 
        @Override
-       public boolean equals(Object o) {
+       public final boolean equals(Object o) {
                if(o instanceof ADictionary)
-                       return eq((ADictionary) o);
+                       return equals((ADictionary) o);
                return false;
        }
 
-       public abstract boolean eq(ADictionary o);
+       public abstract boolean equals(ADictionary o);
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index 852cc733c1..552e262e57 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -26,6 +26,7 @@ import java.math.BigDecimal;
 import java.math.MathContext;
 import java.util.Arrays;
 
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -51,12 +52,14 @@ public class Dictionary extends ADictionary {
        protected final double[] _values;
 
        protected Dictionary(double[] values) {
-               if(values == null || values.length == 0)
-                       throw new DMLCompressionException("Invalid construction 
of dictionary with null array");
                _values = values;
        }
 
        public static Dictionary create(double[] values) {
+               if(values == null)
+                       throw new DMLCompressionException("Invalid construction 
of dictionary with null array");
+               else if(values.length == 0)
+                       throw new DMLCompressionException("Invalid construction 
of dictionary with empty array");
                boolean nonZero = false;
                for(double d : values) {
                        if(d != 0) {
@@ -633,17 +636,26 @@ public class Dictionary extends ADictionary {
 
        @Override
        public String toString() {
-               StringBuilder sb = new StringBuilder();
-
+               StringBuilder sb = new StringBuilder(_values.length * 3 + 10);
                sb.append("Dictionary : ");
-               sb.append(Arrays.toString(_values));
+               stringArray(sb, _values);
                return sb.toString();
        }
 
+       private static void stringArray(StringBuilder sb, double[] val) {
+               sb.append("[");
+               sb.append(doubleToString(val[0]));
+               for(int i = 1; i < val.length; i++) {
+                       sb.append(", ");
+                       sb.append(doubleToString(val[i]));
+               }
+               sb.append("]");
+       }
+
        public String getString(int colIndexes) {
                StringBuilder sb = new StringBuilder();
                if(colIndexes == 1)
-                       sb.append(Arrays.toString(_values));
+                       stringArray(sb, _values);
                else {
                        sb.append("[\n\t");
                        for(int i = 0; i < _values.length - 1; i++) {
@@ -771,8 +783,8 @@ public class Dictionary extends ADictionary {
        }
 
        @Override
-       public boolean isLossy() {
-               return false;
+       public DictType getDictType() {
+               return DictType.Dict;
        }
 
        @Override
@@ -858,9 +870,10 @@ public class Dictionary extends ADictionary {
                int off = 0;
                for(int i = 0; i < nRow; i++) {
                        for(int j = 0; j < nCol; j++) {
+                               final double ref = reference[j];
                                final double v = _values[off];
-                               retV[off++] = v + reference[j] == pattern ? 
replace - reference[j] : v;
-
+                               retV[off] = Math.abs(v + ref - pattern) < 
0.000001 ? replace - ref : v;
+                               off++;
                        }
                }
                return create(retV);
@@ -980,27 +993,28 @@ public class Dictionary extends ADictionary {
 
        @Override
        public ADictionary rexpandCols(int max, boolean ignore, boolean cast, 
int nCol) {
-               MatrixBlockDictionary a = getMBDict(nCol);
-               if(a == null)
-                       return null;
-               return a.rexpandCols(max, ignore, cast, nCol);
+               if(nCol > 1)
+                       throw new DMLCompressionException("Invalid to rexpand 
the column groups if more than one column");
+               MatrixBlockDictionary m = getMBDict(nCol);
+               return m == null ? null : m.rexpandCols(max, ignore, cast, 
nCol);
        }
 
        @Override
        public ADictionary rexpandColsWithReference(int max, boolean ignore, 
boolean cast, int reference) {
-               MatrixBlockDictionary a = getMBDict(1);
-               if(a == null)
-                       a = new MatrixBlockDictionary(new 
MatrixBlock(_values.length, 1, (double) reference));
-               else
-                       a = (MatrixBlockDictionary) a.applyScalarOp(new 
LeftScalarOperator(Plus.getPlusFnObject(), reference));
-               if(a == null)
+               MatrixBlockDictionary m = getMBDict(1);
+               if(m == null)
                        return null;
-               return a.rexpandCols(max, ignore, cast, 1);
+               ADictionary a = m.applyScalarOp(new 
LeftScalarOperator(Plus.getPlusFnObject(), reference));
+               return a == null ? null : a.rexpandCols(max, ignore, cast, 1);
        }
 
        @Override
        public double getSparsity() {
-               return 1;
+               int zeros = 0;
+               for(double v : _values)
+                       if(v == 0.0)
+                               zeros++;
+               return OptimizerUtils.getSparsity(_values.length, 1L, 
_values.length - zeros);
        }
 
        @Override
@@ -1066,7 +1080,7 @@ public class Dictionary extends ADictionary {
        }
 
        @Override
-       public boolean eq(ADictionary o) {
+       public boolean equals(ADictionary o) {
                if(o instanceof Dictionary)
                        return Arrays.equals(_values, ((Dictionary) o)._values);
                else if(o instanceof MatrixBlockDictionary) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
index c9bc6524bc..a777201d00 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
@@ -85,7 +85,7 @@ public interface DictionaryFactory {
                                }
                                retB.recomputeNonZeros();
                                retB.examSparsity(true);
-                               return new MatrixBlockDictionary(retB);
+                               return MatrixBlockDictionary.create(retB);
                        }
                        else {
 
@@ -130,7 +130,7 @@ public interface DictionaryFactory {
                        }
                        m.recomputeNonZeros();
                        m.examSparsity(true);
-                       return new MatrixBlockDictionary(m);
+                       return MatrixBlockDictionary.create(m);
                }
                else if(ubm instanceof MultiColBitmap) {
                        MultiColBitmap mcbm = (MultiColBitmap) ubm;
@@ -166,7 +166,7 @@ public interface DictionaryFactory {
 
                        m.recomputeNonZeros();
                        m.examSparsity(true);
-                       return new MatrixBlockDictionary(m);
+                       return MatrixBlockDictionary.create(m);
                }
                else {
                        double[] dict = new double[nCol * nVal];
@@ -216,7 +216,7 @@ public interface DictionaryFactory {
                        }
                        m.recomputeNonZeros();
                        m.examSparsity(true);
-                       return new MatrixBlockDictionary(m);
+                       return MatrixBlockDictionary.create(m);
                }
 
                final double[] resValues = new double[nRows * nCols];
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
index b13efbe4c1..d447308e24 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
@@ -60,35 +60,35 @@ public class MatrixBlockDictionary extends ADictionary {
         * 
         * @param data The matrix block data.
         */
-       public MatrixBlockDictionary(MatrixBlock data) {
-               this(data, true);
+       protected MatrixBlockDictionary(MatrixBlock data) {
+               _data = data;
        }
 
-       /**
-        * Unsafe private constructor that does not check the data validity. 
USE WITH CAUTION.
-        * 
-        * @param data  The matrix block data.
-        * @param check Check the nonZeros in the dict
-        */
-       public MatrixBlockDictionary(MatrixBlock data, boolean check) {
-               if(check) {
-                       data.examSparsity(true);
-                       if(data.isEmpty())
-                               throw new DMLCompressionException("Invalid 
construction of empty dictionary");
-                       else if(data.isInSparseFormat() && 
data.getSparseBlock() instanceof SparseBlockMCSR) {
-                               SparseBlock csr = 
SparseBlockFactory.copySparseBlock(SparseBlock.Type.CSR, data.getSparseBlock(), 
false);
-                               data.setSparseBlock(csr);
+       public static MatrixBlockDictionary create(MatrixBlock mb) {
+               return create(mb, true);
+       }
+
+       public static MatrixBlockDictionary create(MatrixBlock mb, boolean 
check) {
+               if(mb == null)
+                       throw new DMLCompressionException("Invalid construction 
of dictionary with null array");
+               else if(mb.getNumRows() == 0 || mb.getNumColumns() == 0)
+                       throw new DMLCompressionException("Invalid construction 
of dictionary with zero rows and/or cols array");
+               else if(mb.isEmpty())
+                       return null;
+               else if(check) {
+                       mb.examSparsity(true);
+                       if(mb.isInSparseFormat() && mb.getSparseBlock() 
instanceof SparseBlockMCSR) {
+                               // make CSR sparse block to make it smaller.
+                               SparseBlock csr = 
SparseBlockFactory.copySparseBlock(SparseBlock.Type.CSR, mb.getSparseBlock(), 
false);
+                               mb.setSparseBlock(csr);
                        }
                }
-               _data = data;
+               return new MatrixBlockDictionary(mb);
        }
 
        public static MatrixBlockDictionary createDictionary(double[] values, 
int nCol, boolean check) {
                final MatrixBlock mb = Util.matrixBlockFromDenseArray(values, 
nCol, check);
-               if(mb.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(mb, check);
+               return create(mb, check);
        }
 
        public MatrixBlock getMatrixBlock() {
@@ -398,10 +398,7 @@ public class MatrixBlockDictionary extends ADictionary {
        @Override
        public ADictionary applyScalarOp(ScalarOperator op) {
                MatrixBlock res = _data.scalarOperations(op, new MatrixBlock());
-               if(res.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(res);
+               return MatrixBlockDictionary.create(res);
        }
 
        @Override
@@ -433,16 +430,13 @@ public class MatrixBlockDictionary extends ADictionary {
                }
 
                ret.recomputeNonZeros();
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
        public ADictionary applyUnaryOp(UnaryOperator op) {
                MatrixBlock res = _data.unaryOperations(op, new MatrixBlock());
-               if(res.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(res);
+               return MatrixBlockDictionary.create(res);
        }
 
        @Override
@@ -474,7 +468,7 @@ public class MatrixBlockDictionary extends ADictionary {
                }
 
                ret.recomputeNonZeros();
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -518,10 +512,7 @@ public class MatrixBlockDictionary extends ADictionary {
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
 
        }
 
@@ -566,10 +557,7 @@ public class MatrixBlockDictionary extends ADictionary {
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
 
        }
 
@@ -618,10 +606,7 @@ public class MatrixBlockDictionary extends ADictionary {
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -671,10 +656,7 @@ public class MatrixBlockDictionary extends ADictionary {
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -721,10 +703,7 @@ public class MatrixBlockDictionary extends ADictionary {
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
 
        }
 
@@ -732,9 +711,7 @@ public class MatrixBlockDictionary extends ADictionary {
        public MatrixBlockDictionary binOpRight(BinaryOperator op, double[] v, 
int[] colIndexes) {
                final MatrixBlock rowVector = Util.extractValues(v, colIndexes);
                final MatrixBlock ret = _data.binaryOperations(op, rowVector, 
null);
-               if(ret.isEmpty())
-                       return null;
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -784,19 +761,14 @@ public class MatrixBlockDictionary extends ADictionary {
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
        public MatrixBlockDictionary binOpRight(BinaryOperator op, double[] v) {
                final MatrixBlock rowVector = new MatrixBlock(1, v.length, v);
                final MatrixBlock ret = _data.binaryOperations(op, rowVector, 
null);
-               if(ret.isEmpty())
-                       return null;
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -843,10 +815,7 @@ public class MatrixBlockDictionary extends ADictionary {
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
 
        }
 
@@ -854,12 +823,12 @@ public class MatrixBlockDictionary extends ADictionary {
        public ADictionary clone() {
                MatrixBlock ret = new MatrixBlock();
                ret.copy(_data);
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
-       public boolean isLossy() {
-               return false;
+       public DictType getDictType() {
+               return DictType.MatrixBlock;
        }
 
        @Override
@@ -1460,10 +1429,8 @@ public class MatrixBlockDictionary extends ADictionary {
 
        @Override
        public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int 
previousNumberOfColumns) {
-               final MatrixBlock retBlock = _data.slice(0, _data.getNumRows() 
- 1, idxStart, idxEnd - 1);
-               if(retBlock.isEmpty())
-                       return null;
-               return new MatrixBlockDictionary(retBlock);
+               final MatrixBlock ret = _data.slice(0, _data.getNumRows() - 1, 
idxStart, idxEnd - 1);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -1694,9 +1661,7 @@ public class MatrixBlockDictionary extends ADictionary {
                MatrixBlock v = new MatrixBlock(1, tuple.length, tuple);
                BinaryOperator op = new 
BinaryOperator(Minus.getMinusFnObject());
                MatrixBlock ret = _data.binaryOperations(op, v, null);
-               if(ret.isEmpty())
-                       return null;
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -1741,7 +1706,7 @@ public class MatrixBlockDictionary extends ADictionary {
                                }
                        }
                        retBlock.setNonZeros(_data.getNonZeros());
-                       return new MatrixBlockDictionary(retBlock);
+                       return MatrixBlockDictionary.create(retBlock);
                }
                else {
                        final double[] _values = _data.getDenseBlockValues();
@@ -1757,7 +1722,7 @@ public class MatrixBlockDictionary extends ADictionary {
                        DenseBlockFP64 db = new DenseBlockFP64(new int[] 
{_data.getNumRows(), _data.getNumColumns()}, scaledValues);
                        MatrixBlock retBlock = new 
MatrixBlock(_data.getNumRows(), _data.getNumColumns(), db);
                        retBlock.setNonZeros(_data.getNonZeros());
-                       return new MatrixBlockDictionary(retBlock);
+                       return MatrixBlockDictionary.create(retBlock);
                }
        }
 
@@ -1770,7 +1735,7 @@ public class MatrixBlockDictionary extends ADictionary {
        public static MatrixBlockDictionary read(DataInput in) throws 
IOException {
                MatrixBlock ret = new MatrixBlock();
                ret.readFields(in);
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -1817,21 +1782,17 @@ public class MatrixBlockDictionary extends ADictionary {
                }
 
                DenseBlock dictV = new DenseBlockFP64(new int[] {numVals, 
aggregateColumns.length}, ret);
-               MatrixBlock dictM = new MatrixBlock(numVals, 
aggregateColumns.length, dictV);
-               dictM.recomputeNonZeros();
-               dictM.examSparsity();
-               if(dictM.isEmpty())
-                       return null;
-               return new MatrixBlockDictionary(dictM);
+               MatrixBlock r = new MatrixBlock(numVals, 
aggregateColumns.length, dictV);
+               r.recomputeNonZeros();
+               r.examSparsity();
+               return MatrixBlockDictionary.create(r);
 
        }
 
        @Override
        public ADictionary replace(double pattern, double replace, int nCol) {
                final MatrixBlock ret = _data.replaceOperations(new 
MatrixBlock(), pattern, replace);
-               if(ret.isEmpty())
-                       return null;
-               return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
@@ -1856,7 +1817,7 @@ public class MatrixBlockDictionary extends ADictionary {
                                        int j = 0;
                                        for(int k = apos; j < nCol && k < alen; 
j++) {
                                                final double v = aix[k] == j ? 
avals[k++] + reference[j] : reference[j];
-                                               retV[off++] = pattern == v ? 
replace - reference[j] : v - reference[j];
+                                               retV[off++] = Math.abs(v - 
pattern) < 0.00001 ? replace - reference[j] : v - reference[j];
                                        }
                                        for(; j < nCol; j++)
                                                retV[off++] = pattern == 
reference[j] ? replace - reference[j] : 0;
@@ -1868,17 +1829,14 @@ public class MatrixBlockDictionary extends ADictionary {
                        for(int i = 0; i < nRow; i++) {
                                for(int j = 0; j < nCol; j++) {
                                        final double v = values[off];
-                                       retV[off++] = pattern == v + 
reference[j] ? replace - reference[j] : v;
+                                       retV[off++] = Math.abs(v + reference[j] 
- pattern) < 0.00001 ? replace - reference[j] : v;
                                }
                        }
                }
 
                ret.recomputeNonZeros();
                ret.examSparsity();
-               if(ret.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ret);
+               return MatrixBlockDictionary.create(ret);
 
        }
 
@@ -2008,20 +1966,16 @@ public class MatrixBlockDictionary extends ADictionary {
 
        @Override
        public ADictionary rexpandCols(int max, boolean ignore, boolean cast, 
int nCol) {
-               MatrixBlock ex = LibMatrixReorg.rexpand(_data, new 
MatrixBlock(), max, false, cast, ignore, 1);
-               if(ex.isEmpty())
-                       return null;
-               else
-                       return new MatrixBlockDictionary(ex);
+               if(nCol > 1)
+                       throw new DMLCompressionException("Invalid to rexpand 
the column groups if more than one column");
+               MatrixBlock ret = LibMatrixReorg.rexpand(_data, new 
MatrixBlock(), max, false, cast, ignore, 1);
+               return MatrixBlockDictionary.create(ret);
        }
 
        @Override
        public ADictionary rexpandColsWithReference(int max, boolean ignore, 
boolean cast, int reference) {
                ADictionary a = applyScalarOp(new 
LeftScalarOperator(Plus.getPlusFnObject(), reference));
-               if(a == null)
-                       return null;
-               else
-                       return a.rexpandCols(max, ignore, cast, 1);
+               return a == null ? null : a.rexpandCols(max, ignore, cast, 1);
        }
 
        @Override
@@ -2145,7 +2099,7 @@ public class MatrixBlockDictionary extends ADictionary {
        }
 
        @Override
-       public boolean eq(ADictionary o) {
+       public boolean equals(ADictionary o) {
                if(o instanceof MatrixBlockDictionary)
                        return _data.equals(((MatrixBlockDictionary) o)._data);
                else if(o instanceof Dictionary) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
index 3a12445dbe..baa0d49405 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
@@ -402,8 +402,8 @@ public class QDictionary extends ADictionary {
        }
 
        @Override
-       public boolean isLossy() {
-               return false;
+       public DictType getDictType() {
+               return DictType.UInt8;
        }
 
        @Override
@@ -597,7 +597,7 @@ public class QDictionary extends ADictionary {
        }
 
        @Override
-       public boolean eq(ADictionary o) {
-                       throw new NotImplementedException();
+       public boolean equals(ADictionary o) {
+               throw new NotImplementedException();
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
index 3bd171855d..64e45039a2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
@@ -313,8 +313,9 @@ public class CLALibBinaryCellOp {
                // apply overlap
                if(smallestSize == Integer.MAX_VALUE) {
                        // if there was no smallest colgroup
-                       ADictionary newDict = new MatrixBlockDictionary(m2);
-                       newColGroups.add(ColGroupConst.create(nCol, newDict));
+                       ADictionary newDict = MatrixBlockDictionary.create(m2);
+                       if(newDict != null)     
+                               newColGroups.add(ColGroupConst.create(nCol, 
newDict));
                }
                else {
                        // apply to the found group
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java
index e38cd2081d..a807cb0343 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java
@@ -222,19 +222,25 @@ public class ColGroupFactoryTest {
 
        @Test
        public void testCompressMultipleTimes() {
-               final int offs = Math.min((int) (mbt.getSparsity() * nRow * 
nCol), nRow);
-               final EstimationFactors f = new 
EstimationFactors(Math.min(nRow, offs), nRow, offs, mbt.getSparsity());
-               final List<CompressedSizeInfoColGroup> es = new ArrayList<>();
-               es.add(new CompressedSizeInfoColGroup(cols, f, 312152, ct));
-               es.add(new CompressedSizeInfoColGroup(cols, f, 312152, ct));// 
second time.
-               final CompressedSizeInfo csi = new CompressedSizeInfo(es);
-               CompressionSettings cs = csb.create();
-
-               cs.transposed = true;
-               if(ce != null)
-                       ColGroupFactory.compressColGroups(mbt, csi, cs, ce, 4);
-               else
-                       ColGroupFactory.compressColGroups(mbt, csi, cs, 4);
+               try{
+
+                       final int offs = Math.min((int) (mbt.getSparsity() * 
nRow * nCol), nRow);
+                       final EstimationFactors f = new 
EstimationFactors(Math.min(nRow, offs), nRow, offs, mbt.getSparsity());
+                       final List<CompressedSizeInfoColGroup> es = new 
ArrayList<>();
+                       es.add(new CompressedSizeInfoColGroup(cols, f, 312152, 
ct));
+                       es.add(new CompressedSizeInfoColGroup(cols, f, 312152, 
ct));// second time.
+                       final CompressedSizeInfo csi = new 
CompressedSizeInfo(es);
+                       CompressionSettings cs = csb.create();
+       
+                       cs.transposed = true;
+                       if(ce != null)
+                               ColGroupFactory.compressColGroups(mbt, csi, cs, 
ce, 4);
+                       else
+                               ColGroupFactory.compressColGroups(mbt, csi, cs, 
4);
+               }catch(Exception e){
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
        }
 
        private void compare(List<AColGroup> gt) {
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java
new file mode 100644
index 0000000000..31f301749a
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.dictionary;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
+import 
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.junit.Test;
+
+public class CustomDictionaryTest {
+
+       protected static final Log LOG = 
LogFactory.getLog(CustomDictionaryTest.class.getName());
+
+       @Test
+       public void testContainsValue() {
+               Dictionary d = Dictionary.createNoCheck(new double[] {1, 2, 3});
+               assertTrue(d.containsValue(1));
+               assertTrue(!d.containsValue(-1));
+       }
+
+       @Test
+       public void testContainsValue_nan() {
+               Dictionary d = Dictionary.createNoCheck(new double[] 
{Double.NaN, 2, 3});
+               assertTrue(d.containsValue(Double.NaN));
+       }
+
+       @Test
+       public void testContainsValue_nan_not() {
+               Dictionary d = Dictionary.createNoCheck(new double[] {1, 2, 3});
+               assertTrue(!d.containsValue(Double.NaN));
+       }
+
+       @Test
+       public void testToString() {
+               ADictionary d = Dictionary.create(new double[] {1.0, 2.0, 3.3, 
4.0, 5.0, 6.0});
+               String s = d.toString();
+               assertFalse(s.contains("0"));
+               assertTrue(s.contains("1"));
+               assertTrue(s.contains("2"));
+               assertTrue(s.contains("3.3"));
+               assertTrue(s.contains("4"));
+               assertTrue(s.contains("5"));
+               assertTrue(s.contains("6"));
+               assertTrue(s.contains(","));
+       }
+
+       @Test
+       public void testGetString2() {
+               ADictionary d = Dictionary.create(new double[] {1.0, 2.0, 3.3, 
4.0, 5.0, 6.0});
+               String s = d.getString(2);
+               assertFalse(s.contains("0"));
+               assertTrue(s.contains("1"));
+               assertTrue(s.contains("2"));
+               assertTrue(s.contains("3.3"));
+               assertTrue(s.contains("4"));
+               assertTrue(s.contains("5"));
+               assertTrue(s.contains("6"));
+               assertTrue(s.contains(","));
+       }
+
+       @Test
+       public void testGetString1() {
+               ADictionary d = Dictionary.create(new double[] {1.0, 2.0, 3.3, 
4.0, 5.0, 6.0});
+               String s = d.getString(1);
+               assertFalse(s.contains("0"));
+               assertTrue(s.contains("1"));
+               assertTrue(s.contains("2"));
+               assertTrue(s.contains("3.3"));
+               assertTrue(s.contains("4"));
+               assertTrue(s.contains("5"));
+               assertTrue(s.contains("6"));
+               assertTrue(s.contains(","));
+       }
+
+       @Test
+       public void testGetString3() {
+               ADictionary d = Dictionary.create(new double[] {1.0, 2.0, 3.3, 
4.0, 5.0, 6.0});
+               String s = d.getString(3);
+               assertFalse(s.contains("0"));
+               assertTrue(s.contains("1"));
+               assertTrue(s.contains("2"));
+               assertTrue(s.contains("3.3"));
+               assertTrue(s.contains("4"));
+               assertTrue(s.contains("5"));
+               assertTrue(s.contains("6"));
+               assertTrue(s.contains(","));
+       }
+
+       @Test
+       public void isNullIfEmpty() {
+               ADictionary d = Dictionary.create(new double[] {0, 0, 0, 0});
+               assertNull("This should be null if empty creation", d);
+       }
+
+       @Test
+       public void isNullIfEmptyMatrixBlock() {
+               ADictionary d = MatrixBlockDictionary.create(new 
MatrixBlock(10, 10, 0.0));
+               assertNull("This should be null if empty creation", d);
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void createEmpty() {
+               Dictionary.create(new double[] {});
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void createNull() {
+               Dictionary.create(null);
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void createNullMatrixBlock() {
+               MatrixBlockDictionary.create(null);
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void createZeroRowAndColMatrixBlock() {
+               MatrixBlockDictionary.create(new MatrixBlock(0, 0, 10.0));
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void createZeroColMatrixBlock() {
+               MatrixBlockDictionary.create(new MatrixBlock(10, 0, 10.0));
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void createZeroRowMatrixBlock() {
+               MatrixBlockDictionary.create(new MatrixBlock(0, 10, 10.0));
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DeltaDictionaryTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DeltaDictionaryTest.java
index e94f0b9db0..b61d02737a 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DeltaDictionaryTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DeltaDictionaryTest.java
@@ -18,7 +18,9 @@
  */
 package org.apache.sysds.test.component.compress.dictionary;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysds.runtime.compress.colgroup.dictionary.DeltaDictionary;
+import org.apache.sysds.runtime.functionobjects.And;
 import org.apache.sysds.runtime.functionobjects.Divide;
 import org.apache.sysds.runtime.functionobjects.Minus;
 import org.apache.sysds.runtime.functionobjects.Multiply;
@@ -120,4 +122,12 @@ public class DeltaDictionaryTest {
                double[] expected = new double[] {3, 4, 3, 4};
                Assert.assertArrayEquals(expected, d.getValues(), 0.01);
        }
+
+       @Test(expected = NotImplementedException.class)
+       public void testNotImplemented() {
+               double scalar = 2;
+               DeltaDictionary d = new DeltaDictionary(new double[] {1, 2, 3, 
4}, 2);
+               ScalarOperator sop = new 
LeftScalarOperator(And.getAndFnObject(), scalar, 1);
+               d = d.applyScalarOp(sop);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTest.java
deleted file mode 100644
index 8ee6a9fa92..0000000000
--- 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTest.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.test.component.compress.dictionary;
-
-import static org.junit.Assert.assertTrue;
-
-import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
-import org.junit.Test;
-
-public class DictionaryTest {
-
-       @Test
-       public void testContainsValue() {
-               Dictionary d = Dictionary.createNoCheck(new double[] {1, 2, 3});
-               assertTrue(d.containsValue(1));
-               assertTrue(!d.containsValue(-1));
-       }
-
-       @Test
-       public void testContainsValue_nan() {
-               Dictionary d = Dictionary.createNoCheck(new double[] 
{Double.NaN, 2, 3});
-               assertTrue(d.containsValue(Double.NaN));
-       }
-
-       @Test
-       public void testContainsValue_nan_not() {
-               Dictionary d = Dictionary.createNoCheck(new double[] {1, 2, 3});
-               assertTrue(!d.containsValue(Double.NaN));
-       }
-}
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
new file mode 100644
index 0000000000..e1827d03bf
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
@@ -0,0 +1,456 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.dictionary;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
+import 
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import scala.util.Random;
+
+@RunWith(value = Parameterized.class)
+public class DictionaryTests {
+
+       protected static final Log LOG = 
LogFactory.getLog(DictionaryTests.class.getName());
+
+       private final int nRow;
+       private final int nCol;
+       private final ADictionary a;
+       private final ADictionary b;
+
+       public DictionaryTests(ADictionary a, ADictionary b, int nRow, int 
nCol) {
+               this.nRow = nRow;
+               this.nCol = nCol;
+               this.a = a;
+               this.b = b;
+       }
+
+       @Parameters
+       public static Collection<Object[]> data() {
+               List<Object[]> tests = new ArrayList<>();
+
+               try {
+                       addAll(tests, new double[] {1, 1, 1, 1, 1}, 1);
+                       addAll(tests, new double[] {-3, 0.0, 132, 43, 1}, 1);
+                       addAll(tests, new double[] {1, 2, 3, 4, 5}, 1);
+                       addAll(tests, new double[] {1, 2, 3, 4, 5, 6}, 2);
+                       addAll(tests, new double[] {1, 2.2, 3.3, 4.4, 5.5, 
6.6}, 3);
+
+                       create(tests, 30, 300, 0.2);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail("failed constructing tests");
+               }
+
+               return tests;
+       }
+
+       private static void create(List<Object[]> tests, int rows, int cols, 
double sparsity) {
+               MatrixBlock mb = TestUtils.generateTestMatrixBlock(rows, cols, 
-3, 3, 0.2, 1342);
+               mb.recomputeNonZeros();
+               MatrixBlock dense = new MatrixBlock();
+
+               dense.copy(mb);
+               dense.sparseToDense();
+               double[] values = dense.getDenseBlockValues();
+
+               tests.add(new Object[] {//
+                       Dictionary.create(values), //
+                       MatrixBlockDictionary.create(mb), //
+                       rows, cols});
+
+               tests.add(new Object[] {//
+                       Dictionary.create(values), //
+                       MatrixBlockDictionary.create(dense), //
+                       rows, cols});
+       }
+
+       private static void addAll(List<Object[]> tests, double[] vals, int 
cols) {
+               tests.add(new Object[] {//
+                       Dictionary.create(vals), //
+                       MatrixBlockDictionary.createDictionary(vals, cols, 
true), //
+                       vals.length / cols, cols});
+       }
+
+       @Test
+       public void sum() {
+               int[] counts = getCounts(nRow, 1324);
+               double as = a.sum(counts, nCol);
+               double bs = b.sum(counts, nCol);
+               assertEquals(as, bs, 0.0000001);
+       }
+
+       @Test
+       public void getValues() {
+               try {
+                       double[] av = a.getValues();
+                       double[] bv = b.getValues();
+                       TestUtils.compareMatricesBitAvgDistance(av, bv, 10, 10, 
"Not Equivalent values from getValues");
+               }
+               catch(DMLCompressionException e) {
+                       // okay since some cases are safeguarded by not 
allowing extraction of dense values.
+               }
+       }
+
+       @Test
+       public void getDictType() {
+               assertNotEquals(a.getDictType(), b.getDictType());
+       }
+
+       @Test
+       public void getSparsity() {
+               assertEquals(a.getSparsity(), b.getSparsity(), 0.001);
+       }
+
+       @Test
+       public void productZero() {
+               product(0.0);
+       }
+
+       @Test
+       public void productOne() {
+               product(1.0);
+       }
+
+       @Test
+       public void productMore() {
+               product(30.0);
+       }
+
+       public void product(double retV) {
+               // Shared
+               final int[] counts = getCounts(nRow, 1324);
+
+               // A
+               final double[] aRet = new double[] {retV};
+               a.product(aRet, counts, nCol);
+
+               // B
+               final double[] bRet = new double[] {retV};
+               b.product(bRet, counts, nCol);
+
+               TestUtils.compareMatricesBitAvgDistance(//
+                       aRet, bRet, 10, 10, "Not Equivalent values from 
product");
+       }
+
+       @Test
+       public void productWithReferenceZero() {
+               final double[] reference = getReference(nCol, 132, -3, 3);
+               productWithReference(0.0, reference);
+       }
+
+       @Test
+       public void productWithReferenceOne() {
+               final double[] reference = getReference(nCol, 132, -3, 3);
+               productWithReference(1.0, reference);
+       }
+
+       @Test
+       public void productWithDoctoredReference() {
+               final double[] reference = getReference(nCol, 132, 0.0, 0.0);
+               productWithReference(1.0, reference);
+       }
+
+       @Test
+       public void productWithDoctoredReference2() {
+               final double[] reference = getReference(nCol, 132, 1.0, 1.0);
+               productWithReference(1.0, reference);
+       }
+
+       public void productWithReference(double retV, double[] reference) {
+               // Shared
+               final int[] counts = getCounts(nRow, 1324);
+
+               // A
+               final double[] aRet = new double[] {retV};
+               a.productWithReference(aRet, counts, reference, nCol);
+
+               // B
+               final double[] bRet = new double[] {retV};
+               b.productWithReference(bRet, counts, reference, nCol);
+
+               TestUtils.compareMatricesBitAvgDistance(//
+                       aRet, bRet, 10, 10, "Not Equivalent values from 
product");
+       }
+
+       @Test
+       public void productWithdefZero() {
+               final double[] def = getReference(nCol, 132, -3, 3);
+               productWithDefault(0.0, def);
+       }
+
+       @Test
+       public void productWithdefOne() {
+               final double[] def = getReference(nCol, 132, -3, 3);
+               productWithDefault(1.0, def);
+       }
+
+       @Test
+       public void productWithDoctoreddef() {
+               final double[] def = getReference(nCol, 132, 0.0, 0.0);
+               productWithDefault(1.0, def);
+       }
+
+       @Test
+       public void productWithDoctoreddef2() {
+               final double[] def = getReference(nCol, 132, 1.0, 1.0);
+               productWithDefault(1.0, def);
+       }
+
+       @Test
+       public void replace() {
+               final Random rand = new Random(13);
+               final int r = rand.nextInt(nRow);
+               final int c = rand.nextInt(nCol);
+               final double v = a.getValue(r, c, nCol);
+               final double rep = rand.nextDouble();
+               final ADictionary aRep = a.replace(v, rep, nCol);
+               final ADictionary bRep = b.replace(v, rep, nCol);
+               assertEquals(aRep.getValue(r, c, nCol), rep, 0.0000001);
+               assertEquals(bRep.getValue(r, c, nCol), rep, 0.0000001);
+       }
+
+       @Test
+       public void replaceWitReference() {
+               final Random rand = new Random(444);
+               final int r = rand.nextInt(nRow);
+               final int c = rand.nextInt(nCol);
+               final double[] reference = getReference(nCol, 44, 1.0, 1.0);
+               final double before = a.getValue(r, c, nCol);
+               final double v = before + 1.0;
+               final double rep = rand.nextDouble() * 500;
+               final ADictionary aRep = a.replaceWithReference(v, rep, 
reference);
+               final ADictionary bRep = b.replaceWithReference(v, rep, 
reference);
+               assertEquals(aRep.getValue(r, c, nCol), bRep.getValue(r, c, 
nCol), 0.0000001);
+               assertNotEquals(before, aRep.getValue(r, c, nCol), 0.00001);
+       }
+
+       @Test
+       public void rexpandCols() {
+               if(nCol == 1) {
+                       int max = (int) a.aggregate(0, 
Builtin.getBuiltinFnObject(BuiltinCode.MAX));
+                       final ADictionary aR = a.rexpandCols(max + 1, true, 
false, nCol);
+                       final ADictionary bR = b.rexpandCols(max + 1, true, 
false, nCol);
+                       compare(aR, bR, nRow, max + 1);
+               }
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void rexpandColsException() {
+               if(nCol > 1) {
+                       int max = (int) a.aggregate(0, 
Builtin.getBuiltinFnObject(BuiltinCode.MAX));
+                       b.rexpandCols(max + 1, true, false, nCol);
+               }
+               else
+                       throw new DMLCompressionException("to test pase");
+       }
+
+       @Test(expected = DMLCompressionException.class)
+       public void rexpandColsExceptionOtherOrder() {
+               if(nCol > 1) {
+                       int max = (int) a.aggregate(0, 
Builtin.getBuiltinFnObject(BuiltinCode.MAX));
+                       a.rexpandCols(max + 1, true, false, nCol);
+               }
+               else
+                       throw new DMLCompressionException("to test pase");
+       }
+
+       @Test
+       public void rexpandColsWithReference1() {
+               rexpandColsWithReference(1);
+       }
+
+       @Test
+       public void rexpandColsWithReference33() {
+               rexpandColsWithReference(33);
+       }
+
+       @Test
+       public void rexpandColsWithReference_neg23() {
+               rexpandColsWithReference(-23);
+       }
+
+       @Test
+       public void rexpandColsWithReference_neg1() {
+               rexpandColsWithReference(-1);
+       }
+
+       public void rexpandColsWithReference(int reference) {
+               if(nCol == 1) {
+                       int max = (int) a.aggregate(0, 
Builtin.getBuiltinFnObject(BuiltinCode.MAX));
+
+                       final ADictionary aR = a.rexpandColsWithReference(max + 
1, true, false, reference);
+                       final ADictionary bR = b.rexpandColsWithReference(max + 
1, true, false, reference);
+                       if(aR == null && bR == null)
+                               return; // valid
+                       compare(aR, bR, nRow, max + 1);
+               }
+       }
+
+       @Test
+       public void sumSq() {
+               int[] counts = getCounts(nRow, 2323);
+               double as = a.sumSq(counts, nCol);
+               double bs = b.sumSq(counts, nCol);
+               assertEquals(as, bs, 0.0001);
+       }
+
+       @Test
+       public void sumSqWithReference() {
+               int[] counts = getCounts(nRow, 2323);
+               double[] reference = getReference(nCol, 323, -10, 23);
+               double as = a.sumSqWithReference(counts, reference);
+               double bs = b.sumSqWithReference(counts, reference);
+               assertEquals(as, bs, 0.0001);
+       }
+
+       @Test
+       public void sliceOutColumnRange() {
+               Random r = new Random(2323);
+               int s = r.nextInt(nCol);
+               int e = r.nextInt(nCol - s) + s + 1;
+               ADictionary ad = a.sliceOutColumnRange(s, e, nCol);
+               ADictionary bd = b.sliceOutColumnRange(s, e, nCol);
+               compare(ad, bd, nRow, e - s);
+       }
+
+       @Test
+       public void contains1() {
+               containsValue(1);
+       }
+
+       @Test
+       public void contains2() {
+               containsValue(2);
+       }
+
+       @Test
+       public void contains100() {
+               containsValue(100);
+       }
+
+       @Test
+       public void contains0() {
+               containsValue(0);
+       }
+
+       @Test
+       public void contains1p1() {
+               containsValue(1.1);
+       }
+
+       public void containsValue(double value) {
+               assertEquals(a.containsValue(value), b.containsValue(value));
+       }
+
+       @Test
+       public void contains1WithReference() {
+               containsValueWithReference(1, getReference(nCol, 3241, 1.0, 
1.0));
+       }
+
+       @Test
+       public void contains1WithReference2() {
+               containsValueWithReference(1, getReference(nCol, 3241, 1.0, 
1.32));
+       }
+
+       @Test
+       public void contains32WithReference2() {
+               containsValueWithReference(32, getReference(nCol, 3241, -1.0, 
1.32));
+       }
+
+       @Test
+       public void contains0WithReference1() {
+               containsValueWithReference(0, getReference(nCol, 3241, 1.0, 
1.0));
+       }
+
+       @Test
+       public void contains1WithReferenceMinus1() {
+               containsValueWithReference(1.0, getReference(nCol, 3241, -1.0, 
-1.0));
+       }
+
+       public void containsValueWithReference(double value, double[] 
reference) {
+               assertEquals(//
+                       a.containsValueWithReference(value, reference), //
+                       b.containsValueWithReference(value, reference));
+       }
+
+       private static void compare(ADictionary a, ADictionary b, int nRow, int 
nCol) {
+               for(int i = 0; i < nRow; i++)
+                       for(int j = 0; j < nCol; j++)
+                               assertEquals(a.getValue(i, j, nCol), 
b.getValue(i, j, nCol), 0.0001);
+       }
+
+       public void productWithDefault(double retV, double[] def) {
+               // Shared
+               final int[] counts = getCounts(nRow, 1324);
+
+               // A
+               final double[] aRet = new double[] {retV};
+               a.productWithDefault(aRet, counts, def, nCol);
+
+               // B
+               final double[] bRet = new double[] {retV};
+               b.productWithDefault(bRet, counts, def, nCol);
+
+               TestUtils.compareMatricesBitAvgDistance(//
+                       aRet, bRet, 10, 10, "Not Equivalent values from 
product");
+       }
+
+       private static int[] getCounts(int nRows, int seed) {
+               int[] counts = new int[nRows];
+               Random r = new Random(seed);
+               for(int i = 0; i < nRows; i++)
+                       counts[i] = r.nextInt(100);
+               return counts;
+       }
+
+       private static double[] getReference(int nCol, int seed, double min, 
double max) {
+               double[] reference = new double[nCol];
+               Random r = new Random(seed);
+               double diff = max - min;
+               if(diff == 0)
+                       for(int i = 0; i < nCol; i++)
+                               reference[i] = max;
+               else
+                       for(int i = 0; i < nCol; i++)
+                               reference[i] = r.nextDouble() * diff - min;
+               return reference;
+       }
+}

Reply via email to