This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 96fd5da49e [SYSTEMDS-3808] Dictionary Compressed Combine
96fd5da49e is described below
commit 96fd5da49e4f936c378f11cc275c34ff5f8a84d1
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Dec 29 20:51:40 2024 +0100
[SYSTEMDS-3808] Dictionary Compressed Combine
This commit speedup the combining of dictionaries via custum hashmaps.
Closes #2166
Signed-off-by: Sebastian Baunsgaard <[email protected]>
---
.../colgroup/dictionary/DictionaryFactory.java | 501 ++++++++++++++++-----
.../compress/estim/encoding/ConstEncoding.java | 5 +-
.../compress/estim/encoding/DenseEncoding.java | 229 ++++++++--
.../compress/estim/encoding/EmptyEncoding.java | 5 +-
.../compress/estim/encoding/EncodingFactory.java | 12 +-
.../runtime/compress/estim/encoding/IEncode.java | 9 +-
.../compress/estim/encoding/SparseEncoding.java | 16 +-
.../runtime/compress/lib/CLALibCombineGroups.java | 22 +-
.../runtime/compress/utils/HashMapLongInt.java | 7 +-
.../compress/combine/CombineEncodings.java | 24 +-
.../compress/combine/CombineEncodingsUnique.java | 7 +-
.../component/compress/dictionary/CombineTest.java | 402 +++++++++++------
12 files changed, 888 insertions(+), 351 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
index f7c21b6368..9eb89e489e 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
@@ -21,11 +21,12 @@ package
org.apache.sysds.runtime.compress.colgroup.dictionary;
import java.io.DataInput;
import java.io.IOException;
-import java.util.Map;
+import java.util.List;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.bitmap.ABitmap;
import org.apache.sysds.runtime.compress.bitmap.Bitmap;
import org.apache.sysds.runtime.compress.bitmap.MultiColBitmap;
@@ -34,14 +35,18 @@ import
org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.IContainADictionary;
import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple;
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups;
import org.apache.sysds.runtime.compress.utils.ACount;
import org.apache.sysds.runtime.compress.utils.DblArray;
import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap;
import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt.KV;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.Pair;
public interface DictionaryFactory {
static final Log LOG =
LogFactory.getLog(DictionaryFactory.class.getName());
@@ -235,8 +240,7 @@ public interface DictionaryFactory {
return combineDictionaries(a, b, null);
}
- public static IDictionary combineDictionaries(AColGroupCompressed a,
AColGroupCompressed b,
- Map<Integer, Integer> filter) {
+ public static IDictionary combineDictionaries(AColGroupCompressed a,
AColGroupCompressed b, HashMapLongInt filter) {
if(a instanceof ColGroupEmpty && b instanceof ColGroupEmpty)
return null; // null return is handled elsewhere.
@@ -248,34 +252,37 @@ public interface DictionaryFactory {
if(ae && be) {
- IDictionary ad = ((IContainADictionary)
a).getDictionary();
- IDictionary bd = ((IContainADictionary)
b).getDictionary();
+ final IDictionary ad = ((IContainADictionary)
a).getDictionary();
+ final IDictionary bd = ((IContainADictionary)
b).getDictionary();
if(ac.isConst()) {
if(bc.isConst()) {
return
Dictionary.create(CLALibCombineGroups.constructDefaultTuple(a, b));
}
else if(bc.isDense()) {
final double[] at =
((IContainDefaultTuple) a).getDefaultTuple();
- return combineConstSparseSparseRet(at,
bd, b.getNumCols(), filter);
+ Pair<int[], int[]> r =
IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices());
+ return combineConstLeft(at, bd,
b.getNumCols(), r.getKey(), r.getValue(), filter);
}
}
else if(ac.isDense()) {
if(bc.isConst()) {
+ final Pair<int[], int[]> r =
IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices());
final double[] bt =
((IContainDefaultTuple) b).getDefaultTuple();
- return combineSparseConstSparseRet(ad,
a.getNumCols(), bt, filter);
+ return combineSparseConstSparseRet(ad,
a.getNumCols(), bt, r.getKey(), r.getValue(), filter);
+ }
+ else if(bc.isDense()) {
+ return combineFullDictionaries(ad,
a.getColIndices(), bd, b.getColIndices(), filter);
}
- else if(bc.isDense())
- return combineFullDictionaries(ad,
a.getNumCols(), bd, b.getNumCols(), filter);
else if(bc.isSDC()) {
double[] tuple =
((IContainDefaultTuple) b).getDefaultTuple();
- return combineSDCRight(ad,
a.getNumCols(), bd, tuple, filter);
+ return combineSDCRight(ad,
a.getColIndices(), bd, tuple, b.getColIndices(), filter);
}
}
else if(ac.isSDC()) {
if(bc.isSDC()) {
final double[] at =
((IContainDefaultTuple) a).getDefaultTuple();
final double[] bt =
((IContainDefaultTuple) b).getDefaultTuple();
- return combineSDC(ad, at, bd, bt,
filter);
+ return combineSDCFilter(ad, at,
a.getColIndices(), bd, bt, b.getColIndices(), filter);
}
}
}
@@ -291,34 +298,83 @@ public interface DictionaryFactory {
* @return The combined dictionary
*/
public static IDictionary combineDictionariesSparse(AColGroupCompressed
a, AColGroupCompressed b) {
+ return combineDictionariesSparse(a, b, null);
+ }
+
+ /**
+ * Combine the dictionaries assuming a sparse combination where each
dictionary can be a SDC containing a default
+ * element that have to be introduced into the combined dictionary.
+ *
+ * @param a A Dictionary can be SDC or const
+ * @param b A Dictionary can be Const or SDC
+ * @param filter A filter to remove elements in the combined dictionary
+ * @return The combined dictionary
+ */
+ public static IDictionary combineDictionariesSparse(AColGroupCompressed
a, AColGroupCompressed b,
+ HashMapLongInt filter) {
CompressionType ac = a.getCompType();
CompressionType bc = b.getCompType();
+ if(filter != null)
+ throw new NotImplementedException("Not supported filter
for sparse join yet!");
+
if(ac.isSDC()) {
- IDictionary ad = ((IContainADictionary)
a).getDictionary();
+ final IDictionary ad = ((IContainADictionary)
a).getDictionary();
if(bc.isConst()) {
+ final Pair<int[], int[]> r =
IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices());
double[] bt = ((IContainDefaultTuple)
b).getDefaultTuple();
- return combineSparseConstSparseRet(ad,
a.getNumCols(), bt);
+ return combineSparseConstSparseRet(ad,
a.getNumCols(), bt, r.getKey(), r.getValue());
}
else if(bc.isSDC()) {
- IDictionary bd = ((IContainADictionary)
b).getDictionary();
+ final IDictionary bd = ((IContainADictionary)
b).getDictionary();
if(a.sameIndexStructure(b)) {
- return ad.cbind(bd, b.getNumCols());
+ // in order or other order..
+ if(IColIndex.inOrder(a.getColIndices(),
b.getColIndices()))
+ return ad.cbind(bd,
b.getNumCols());
+ else
if(IColIndex.inOrder(b.getColIndices(), a.getColIndices()))
+ return bd.cbind(ad,
b.getNumCols());
+ else {
+ final Pair<int[], int[]> r =
IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices());
+ return cbindReorder(ad, bd,
r.getKey(), r.getValue());
+ }
}
- // real combine extract default and combine
like dense but with default before.
}
}
else if(ac.isConst()) {
- double[] at = ((IContainDefaultTuple)
a).getDefaultTuple();
+ final double[] at = ((IContainDefaultTuple)
a).getDefaultTuple();
if(bc.isSDC()) {
- IDictionary bd = ((IContainADictionary)
b).getDictionary();
- return combineConstSparseSparseRet(at, bd,
b.getNumCols());
+ final IDictionary bd = ((IContainADictionary)
b).getDictionary();
+ final Pair<int[], int[]> r =
IColIndex.reorderingIndexes(a.getColIndices(), b.getColIndices());
+ return combineConstLeftAll(at, bd,
b.getNumCols(), r.getKey(), r.getValue());
}
}
throw new NotImplementedException("Not supporting combining
dense: " + a + " " + b);
}
+ private static IDictionary cbindReorder(IDictionary a, IDictionary b,
int[] ai, int[] bi) {
+ final int nca = ai.length;
+ final int ncb = bi.length;
+ final int ra = a.getNumberOfValues(nca);
+ final int rb = b.getNumberOfValues(ncb);
+ final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
+ final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ if(ra != rb)
+ throw new DMLCompressionException("Invalid cbind
reorder, different sizes of dictionaries");
+ final MatrixBlock out = new MatrixBlock(ra, nca + ncb, false);
+
+ for(int r = 0; r < ra; r++) {// each row
+ //
+ for(int c = 0; c < nca; c++)
+ out.set(r, ai[c], ma.get(r, c));
+
+ for(int c = 0; c < ncb; c++)
+ out.set(r, bi[c], mb.get(r, c));
+ }
+
+ return new MatrixBlockDictionary(out);
+ }
+
/**
* Combine the dictionaries as if the dictionaries contain the full
spectrum of the combined data.
*
@@ -332,6 +388,13 @@ public interface DictionaryFactory {
return combineFullDictionaries(a, nca, b, ncb, null);
}
+ public static IDictionary combineFullDictionaries(IDictionary a,
IColIndex ai, IDictionary b, IColIndex bi,
+ HashMapLongInt filter) {
+ final int nca = ai.size();
+ final int ncb = bi.size();
+ return combineFullDictionaries(a, ai, nca, b, bi, ncb, filter);
+ }
+
/**
* Combine the dictionaries as if the dictionaries only contain the
values in the specified filter.
*
@@ -344,72 +407,116 @@ public interface DictionaryFactory {
* @return A combined dictionary
*/
public static IDictionary combineFullDictionaries(IDictionary a, int
nca, IDictionary b, int ncb,
- Map<Integer, Integer> filter) {
+ HashMapLongInt filter) {
+ return combineFullDictionaries(a, null, nca, b, null, ncb,
filter);
+ }
+
+ public static IDictionary combineFullDictionaries(IDictionary a,
IColIndex ai, int nca, IDictionary b, IColIndex bi,
+ int ncb, HashMapLongInt filter) {
+
final int ra = a.getNumberOfValues(nca);
final int rb = b.getNumberOfValues(ncb);
final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ final int filterSize = filter != null ? filter.size() : ra * rb;
+ if(filterSize == 0)
+ return null;
+ final MatrixBlock out = new MatrixBlock(filterSize, nca + ncb,
false);
+ out.allocateBlock();
- if(ra == 1 && rb == 1) {
+ if(ai != null && bi != null && !IColIndex.inOrder(ai, bi)) {
- if(filter == null || filter.containsKey(0))
- return new MatrixBlockDictionary(ma.append(mb));
+ Pair<int[], int[]> reordering =
IColIndex.reorderingIndexes(ai, bi);
+ if(filter != null)
+ // throw new NotImplementedException();
+ combineFullDictionariesOOOFilter(out, filter,
ra, rb, nca, ncb, reordering.getKey(), reordering.getValue(),
+ ma, mb);
else
- return null;
- }
+ combineFullDictionariesOOONoFilter(out, ra, rb,
nca, ncb, reordering.getKey(), reordering.getValue(), ma,
+ mb);
- MatrixBlock out = new MatrixBlock(filter != null ?
filter.size() : ra * rb, nca + ncb, false);
+ }
+ else {
+ if(filter != null)
+ combineFullDictionariesFilter(out, filter, ra,
rb, nca, ncb, ma, mb);
+ else
+ combineFullDictionariesNoFilter(out, ra, rb,
nca, ncb, ma, mb);
+ }
- out.allocateBlock();
-
- if(filter != null)
- combineFullWithFilter(nca, ncb, filter, ra, ma, mb,
out);
- else
- combineFullWithoutFilter(nca, ncb, ra, ma, mb, out);
+ out.examSparsity(true);
return new MatrixBlockDictionary(out);
}
- private static void combineFullWithoutFilter(int nca, int ncb, final
int ra, MatrixBlock ma, MatrixBlock mb,
- MatrixBlock out) {
- for(int r = 0; r < out.getNumRows(); r++) {
+ private static void combineFullDictionariesFilter(MatrixBlock out,
HashMapLongInt filter, int ra, int rb, int nca,
+ int ncb, MatrixBlock ma, MatrixBlock mb) {
+
+ for(KV k : filter) {
+ final int r = (int) (k.k);
+ final int o = k.v;
int ia = r % ra;
int ib = r / ra;
for(int c = 0; c < nca; c++)
- out.set(r, c, ma.get(ia, c));
+ out.set(o, c, ma.get(ia, c));
for(int c = 0; c < ncb; c++)
- out.set(r, c + nca, mb.get(ib, c));
-
+ out.set(o, c + nca, mb.get(ib, c));
}
}
- private static void combineFullWithFilter(int nca, int ncb,
Map<Integer, Integer> filter, final int ra,
- MatrixBlock ma, MatrixBlock mb, MatrixBlock out) {
- for(int r : filter.keySet()) {
- int o = filter.get(r);
+ private static void combineFullDictionariesOOOFilter(MatrixBlock out,
HashMapLongInt filter, int ra, int rb, int nca,
+ int ncb, int[] ai, int[] bi, MatrixBlock ma, MatrixBlock mb) {
+ for(KV k : filter) {
+ final int r = (int) (k.k);
+ final int o = k.v;
int ia = r % ra;
int ib = r / ra;
for(int c = 0; c < nca; c++)
- out.set(o, c, ma.get(ia, c));
+ out.set(o, ai[c], ma.get(ia, c));
+ for(int c = 0; c < ncb; c++)
+ out.set(o, bi[c], mb.get(ib, c));
+ }
+ }
+ private static void combineFullDictionariesOOONoFilter(MatrixBlock out,
int ra, int rb, int nca, int ncb, int[] ai,
+ int[] bi, MatrixBlock ma, MatrixBlock mb) {
+ for(int r = 0; r < out.getNumRows(); r++) {
+ int ia = r % ra;
+ int ib = r / ra;
+ for(int c = 0; c < nca; c++)
+ out.set(r, ai[c], ma.get(ia, c));
for(int c = 0; c < ncb; c++)
- out.set(o, c + nca, mb.get(ib, c));
+ out.set(r, bi[c], mb.get(ib, c));
+ }
+ }
+ private static void combineFullDictionariesNoFilter(MatrixBlock out,
int ra, int rb, int nca, int ncb,
+ MatrixBlock ma, MatrixBlock mb) {
+ for(int r = 0; r < out.getNumRows(); r++) {
+ int ia = r % ra;
+ int ib = r / ra;
+ for(int c = 0; c < nca; c++)
+ out.set(r, c, ma.get(ia, c));
+ for(int c = 0; c < ncb; c++)
+ out.set(r, c + nca, mb.get(ib, c));
}
}
- private static IDictionary combineSDCRight(IDictionary a, int nca,
IDictionary b, double[] tub) {
+ public static IDictionary combineSDCRightNoFilter(IDictionary a, int
nca, IDictionary b, double[] tub) {
+ return combineSDCRightNoFilter(a, null, nca, b, tub, null);
+ }
+ public static IDictionary combineSDCRightNoFilter(IDictionary a,
IColIndex ai, int nca, IDictionary b, double[] tub,
+ IColIndex bi) {
+ if(ai != null || bi != null)
+ throw new NotImplementedException();
final int ncb = tub.length;
final int ra = a.getNumberOfValues(nca);
final int rb = b.getNumberOfValues(ncb);
-
- MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
- MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
-
- MatrixBlock out = new MatrixBlock(ra * (rb + 1), nca + ncb,
false);
+ final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
+ final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ final MatrixBlock out = new MatrixBlock(ra * (rb + 1), nca +
ncb, false);
out.allocateBlock();
@@ -434,65 +541,116 @@ public interface DictionaryFactory {
return new MatrixBlockDictionary(out);
}
+ public static IDictionary combineSDCRight(IDictionary a, IColIndex ai,
IDictionary b, double[] tub, IColIndex bi,
+ HashMapLongInt filter) {
+ return combineSDCRight(a, ai, ai.size(), b, tub, bi, filter);
+ }
+
public static IDictionary combineSDCRight(IDictionary a, int nca,
IDictionary b, double[] tub,
- Map<Integer, Integer> filter) {
+ HashMapLongInt filter) {
+ return combineSDCRight(a, null, nca, b, tub, null, filter);
+ }
+
+ public static IDictionary combineSDCRight(IDictionary a, IColIndex ai,
int nca, IDictionary b, double[] tub,
+ IColIndex bi, HashMapLongInt filter) {
if(filter == null)
- return combineSDCRight(a, nca, b, tub);
+ return combineSDCRightNoFilter(a, ai, nca, b, tub, bi);
+
final int ncb = tub.length;
final int ra = a.getNumberOfValues(nca);
final int rb = b.getNumberOfValues(ncb);
-
- MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
- MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
-
- MatrixBlock out = new MatrixBlock(filter.size(), nca + ncb,
false);
-
+ final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
+ final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ final MatrixBlock out = new MatrixBlock(filter.size(), nca +
ncb, false);
out.allocateBlock();
- for(int r = 0; r < ra; r++) {
- if(filter.containsKey(r)) {
+ if(ai != null && bi != null) {
+ Pair<int[], int[]> re = IColIndex.reorderingIndexes(ai,
bi);
+ combineSDCRightOOOFilter(out, nca, ncb, tub, ra, rb,
ma, mb, re.getKey(), re.getValue(), filter);
+ }
+ else {
+ combineSDCRightFilter(out, nca, ncb, tub, ra, rb, ma,
mb, filter);
+ }
+ return new MatrixBlockDictionary(out);
+ }
- int o = filter.get(r);
+ private static void combineSDCRightFilter(MatrixBlock out, int nca, int
ncb, double[] tub, int ra, int rb,
+ MatrixBlock ma, MatrixBlock mb, HashMapLongInt filter) {
+ for(int r = 0; r < ra; r++) {
+ int o = filter.get(r);
+ if(o != -1) {
for(int c = 0; c < nca; c++)
out.set(o, c, ma.get(r, c));
for(int c = 0; c < ncb; c++)
out.set(o, c + nca, tub[c]);
}
-
}
-
- for(int r = ra; r < ra * rb; r++) {
- if(filter.containsKey(r)) {
- int o = filter.get(r);
-
+ for(int r = ra; r < ra * rb + ra; r++) {
+ int o = filter.get(r);
+ if(o != -1) {
int ia = r % ra;
int ib = r / ra - 1;
for(int c = 0; c < nca; c++) // all good.
out.set(o, c, ma.get(ia, c));
-
for(int c = 0; c < ncb; c++)
out.set(o, c + nca, mb.get(ib, c));
+ }
+ }
+ }
+ private static void combineSDCRightOOOFilter(MatrixBlock out, int nca,
int ncb, double[] tub, int ra, int rb,
+ MatrixBlock ma, MatrixBlock mb, int[] ai, int[] bi,
HashMapLongInt filter) {
+ for(int r = 0; r < ra; r++) {
+ int o = filter.get(r);
+ if(o != -1) {
+ for(int c = 0; c < nca; c++)
+ out.set(o, ai[c], ma.get(r, c));
+ for(int c = 0; c < ncb; c++)
+ out.set(o, bi[c], tub[c]);
}
}
- return new MatrixBlockDictionary(out);
+ for(int r = ra; r < ra * rb + ra; r++) {
+ int o = filter.get(r);
+ if(o != -1) {
+ int ia = r % ra;
+ int ib = r / ra - 1;
+ for(int c = 0; c < nca; c++) // all good.
+ out.set(o, ai[c], ma.get(ia, c));
+ for(int c = 0; c < ncb; c++)
+ out.set(o, bi[c], mb.get(ib, c));
+ }
+ }
+ }
+
+ public static IDictionary combineSDCNoFilter(IDictionary a, double[]
tua, IDictionary b, double[] tub) {
+ return combineSDCNoFilter(a, tua, null, b, tub, null);
}
- public static IDictionary combineSDC(IDictionary a, double[] tua,
IDictionary b, double[] tub) {
+ public static IDictionary combineSDCNoFilter(IDictionary a, double[]
tua, IColIndex ai, IDictionary b, double[] tub,
+ IColIndex bi) {
final int nca = tua.length;
final int ncb = tub.length;
final int ra = a.getNumberOfValues(nca);
final int rb = b.getNumberOfValues(ncb);
+ final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
+ final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ final MatrixBlock out = new MatrixBlock((ra + 1) * (rb + 1),
nca + ncb, false);
- MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
- MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ out.allocateBlock();
- MatrixBlock out = new MatrixBlock((ra + 1) * (rb + 1), nca +
ncb, false);
+ if(ai != null || bi != null) {
+ final Pair<int[], int[]> re =
IColIndex.reorderingIndexes(ai, bi);
+ combineSDCNoFilterOOO(nca, ncb, tua, tub, out, ma, mb,
ra, rb, re.getKey(), re.getValue());
+ }
+ else
+ combineSDCNoFilter(nca, ncb, tua, tub, out, ma, mb, ra,
rb);
+ return new MatrixBlockDictionary(out);
+ }
- out.allocateBlock();
+ private static void combineSDCNoFilter(int nca, int ncb, double[] tua,
double[] tub, MatrixBlock out, MatrixBlock ma,
+ MatrixBlock mb, int ra, int rb) {
// 0 row both default tuples
-
for(int c = 0; c < nca; c++)
out.set(0, c, tua[c]);
@@ -508,8 +666,8 @@ public interface DictionaryFactory {
}
for(int r = ra + 1; r < out.getNumRows(); r++) {
- int ia = r % (ra + 1) - 1;
- int ib = r / (ra + 1) - 1;
+ final int ia = r % (ra + 1) - 1;
+ final int ib = r / (ra + 1) - 1;
if(ia == -1)
for(int c = 0; c < nca; c++)
@@ -520,42 +678,89 @@ public interface DictionaryFactory {
for(int c = 0; c < ncb; c++) // all good here.
out.set(r, c + nca, mb.get(ib, c));
+ }
+ }
+ private static void combineSDCNoFilterOOO(int nca, int ncb, double[]
tua, double[] tub, MatrixBlock out,
+ MatrixBlock ma, MatrixBlock mb, int ra, int rb, int[] ai, int[]
bi) {
+
+ // 0 row both default tuples
+ for(int c = 0; c < nca; c++)
+ out.set(0, ai[c], tua[c]);
+
+ for(int c = 0; c < ncb; c++)
+ out.set(0, bi[c], tub[c]);
+
+ // default case for b and all cases for a.
+ for(int r = 1; r < ra + 1; r++) {
+ for(int c = 0; c < nca; c++)
+ out.set(r, ai[c], ma.get(r - 1, c));
+ for(int c = 0; c < ncb; c++)
+ out.set(r, bi[c], tub[c]);
}
- return new MatrixBlockDictionary(out);
+ for(int r = ra + 1; r < out.getNumRows(); r++) {
+ final int ia = r % (ra + 1) - 1;
+ final int ib = r / (ra + 1) - 1;
+
+ if(ia == -1)
+ for(int c = 0; c < nca; c++)
+ out.set(r, ai[c], tua[c]);
+ else
+ for(int c = 0; c < nca; c++)
+ out.set(r, ai[c], ma.get(ia, c));
+
+ for(int c = 0; c < ncb; c++) // all good here.
+ out.set(r, bi[c], mb.get(ib, c));
+ }
}
- public static IDictionary combineSDC(IDictionary a, double[] tua,
IDictionary b, double[] tub,
- Map<Integer, Integer> filter) {
+ public static IDictionary combineSDCFilter(IDictionary a, double[] tua,
IDictionary b, double[] tub,
+ HashMapLongInt filter) {
+ return combineSDCFilter(a, tua, null, b, tub, null, filter);
+ }
+
+ public static IDictionary combineSDCFilter(IDictionary a, double[] tua,
IColIndex ai, IDictionary b, double[] tub,
+ IColIndex bi, HashMapLongInt filter) {
if(filter == null)
- return combineSDC(a, tua, b, tub);
+ return combineSDCNoFilter(a, tua, ai, b, tub, bi);
+
final int nca = tua.length;
final int ncb = tub.length;
final int ra = a.getNumberOfValues(nca);
- final int rb = b.getNumberOfValues(nca);
+ final int rb = b.getNumberOfValues(ncb);
+ final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
+ final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ final MatrixBlock out = new MatrixBlock(filter.size(), nca +
ncb, false);
+ out.allocateBlock();
- MatrixBlock ma = a.getMBDict(nca).getMatrixBlock();
- MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock();
+ if(ai != null && bi != null) {
+ Pair<int[], int[]> re = IColIndex.reorderingIndexes(ai,
bi);
+ combineSDCFilterOOO(filter, nca, ncb, tua, tub, out,
ma, mb, ra, rb, re.getKey(), re.getValue());
+ }
+ else
+ combineSDCFilter(filter, nca, ncb, tua, tub, out, ma,
mb, ra, rb);
- MatrixBlock out = new MatrixBlock(filter.size(), nca + ncb,
false);
+ return new MatrixBlockDictionary(out);
+ }
- out.allocateBlock();
+ private static void combineSDCFilter(HashMapLongInt filter, int nca,
int ncb, double[] tua, double[] tub,
+ MatrixBlock out, MatrixBlock ma, MatrixBlock mb, int ra, int
rb) {
// 0 row both default tuples
- if(filter.containsKey(0)) {
- int o = filter.get(0);
+ final int o0 = filter.get(0);
+ if(o0 != -1) {
for(int c = 0; c < nca; c++)
- out.set(o, c, tua[c]);
+ out.set(o0, c, tua[c]);
for(int c = 0; c < ncb; c++)
- out.set(o, c + nca, tub[c]);
+ out.set(o0, c + nca, tub[c]);
}
// default case for b and all cases for a.
for(int r = 1; r < ra + 1; r++) {
- if(filter.containsKey(r)) {
- int o = filter.get(r);
+ final int o = filter.get(r);
+ if(o != -1) {
for(int c = 0; c < nca; c++)
out.set(o, c, ma.get(r - 1, c));
for(int c = 0; c < ncb; c++)
@@ -563,13 +768,11 @@ public interface DictionaryFactory {
}
}
- for(int r = ra + 1; r < ra * rb; r++) {
-
- if(filter.containsKey(r)) {
- int o = filter.get(r);
-
- int ia = r % (ra + 1) - 1;
- int ib = r / (ra + 1) - 1;
+ for(int r = ra + 1; r < ra * rb + ra + rb + 1; r++) {
+ final int o = filter.get(r);
+ if(o != -1) {
+ final int ia = r % (ra + 1) - 1;
+ final int ib = r / (ra + 1) - 1;
if(ia == -1)
for(int c = 0; c < nca; c++)
@@ -582,12 +785,50 @@ public interface DictionaryFactory {
out.set(o, c + nca, mb.get(ib, c));
}
}
+ }
- return new MatrixBlockDictionary(out);
+ private static void combineSDCFilterOOO(HashMapLongInt filter, int nca,
int ncb, double[] tua, double[] tub,
+ MatrixBlock out, MatrixBlock ma, MatrixBlock mb, int ra, int
rb, int[] ai, int[] bi) {
+ // 0 row both default tuples
+ final int o0 = filter.get(0);
+ if(o0 != -1) {
+ for(int c = 0; c < nca; c++)
+ out.set(o0, ai[c], tua[c]);
+ for(int c = 0; c < ncb; c++)
+ out.set(o0, bi[c], tub[c]);
+ }
+ // default case for b and all cases for a.
+ for(int r = 1; r < ra + 1; r++) {
+ final int o = filter.get(r);
+ if(o != -1) {
+ for(int c = 0; c < nca; c++)
+ out.set(o, ai[c], ma.get(r - 1, c));
+ for(int c = 0; c < ncb; c++)
+ out.set(o, bi[c], tub[c]);
+ }
+ }
+
+ for(int r = ra + 1; r < ra * rb + ra + rb + 1; r++) {
+ final int o = filter.get(r);
+ if(o != -1) {
+ final int ia = r % (ra + 1) - 1;
+ final int ib = r / (ra + 1) - 1;
+
+ if(ia == -1)
+ for(int c = 0; c < nca; c++)
+ out.set(o, ai[c], tua[c]);
+ else
+ for(int c = 0; c < nca; c++)
+ out.set(o, ai[c], ma.get(ia,
c));
+
+ for(int c = 0; c < ncb; c++) // all good here.
+ out.set(o, bi[c], mb.get(ib, c));
+ }
+ }
}
- private static IDictionary combineSparseConstSparseRet(IDictionary a,
int nca, double[] tub) {
+ private static IDictionary combineSparseConstSparseRet(IDictionary a,
int nca, double[] tub, int[] ai, int[] bi) {
final int ncb = tub.length;
final int ra = a.getNumberOfValues(nca);
@@ -600,19 +841,19 @@ public interface DictionaryFactory {
// default case for b and all cases for a.
for(int r = 0; r < ra; r++) {
for(int c = 0; c < nca; c++)
- out.set(r, c, ma.get(r, c));
+ out.set(r, ai[c], ma.get(r, c));
for(int c = 0; c < ncb; c++)
- out.set(r, c + nca, tub[c]);
+ out.set(r, bi[c], tub[c]);
}
return new MatrixBlockDictionary(out);
}
- private static IDictionary combineSparseConstSparseRet(IDictionary a,
int nca, double[] tub,
- Map<Integer, Integer> filter) {
+ private static IDictionary combineSparseConstSparseRet(IDictionary a,
int nca, double[] tub, int[] ai, int[] bi,
+ HashMapLongInt filter) {
if(filter == null)
- return combineSparseConstSparseRet(a, nca, tub);
+ return combineSparseConstSparseRet(a, nca, tub, ai, bi);
else
throw new NotImplementedException();
// final int ncb = tub.length;
@@ -636,7 +877,8 @@ public interface DictionaryFactory {
}
- private static IDictionary combineConstSparseSparseRet(double[] tua,
IDictionary b, int ncb) {
+ private static IDictionary combineConstLeftAll(double[] tua,
IDictionary b, int ncb, int[] ai, int[] bi) {
+
final int nca = tua.length;
final int rb = b.getNumberOfValues(ncb);
@@ -649,19 +891,19 @@ public interface DictionaryFactory {
// default case for b and all cases for a.
for(int r = 0; r < rb; r++) {
for(int c = 0; c < nca; c++)
- out.set(r, c, tua[c]);
+ out.set(r, ai[c], tua[c]);
for(int c = 0; c < ncb; c++)
- out.set(r, c + nca, mb.get(r, c));
+ out.set(r, bi[c], mb.get(r, c));
}
return new MatrixBlockDictionary(out);
}
- private static IDictionary combineConstSparseSparseRet(double[] tua,
IDictionary b, int ncb,
- Map<Integer, Integer> filter) {
+ private static IDictionary combineConstLeft(double[] tua, IDictionary
b, int ncb, int[] ai, int[] bi,
+ HashMapLongInt filter) {
if(filter == null)
- return combineConstSparseSparseRet(tua, b, ncb);
+ return combineConstLeftAll(tua, b, ncb, ai, bi);
else
throw new NotImplementedException();
// final int nca = tua.length;
@@ -684,4 +926,39 @@ public interface DictionaryFactory {
// return new MatrixBlockDictionary(out);
}
+
+ public static IDictionary cBindDictionaries(int nCol, List<IDictionary>
dicts) {
+ MatrixBlockDictionary baseDict = dicts.get(0).getMBDict(nCol);
+ MatrixBlock base = baseDict == null ? new MatrixBlock(1, nCol,
true) : baseDict.getMatrixBlock();
+ MatrixBlock[] others = new MatrixBlock[dicts.size() - 1];
+ for(int i = 1; i < dicts.size(); i++) {
+ MatrixBlockDictionary otherDict =
dicts.get(i).getMBDict(nCol);
+ MatrixBlock otherBase = otherDict == null ? new
MatrixBlock(1, nCol, true) : otherDict.getMatrixBlock();
+ others[i - 1] = otherBase;
+ }
+ MatrixBlock ret = base.append(others, null, true);
+ return MatrixBlockDictionary.create(ret, true);
+ }
+
+ // public static IDictionary cBindDictionaries(List<Pair<Integer,
IDictionary>> dicts) {
+ // MatrixBlock base =
dicts.get(0).getValue().getMBDict(dicts.get(0).getKey()).getMatrixBlock();
+ // MatrixBlock[] others = new MatrixBlock[dicts.size() - 1];
+ // for(int i = 1; i < dicts.size(); i++) {
+ // Pair<Integer, IDictionary> p = dicts.get(i);
+ // others[i - 1] = p.getValue().getMBDict(p.getKey()).getMatrixBlock();
+ // }
+ // MatrixBlock ret = base.append(others, null, true);
+ // return new MatrixBlockDictionary(ret);
+ // }
+
+ public static IDictionary cBindDictionaries(IDictionary left,
IDictionary right, int nColLeft, int nColRight) {
+ MatrixBlockDictionary base = left.getMBDict(nColLeft);
+ MatrixBlockDictionary add = right.getMBDict(nColRight);
+
+ MatrixBlock a = base == null ? (add != null ? new
MatrixBlock(add.getNumberOfValues(nColRight), nColLeft,
+ true) : new MatrixBlock(1, nColLeft, true)) :
base.getMatrixBlock();
+ MatrixBlock b = add == null ? new MatrixBlock(a.getNumRows(),
nColRight, true) : add.getMatrixBlock();
+ MatrixBlock ret = a.append(b, null, true);
+ return MatrixBlockDictionary.create(ret, true);
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java
index b0fe390f61..2111b85cbf 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java
@@ -19,12 +19,11 @@
package org.apache.sysds.runtime.compress.estim.encoding;
-import java.util.Map;
-
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
/** Const encoding for cases where the entire group of columns is the same
value */
public class ConstEncoding extends AEncode {
@@ -41,7 +40,7 @@ public class ConstEncoding extends AEncode {
}
@Override
- public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
+ public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
return new ImmutablePair<>(e, null);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java
index 8c612955aa..8fc9d96f72 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java
@@ -19,34 +19,40 @@
package org.apache.sysds.runtime.compress.estim.encoding;
-import java.util.HashMap;
-import java.util.Map;
-
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToCharPByte;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
/**
* An Encoding that contains a value on each row of the input.
*/
public class DenseEncoding extends AEncode {
+ private static boolean zeroWarn = false;
+
private final AMapToData map;
public DenseEncoding(AMapToData map) {
this.map = map;
if(CompressedMatrixBlock.debug) {
+ // if(!zeroWarn) {
int[] freq = map.getCounts();
- for(int i = 0; i < freq.length; i++) {
- if(freq[i] == 0)
- throw new
DMLCompressionException("Invalid counts in fact contains 0");
+ for(int i = 0; i < freq.length && !zeroWarn; i++) {
+ if(freq[i] == 0) {
+ LOG.warn("Dense encoding contains zero
encoding, indicating not all dictionary entries are in use");
+ zeroWarn = true;
+
+ }
}
}
}
@@ -62,7 +68,7 @@ public class DenseEncoding extends AEncode {
}
@Override
- public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
+ public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
if(e instanceof EmptyEncoding || e instanceof ConstEncoding)
return new ImmutablePair<>(this, null);
else if(e instanceof SparseEncoding)
@@ -106,14 +112,14 @@ public class DenseEncoding extends AEncode {
return ret;
}
- private final Pair<IEncode, Map<Integer, Integer>>
combineSparseHashMap(final AMapToData ret) {
+ private final Pair<IEncode, HashMapLongInt> combineSparseHashMap(final
AMapToData ret) {
final int size = ret.size();
- final Map<Integer, Integer> m = new HashMap<>(size);
+ final HashMapLongInt m = new HashMapLongInt(100);
for(int r = 0; r < size; r++) {
final int prev = ret.getIndex(r);
final int v = m.size();
- final Integer mv = m.putIfAbsent(prev, v);
- if(mv == null)
+ final int mv = m.putIfAbsent(prev, v);
+ if(mv == -1)
ret.set(r, v);
else
ret.set(r, mv);
@@ -146,28 +152,44 @@ public class DenseEncoding extends AEncode {
final int nVL = lm.getUnique();
final int nVR = rm.getUnique();
final int size = map.size();
- final int maxUnique = nVL * nVR;
-
+ int maxUnique = nVL * nVR;
+ final DenseEncoding retE;
final AMapToData ret = MapToFactory.create(size, maxUnique);
-
- if(maxUnique > size && maxUnique > 2048) {
+ if(maxUnique < Math.max(nVL, nVR)) {// overflow
+ final HashMapLongInt m = new
HashMapLongInt(Math.max(100, size / 100));
+ retE = combineDenseWithHashMapLong(lm, rm, size, nVL,
ret, m);
+ }
+ else if(maxUnique > size && maxUnique > 2048) {
// aka there is more maxUnique than rows.
- final Map<Integer, Integer> m = new HashMap<>(size);
- return combineDenseWithHashMap(lm, rm, size, nVL, ret,
m);
+ final HashMapLongInt m = new
HashMapLongInt(Math.max(100, maxUnique / 100));
+ retE = combineDenseWithHashMap(lm, rm, size, nVL, ret,
m);
}
else {
final AMapToData m = MapToFactory.create(maxUnique,
maxUnique + 1);
- return combineDenseWithMapToData(lm, rm, size, nVL,
ret, maxUnique, m);
+ retE = combineDenseWithMapToData(lm, rm, size, nVL,
ret, maxUnique, m);
+ }
+
+ if(retE.getUnique() < 0) {
+ String th = this.toString();
+ String ot = other.toString();
+ String cm = retE.toString();
+
+ if(th.length() > 1000)
+ th = th.substring(0, 1000);
+ if(ot.length() > 1000)
+ ot = ot.substring(0, 1000);
+ if(cm.length() > 1000)
+ cm = cm.substring(0, 1000);
+ throw new DMLCompressionException(
+ "Failed to combine dense encodings correctly:
Number unique values is lower than max input: \n\n" + th
+ + "\n\n" + ot + "\n\n" + cm);
}
+ return retE;
}
- private Pair<IEncode, Map<Integer, Integer>> combineDenseNoResize(final
DenseEncoding other) {
- if(map == other.map) {
- LOG.warn("Constructing perfect mapping, this could be
optimized to skip hashmap");
- final Map<Integer, Integer> m = new
HashMap<>(map.size());
- for(int i = 0; i < map.getUnique(); i++)
- m.put(i * i, i);
- return new ImmutablePair<>(this, m); // same object
+ private Pair<IEncode, HashMapLongInt> combineDenseNoResize(final
DenseEncoding other) {
+ if(map.equals(other.map)) {
+ return combineSameMapping();
}
final AMapToData lm = map;
@@ -176,40 +198,115 @@ public class DenseEncoding extends AEncode {
final int nVL = lm.getUnique();
final int nVR = rm.getUnique();
final int size = map.size();
- final int maxUnique = nVL * nVR;
+ final int maxUnique = (int) Math.min((long) nVL * nVR, (long)
size);
final AMapToData ret = MapToFactory.create(size, maxUnique);
- final Map<Integer, Integer> m = new HashMap<>(Math.min(size,
maxUnique));
+ final HashMapLongInt m = new HashMapLongInt(Math.max(100,
maxUnique / 1000));
return new ImmutablePair<>(combineDenseWithHashMap(lm, rm,
size, nVL, ret, m), m);
+ }
- // there can be less unique.
-
- // return new DenseEncoding(ret);
+ private Pair<IEncode, HashMapLongInt> combineSameMapping() {
+ LOG.warn("Constructing perfect mapping, this could be optimized
to skip hashmap");
+ final HashMapLongInt m = new HashMapLongInt(Math.max(100,
map.size() / 100));
+ for(int i = 0; i < map.getUnique(); i++)
+ m.putIfAbsent(i * (map.getUnique() + 1), i);
+ return new ImmutablePair<>(this, m); // same object
}
- private Pair<IEncode, Map<Integer, Integer>>
combineSparseNoResize(final SparseEncoding other) {
+ private Pair<IEncode, HashMapLongInt> combineSparseNoResize(final
SparseEncoding other) {
final AMapToData a = assignSparse(other);
return combineSparseHashMap(a);
}
+ protected final DenseEncoding combineDenseWithHashMapLong(final
AMapToData lm, final AMapToData rm, final int size,
+ final long nVL, final AMapToData ret, HashMapLongInt m) {
+ if(ret instanceof MapToChar)
+ for(int r = 0; r < size; r++)
+ addValHashMapChar((long) lm.getIndex(r) +
rm.getIndex(r) * nVL, r, m, (MapToChar) ret);
+ else
+ for(int r = 0; r < size; r++)
+ addValHashMap((long) lm.getIndex(r) +
rm.getIndex(r) * nVL, r, m, ret);
+ return new DenseEncoding(ret.resize(m.size()));
+ }
+
protected final DenseEncoding combineDenseWithHashMap(final AMapToData
lm, final AMapToData rm, final int size,
- final int nVL, final AMapToData ret, Map<Integer, Integer> m) {
+ final int nVL, final AMapToData ret, HashMapLongInt m) {
+ // JIT compile instance checks.
+ if(ret instanceof MapToChar)
+ combineDenseWIthHashMapCharOut(lm, rm, size, nVL,
(MapToChar) ret, m);
+ else if(ret instanceof MapToCharPByte)
+ combineDenseWIthHashMapPByteOut(lm, rm, size, nVL,
(MapToCharPByte) ret, m);
+ else
+ combineDenseWithHashMapGeneric(lm, rm, size, nVL, ret,
m);
+ ret.setUnique(m.size());
+ return new DenseEncoding(ret);
+ }
+
+ private final void combineDenseWIthHashMapPByteOut(final AMapToData lm,
final AMapToData rm, final int size,
+ final int nVL, final MapToCharPByte ret, HashMapLongInt m) {
+ for(int r = 0; r < size; r++)
+ addValHashMapCharByte(lm.getIndex(r) + rm.getIndex(r) *
nVL, r, m, ret);
+ }
+
+ private final void combineDenseWIthHashMapCharOut(final AMapToData lm,
final AMapToData rm, final int size,
+ final int nVL, final MapToChar ret, HashMapLongInt m) {
+ if(lm instanceof MapToChar && rm instanceof MapToChar)
+ combineDenseWIthHashMapAllChar(lm, rm, size, nVL, ret,
m);
+ else// some other combination
+ combineDenseWIthHashMapCharOutGeneric(lm, rm, size,
nVL, ret, m);
+ }
+
+ private final void combineDenseWIthHashMapCharOutGeneric(final
AMapToData lm, final AMapToData rm, final int size,
+ final int nVL, final MapToChar ret, HashMapLongInt m) {
+ for(int r = 0; r < size; r++)
+ addValHashMapChar(lm.getIndex(r) + rm.getIndex(r) *
nVL, r, m, ret);
+ }
+
+ private final void combineDenseWIthHashMapAllChar(final AMapToData lm,
final AMapToData rm, final int size,
+ final int nVL, final MapToChar ret, HashMapLongInt m) {
+ final MapToChar lmC = (MapToChar) lm;
+ final MapToChar rmC = (MapToChar) rm;
+ for(int r = 0; r < size; r++)
+ addValHashMapChar(lmC.getIndex(r) + rmC.getIndex(r) *
nVL, r, m, ret);
+
+ }
+
+ protected final void combineDenseWithHashMapGeneric(final AMapToData
lm, final AMapToData rm, final int size,
+ final int nVL, final AMapToData ret, HashMapLongInt m) {
for(int r = 0; r < size; r++)
addValHashMap(lm.getIndex(r) + rm.getIndex(r) * nVL, r,
m, ret);
- return new DenseEncoding(ret.resize(m.size()));
}
protected final DenseEncoding combineDenseWithMapToData(final
AMapToData lm, final AMapToData rm, final int size,
final int nVL, final AMapToData ret, final int maxUnique, final
AMapToData m) {
+ if(m instanceof MapToChar)
+ return combineDenseWithMapToDataToChar(lm, rm, size,
nVL, ret, maxUnique, (MapToChar) m);
+ else
+ return combineDenseWithMapToDataGeneric(lm, rm, size,
nVL, ret, maxUnique, m);
+
+ }
+
+ protected final DenseEncoding combineDenseWithMapToDataToChar(final
AMapToData lm, final AMapToData rm,
+ final int size, final int nVL, final AMapToData ret, final int
maxUnique, final MapToChar m) {
+ int newUID = 1;
+ for(int r = 0; r < size; r++)
+ newUID = addValMapToDataChar(lm.getIndex(r) +
rm.getIndex(r) * nVL, r, m, newUID, ret);
+ ret.setUnique(newUID - 1);
+ return new DenseEncoding(ret);
+ }
+
+ protected final DenseEncoding combineDenseWithMapToDataGeneric(final
AMapToData lm, final AMapToData rm,
+ final int size, final int nVL, final AMapToData ret, final int
maxUnique, final AMapToData m) {
int newUID = 1;
for(int r = 0; r < size; r++)
newUID = addValMapToData(lm.getIndex(r) +
rm.getIndex(r) * nVL, r, m, newUID, ret);
- return new DenseEncoding(ret.resize(newUID - 1));
+ ret.setUnique(newUID - 1);
+ return new DenseEncoding(ret);
}
- protected static int addValMapToData(final int nv, final int r, final
AMapToData map, int newId,
+ protected static int addValMapToDataChar(final int nv, final int r,
final MapToChar map, int newId,
final AMapToData d) {
int mv = map.getIndex(nv);
if(mv == 0)
@@ -218,11 +315,56 @@ public class DenseEncoding extends AEncode {
return newId;
}
- protected static void addValHashMap(final int nv, final int r, final
Map<Integer, Integer> map,
+ protected static int addValMapToData(final int nv, final int r, final
AMapToData map, int newId,
final AMapToData d) {
+ int mv = map.getIndex(nv);
+ if(mv == 0)
+ mv = map.setAndGet(nv, newId++);
+ d.set(r, mv - 1);
+ return newId;
+ }
+
+ protected static void addValHashMap(final int nv, final int r, final
HashMapLongInt map, final AMapToData d) {
final int v = map.size();
- final Integer mv = map.putIfAbsent(nv, v);
- if(mv == null)
+ final int mv = map.putIfAbsent(nv, v);
+ if(mv == -1)
+ d.set(r, v);
+ else
+ d.set(r, mv);
+ }
+
+ protected static void addValHashMapChar(final int nv, final int r,
final HashMapLongInt map, final MapToChar d) {
+ final int v = map.size();
+ final int mv = map.putIfAbsent(nv, v);
+ if(mv == -1)
+ d.set(r, v);
+ else
+ d.set(r, mv);
+ }
+
+ protected static void addValHashMapCharByte(final int nv, final int r,
final HashMapLongInt map,
+ final MapToCharPByte d) {
+ final int v = map.size();
+ final int mv = map.putIfAbsent(nv, v);
+ if(mv == -1)
+ d.set(r, v);
+ else
+ d.set(r, mv);
+ }
+
+ protected static void addValHashMapChar(final long nv, final int r,
final HashMapLongInt map, final MapToChar d) {
+ final int v = map.size();
+ final int mv = map.putIfAbsent(nv, v);
+ if(mv == -1)
+ d.set(r, v);
+ else
+ d.set(r, mv);
+ }
+
+ protected static void addValHashMap(final long nv, final int r, final
HashMapLongInt map, final AMapToData d) {
+ final int v = map.size();
+ final int mv = map.putIfAbsent(nv, v);
+ if(mv == -1)
d.set(r, v);
else
d.set(r, mv);
@@ -237,13 +379,18 @@ public class DenseEncoding extends AEncode {
public EstimationFactors extractFacts(int nRows, double tupleSparsity,
double matrixSparsity,
CompressionSettings cs) {
int largestOffs = 0;
-
int[] counts = map.getCounts();
for(int i = 0; i < counts.length; i++)
if(counts[i] > largestOffs)
largestOffs = counts[i];
- else if(counts[i] == 0)
- throw new DMLCompressionException("Invalid
count of 0 all values should have at least one instance");
+ else if(counts[i] == 0) {
+ if(!zeroWarn) {
+ LOG.warn("Invalid count of 0 all values
should have at least one instance index: " + i + " of "
+ + counts.length);
+ zeroWarn = true;
+ }
+ counts[i] = 1;
+ }
if(cs.isRLEAllowed())
return new EstimationFactors(map.getUnique(), nRows,
largestOffs, counts, 0, nRows, map.countRuns(), false,
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java
index 0d386f1424..806027ef52 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java
@@ -19,12 +19,11 @@
package org.apache.sysds.runtime.compress.estim.encoding;
-import java.util.Map;
-
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
/**
* Empty encoding for cases where the entire group of columns is zero
@@ -41,7 +40,7 @@ public class EmptyEncoding extends AEncode {
}
@Override
- public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
+ public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
return new ImmutablePair<>(e, null);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
index d8ab0f0f7c..257ddf6f3c 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
@@ -229,8 +229,16 @@ public interface EncodingFactory {
// Iteration 3 of non zero indexes, make a Offset
Encoding to know what cells are zero and not.
// not done yet
- final AOffset o = OffsetFactory.createOffset(aix, apos,
alen);
- return new SparseEncoding(d, o, m.getNumColumns());
+ try{
+
+ final AOffset o =
OffsetFactory.createOffset(aix, apos, alen);
+ return new SparseEncoding(d, o,
m.getNumColumns());
+ }
+ catch(Exception e){
+ String mes =
Arrays.toString(Arrays.copyOfRange(aix, apos, alen)) + "\n" + apos + " " +
alen;
+ mes +=
Arrays.toString(Arrays.copyOfRange(avals, apos, alen));
+ throw new DMLRuntimeException(mes, e);
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
index 15393a947b..a4a36fb019 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
@@ -19,13 +19,12 @@
package org.apache.sysds.runtime.compress.estim.encoding;
-import java.util.Map;
-
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
/**
* This interface covers an intermediate encoding for the samples to improve
the efficiency of the joining of sample
@@ -47,11 +46,15 @@ public interface IEncode {
/**
* Combine two encodings without resizing the output. meaning the
mapping of the indexes should be consistent with
* left hand side Dictionary indexes and right hand side indexes.
+ * <p>
+ *
+ *
+ * NOTE: Require both encodings to contain the correct metadata for
number of unique values.
*
* @param e The other side to combine with
* @return The combined encoding
*/
- public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e);
+ public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e);
/**
* Get the number of unique values in this encoding
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
index ffe365127a..970ea3a8d1 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
@@ -19,9 +19,6 @@
package org.apache.sysds.runtime.compress.estim.encoding;
-import java.util.HashMap;
-import java.util.Map;
-
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
@@ -33,6 +30,7 @@ import
org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
import org.apache.sysds.runtime.compress.utils.IntArrayList;
/**
@@ -80,7 +78,7 @@ public class SparseEncoding extends AEncode {
}
@Override
- public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
+ public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
if(e instanceof EmptyEncoding || e instanceof ConstEncoding)
return new ImmutablePair<>(this, null);
else if(e instanceof SparseEncoding) {
@@ -132,7 +130,7 @@ public class SparseEncoding extends AEncode {
}
}
- private Pair<IEncode, Map<Integer, Integer>>
combineSparseNoResizeDense(SparseEncoding e) {
+ private Pair<IEncode, HashMapLongInt>
combineSparseNoResizeDense(SparseEncoding e) {
final int fl = off.getOffsetToLast();
final int fr = e.off.getOffsetToLast();
@@ -162,7 +160,7 @@ public class SparseEncoding extends AEncode {
retMap.set(fr, retMap.getIndex(fr) +
(e.map.getIndex(itr.getDataIndex()) + 1) * nVl);
// Full iteration to set unique elements.
- final Map<Integer, Integer> m = new HashMap<>();
+ final HashMapLongInt m = new HashMapLongInt(100);
for(int i = 0; i < retMap.size(); i++)
addValHashMap(retMap.getIndex(i), i, m, retMap);
@@ -170,11 +168,11 @@ public class SparseEncoding extends AEncode {
}
- protected static void addValHashMap(final int nv, final int r, final
Map<Integer, Integer> map,
+ protected static void addValHashMap(final int nv, final int r, final
HashMapLongInt map,
final AMapToData d) {
final int v = map.size();
- final Integer mv = map.putIfAbsent(nv, v);
- if(mv == null)
+ final int mv = map.putIfAbsent(nv, v);
+ if(mv == -1)
d.set(r, v);
else
d.set(r, mv);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java
index 32ec9c0f32..285dbb96b3 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java
@@ -50,6 +50,7 @@ import
org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.EmptyEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -177,24 +178,25 @@ public final class CLALibCombineGroups {
}
// add if encodings are equal make shortcut.
- Pair<IEncode, Map<Integer, Integer>> cec =
ae.combineWithMap(be);
- IEncode ce = cec.getLeft();
- Map<Integer, Integer> filter = cec.getRight();
- if(ce instanceof DenseEncoding) {
- DenseEncoding ced = (DenseEncoding) (ce);
- IDictionary cd =
DictionaryFactory.combineDictionaries(ac, bc, filter);
- return ColGroupDDC.create(combinedColumns, cd,
ced.getMap(), null);
- }
- else if(ce instanceof EmptyEncoding) {
+ final Pair<IEncode, HashMapLongInt> cec = ae.combineWithMap(be);
+ final IEncode ce = cec.getLeft();
+ final HashMapLongInt filter = cec.getRight();
+
+ if(ce instanceof EmptyEncoding) {
return new ColGroupEmpty(combinedColumns);
}
else if(ce instanceof ConstEncoding) {
IDictionary cd =
DictionaryFactory.combineDictionaries(ac, bc, filter);
return ColGroupConst.create(combinedColumns, cd);
}
+ else if(ce instanceof DenseEncoding) {
+ DenseEncoding ced = (DenseEncoding) (ce);
+ IDictionary cd =
DictionaryFactory.combineDictionaries(ac, bc, filter);
+ return ColGroupDDC.create(combinedColumns, cd,
ced.getMap(), null);
+ }
else if(ce instanceof SparseEncoding) {
SparseEncoding sed = (SparseEncoding) ce;
- IDictionary cd =
DictionaryFactory.combineDictionariesSparse(ac, bc);
+ IDictionary cd =
DictionaryFactory.combineDictionariesSparse(ac, bc, filter);
double[] defaultTuple = constructDefaultTuple(ac, bc);
return ColGroupSDC.create(combinedColumns,
sed.getNumRows(), cd, defaultTuple, sed.getOffsets(), sed.getMap(),
null);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapLongInt.java
b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapLongInt.java
index 8379a06698..a221c7fb36 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapLongInt.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapLongInt.java
@@ -144,8 +144,8 @@ public class HashMapLongInt implements Iterable<KV> {
protected Itt() {
if(size == 0) {
- lastBucket = 0;
- lastCell = 0;
+ lastBucket = -1;
+ lastCell = -1;
}
else {
int tmpLastBucket = keys.length - 1;
@@ -164,7 +164,8 @@ public class HashMapLongInt implements Iterable<KV> {
@Override
public boolean hasNext() {
- return bucketId < lastBucket || (bucketId == lastBucket
&& bucketCell <= lastCell);
+ return lastBucket != -1 && //
+ (bucketId < lastBucket || (bucketId ==
lastBucket && bucketCell <= lastCell));
}
@Override
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java
b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java
index 847e7ff1e7..4f49f0bbe3 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodings.java
@@ -21,14 +21,13 @@ package org.apache.sysds.test.component.compress.combine;
import static org.junit.Assert.assertTrue;
-import java.util.Map;
-
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
import org.junit.Test;
public class CombineEncodings {
@@ -39,9 +38,9 @@ public class CombineEncodings {
public void combineCustom() {
IEncode ae = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 10));
IEncode be = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 10));
- Pair<IEncode, Map<Integer, Integer>> cec =
ae.combineWithMap(be);
+ Pair<IEncode, HashMapLongInt> cec = ae.combineWithMap(be);
IEncode ce = cec.getLeft();
- Map<Integer, Integer> cem = cec.getRight();
+ HashMapLongInt cem = cec.getRight();
assertTrue(cem.size() == 10);
assertTrue(cem.size() == ce.getUnique());
assertTrue(ce.equals(new DenseEncoding(MapToFactory.create(new
int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 10))));
@@ -52,9 +51,9 @@ public class CombineEncodings {
public void combineCustom2() {
IEncode ae = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 8, 8}, 10));
IEncode be = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 10));
- Pair<IEncode, Map<Integer, Integer>> cec =
ae.combineWithMap(be);
+ Pair<IEncode, HashMapLongInt> cec = ae.combineWithMap(be);
IEncode ce = cec.getLeft();
- Map<Integer, Integer> cem = cec.getRight();
+ HashMapLongInt cem = cec.getRight();
assertTrue(cem.size() == 10);
assertTrue(cem.size() == ce.getUnique());
assertTrue(ce.equals(new DenseEncoding(MapToFactory.create(new
int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 10))));
@@ -65,9 +64,9 @@ public class CombineEncodings {
public void combineCustom3() {
IEncode ae = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 7, 8}, 10));
IEncode be = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 7, 9}, 10));
- Pair<IEncode, Map<Integer, Integer>> cec =
ae.combineWithMap(be);
+ Pair<IEncode, HashMapLongInt> cec = ae.combineWithMap(be);
IEncode ce = cec.getLeft();
- Map<Integer, Integer> cem = cec.getRight();
+ HashMapLongInt cem = cec.getRight();
assertTrue(cem.size() == 9);
assertTrue(cem.size() == ce.getUnique());
assertTrue(ce.equals(new DenseEncoding(MapToFactory.create(new
int[] {0, 1, 2, 3, 4, 5, 6, 7, 7, 8}, 9))));
@@ -76,11 +75,12 @@ public class CombineEncodings {
@Test
public void combineCustom4() {
- IEncode ae = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 10));
- IEncode be = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 10));
- Pair<IEncode, Map<Integer, Integer>> cec =
ae.combineWithMap(be);
+ // same mapping require the unique to be correct!!
+ IEncode ae = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 8));
+ IEncode be = new DenseEncoding(MapToFactory.create(new int[]
{0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 8));
+ Pair<IEncode, HashMapLongInt> cec = ae.combineWithMap(be);
IEncode ce = cec.getLeft();
- Map<Integer, Integer> cem = cec.getRight();
+ HashMapLongInt cem = cec.getRight();
assertTrue(cem.size() == 8);
assertTrue(cem.size() == ce.getUnique());
assertTrue(ce.equals(new DenseEncoding(MapToFactory.create(new
int[] {0, 1, 2, 3, 4, 5, 6, 7, 7, 0}, 8))));
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodingsUnique.java
b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodingsUnique.java
index bff3ba1709..9c1f64e0f9 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodingsUnique.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineEncodingsUnique.java
@@ -26,7 +26,6 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
-import java.util.Map;
import java.util.Random;
import java.util.Set;
@@ -42,6 +41,7 @@ import
org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -99,10 +99,9 @@ public class CombineEncodingsUnique {
public void combineUnique() {
try {
- Pair<IEncode, Map<Integer, Integer>> cec =
ae.combineWithMap(be);
+ Pair<IEncode, HashMapLongInt> cec =
ae.combineWithMap(be);
IEncode ce = cec.getLeft();
- Map<Integer, Integer> cem = cec.getRight();
- // LOG.error(ae + "\n" + be + "\n" + ce + "\n" + cem);
+ HashMapLongInt cem = cec.getRight();
assertEquals(cem.size(), ce.getUnique());
// check all unique values are contained.
checkContainsAllUnique(ce);
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java
index 119412fbb6..eb4bd4b41f 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java
@@ -25,9 +25,6 @@ import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
-import java.util.HashMap;
-import java.util.Map;
-
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -44,8 +41,10 @@ import
org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
+import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
@@ -78,8 +77,8 @@ public class CombineTest {
IDictionary a = Dictionary.create(new double[] {1.2});
IDictionary b = Dictionary.create(new double[] {1.4});
- Map<Integer, Integer> filter = new HashMap<>();
- filter.put(0, 0);
+ HashMapLongInt filter = new HashMapLongInt(3);
+ filter.putIfAbsent(0, 0);
IDictionary c =
DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter);
assertEquals(c.getValue(0, 0, 2), 1.2, 0.0);
@@ -97,8 +96,7 @@ public class CombineTest {
IDictionary a = Dictionary.create(new double[] {1.2});
IDictionary b = Dictionary.create(new double[] {1.4});
- Map<Integer, Integer> filter = new HashMap<>();
- // filter.put(0, 0);
+ HashMapLongInt filter = new HashMapLongInt(3);
IDictionary c =
DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter);
assertEquals(c, null);
@@ -175,8 +173,8 @@ public class CombineTest {
try {
IDictionary a = Dictionary.create(new double[] {1.2,
1.3});
IDictionary b = Dictionary.create(new double[] {1.4,
1.5});
- Map<Integer, Integer> filter = new HashMap<>();
- filter.put(0,0);
+ HashMapLongInt filter = new HashMapLongInt(3);
+ filter.putIfAbsent(0, 0);
IDictionary c =
DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter);
@@ -196,14 +194,13 @@ public class CombineTest {
}
}
-
@Test
public void twoBothSidesFilter2() {
try {
IDictionary a = Dictionary.create(new double[] {1.2,
1.3});
IDictionary b = Dictionary.create(new double[] {1.4,
1.5});
- Map<Integer, Integer> filter = new HashMap<>();
- filter.put(3,0);
+ HashMapLongInt filter = new HashMapLongInt(3);
+ filter.putIfAbsent(3, 0);
IDictionary c =
DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter);
@@ -223,15 +220,14 @@ public class CombineTest {
}
}
-
@Test
public void twoBothSidesFilter3() {
try {
IDictionary a = Dictionary.create(new double[] {1.2,
1.3});
IDictionary b = Dictionary.create(new double[] {1.4,
1.5});
- Map<Integer, Integer> filter = new HashMap<>();
- filter.put(3,0);
- filter.put(1,1);
+ HashMapLongInt filter = new HashMapLongInt(3);
+ filter.putIfAbsent(3, 0);
+ filter.putIfAbsent(1, 1);
IDictionary c =
DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter);
@@ -259,7 +255,7 @@ public class CombineTest {
double[] ad = new double[] {0};
double[] bd = new double[] {0};
- IDictionary c = DictionaryFactory.combineSDC(a, ad, b,
bd);
+ IDictionary c = DictionaryFactory.combineSDCNoFilter(a,
ad, b, bd);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
MatrixBlock exp = new MatrixBlock(4, 2, new double[]
{0, 0, 3, 0, 0, 4, 3, 4});
@@ -279,7 +275,7 @@ public class CombineTest {
double[] ad = new double[] {0};
double[] bd = new double[] {0, 0};
- IDictionary c = DictionaryFactory.combineSDC(a, ad, b,
bd);
+ IDictionary c = DictionaryFactory.combineSDCNoFilter(a,
ad, b, bd);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
MatrixBlock exp = new MatrixBlock(4, 3, new double[]
{0, 0, 0, 3, 0, 0, 0, 4, 4, 3, 4, 4});
@@ -299,7 +295,7 @@ public class CombineTest {
double[] ad = new double[] {1};
double[] bd = new double[] {2};
- IDictionary c = DictionaryFactory.combineSDC(a, ad, b,
bd);
+ IDictionary c = DictionaryFactory.combineSDCNoFilter(a,
ad, b, bd);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
MatrixBlock exp = new MatrixBlock(4, 2, new double[] {//
@@ -323,7 +319,7 @@ public class CombineTest {
double[] ad = new double[] {0, 1};
double[] bd = new double[] {0, 2};
- IDictionary c = DictionaryFactory.combineSDC(a, ad, b,
bd);
+ IDictionary c = DictionaryFactory.combineSDCNoFilter(a,
ad, b, bd);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
MatrixBlock exp = new MatrixBlock(4, 4, new double[] {//
@@ -347,7 +343,7 @@ public class CombineTest {
double[] ad = new double[] {0, 1};
double[] bd = new double[] {0, 2};
- IDictionary c = DictionaryFactory.combineSDC(a, ad, b,
bd);
+ IDictionary c = DictionaryFactory.combineSDCNoFilter(a,
ad, b, bd);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
MatrixBlock exp = new MatrixBlock(6, 4, new double[] {//
@@ -373,7 +369,7 @@ public class CombineTest {
double[] ad = new double[] {0, 1};
double[] bd = new double[] {0, 2};
- IDictionary c = DictionaryFactory.combineSDC(a, ad, b,
bd);
+ IDictionary c = DictionaryFactory.combineSDCNoFilter(a,
ad, b, bd);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
MatrixBlock exp = new MatrixBlock(9, 4, new double[] {//
@@ -532,58 +528,12 @@ public class CombineTest {
DictionaryFactory.combineDictionariesSparse(m, s);
}
- // @Test
- // public void sparseSparseConst1() {
- // try {
- // IDictionary a = Dictionary.create(new double[] {3, 2,
7, 8});
- // // IDictionary b = Dictionary.create(new double[] {4,
4, 9, 5});
-
- // double[] bd = new double[] {0, 2};
-
- // IDictionary c =
DictionaryFactory.combineSparseConstSparseRet(a, 2, bd);
- // MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
-
- // MatrixBlock exp = new MatrixBlock(2, 4, new double[] {//
- // 3, 2, 0, 2, //
- // 7, 8, 0, 2,});
- // TestUtils.compareMatricesBitAvgDistance(ret, exp, 0, 0);
- // }
- // catch(Exception e) {
- // e.printStackTrace();
- // fail(e.getMessage());
- // }
- // }
-
- // @Test
- // public void sparseSparseConst2() {
- // try {
- // IDictionary a = Dictionary.create(new double[] {3, 2,
7, 8});
- // // IDictionary b = Dictionary.create(new double[] {4,
4, 9, 5});
-
- // double[] bd = new double[] {0, 2};
-
- // IDictionary c =
DictionaryFactory.combineSparseConstSparseRet(a, 1, bd);
- // MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
-
- // MatrixBlock exp = new MatrixBlock(4, 3, new double[] {//
- // 3, 0, 2, //
- // 2, 0, 2, //
- // 7, 0, 2, //
- // 8, 0, 2,});
- // TestUtils.compareMatricesBitAvgDistance(exp, ret, 0, 0);
- // }
- // catch(Exception e) {
- // e.printStackTrace();
- // fail(e.getMessage());
- // }
- // }
-
@Test
public void testEmpty() {
try {
IDictionary d = Dictionary.create(new double[] {3, 2,
7, 8});
- AColGroup a =
ColGroupDDC.create(ColIndexFactory.create(2), d, MapToFactory.create(10, 2),
null);
- ColGroupEmpty b = new
ColGroupEmpty(ColIndexFactory.create(4));
+ AColGroup a =
ColGroupDDC.create(ColIndexFactory.createI(1, 2), d, MapToFactory.create(10,
2), null);
+ ColGroupEmpty b = new
ColGroupEmpty(ColIndexFactory.createI(3, 4, 5, 6));
IDictionary c =
DictionaryFactory.combineDictionaries((AColGroupCompressed) a,
(AColGroupCompressed) b);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
@@ -603,9 +553,9 @@ public class CombineTest {
public void combineDictionariesSparse1() {
try {
IDictionary d = Dictionary.create(new double[] {3, 2,
7, 8});
- AColGroup a =
ColGroupSDC.create(ColIndexFactory.create(2), 500, d, new double[] {1, 2},
+ AColGroup a =
ColGroupSDC.create(ColIndexFactory.createI(1, 2), 500, d, new double[] {1, 2},
OffsetFactory.createOffset(new int[] {3, 4}),
MapToFactory.create(10, 2), null);
- ColGroupEmpty b = new
ColGroupEmpty(ColIndexFactory.create(4));
+ ColGroupEmpty b = new
ColGroupEmpty(ColIndexFactory.createI(3, 4, 5, 6));
IDictionary c =
DictionaryFactory.combineDictionariesSparse((AColGroupCompressed) a,
(AColGroupCompressed) b);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
@@ -624,17 +574,19 @@ public class CombineTest {
@Test
public void combineDictionariesSparse2() {
try {
- IDictionary d = Dictionary.create(new double[] {3, 2,
7, 8});
- AColGroup b =
ColGroupSDC.create(ColIndexFactory.create(2), 500, d, new double[] {1, 2},
+ IDictionary d = Dictionary.create(new double[] {//
+ 3, 2, //
+ 7, 8});
+ AColGroup a =
ColGroupSDC.create(ColIndexFactory.createI(1, 2), 500, d, new double[] {1, 2},
OffsetFactory.createOffset(new int[] {3, 4}),
MapToFactory.create(10, 2), null);
- ColGroupEmpty a = new
ColGroupEmpty(ColIndexFactory.create(4));
+ ColGroupEmpty b = new
ColGroupEmpty(ColIndexFactory.createI(3, 4, 5, 6));
IDictionary c =
DictionaryFactory.combineDictionariesSparse((AColGroupCompressed) a,
(AColGroupCompressed) b);
MatrixBlock ret = c.getMBDict(2).getMatrixBlock();
MatrixBlock exp = new MatrixBlock(2, 6, new double[] {//
- 0, 0, 0, 0, 3, 2, //
- 0, 0, 0, 0, 7, 8,});
+ 3, 2, 0, 0, 0, 0, //
+ 7, 8, 0, 0, 0, 0,});
TestUtils.compareMatricesBitAvgDistance(ret, exp, 0, 0);
}
catch(Exception e) {
@@ -647,10 +599,10 @@ public class CombineTest {
public void combineMockingEmpty() {
IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockSDC(ad, ade);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
+ HashMapLongInt m = new HashMapLongInt(10);
IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
assertEquals(red.getNumberOfValues(2), 0);
@@ -658,82 +610,87 @@ public class CombineTest {
@Test
public void combineMockingDefault() {
- IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4});
- double[] ade = new double[] {0};
- AColGroupCompressed a = mockSDC(ad, ade);
- AColGroupCompressed b = mockSDC(ad, ade);
-
- Map<Integer, Integer> m = new HashMap<>();
- m.put(0, 0);
- IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
-
- assertEquals(red.getNumberOfValues(2), 1);
- assertEquals(red, Dictionary.createNoCheck(new double[] {0,
0}));
+ try {
+ IDictionary ad = Dictionary.create(new double[] {1, 2,
3, 4});
+ double[] ade = new double[] {0};
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(0, 0);
+ IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
+ assertEquals(red.getNumberOfValues(2), 1);
+ assertEquals(Dictionary.createNoCheck(new double[] {0,
0}), red);
+ assertEquals(red, Dictionary.createNoCheck(new double[]
{0, 0}));
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
}
@Test
public void combineMockingFirstValue() {
IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockSDC(ad, ade);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
- m.put(1, 0);
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(1, 0);
IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
assertEquals(red.getNumberOfValues(2), 1);
- assertEquals(red, Dictionary.create(new double[] {1, 0}));
+ assertEquals(red, Dictionary.create(new double[] {0, 1}));
}
@Test
public void combineMockingFirstAndDefault() {
IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockSDC(ad, ade);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
- m.put(1, 0);
- m.put(0, 1);
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(1, 0);
+ m.putIfAbsent(0, 1);
IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
assertEquals(red.getNumberOfValues(2), 2);
- assertEquals(red, Dictionary.create(new double[] {1, 0, 0, 0}));
+ assertEquals(red, Dictionary.create(new double[] {0, 1, 0, 0}));
}
@Test
public void combineMockingMixed() {
IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockSDC(ad, ade);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
- m.put(1, 0);
- m.put(0, 1);
- m.put(5, 2);
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(1, 0);
+ m.putIfAbsent(0, 1);
+ m.putIfAbsent(5, 2);
IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
assertEquals(red.getNumberOfValues(2), 3);
- assertEquals(Dictionary.create(new double[] {1, 0, 0, 0, 0,
1}), red);
+ assertEquals(Dictionary.create(new double[] {0, 1, 0, 0, 1,
0}), red);
}
@Test
public void combineMockingMixed2() {
IDictionary ad = Dictionary.create(new double[] {1, 2, 3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockSDC(ad, ade);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
- m.put(1, 0);
- m.put(0, 1);
- m.put(10, 2);
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(1, 0);
+ m.putIfAbsent(0, 1);
+ m.putIfAbsent(10, 2);
IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
assertEquals(red.getNumberOfValues(2), 3);
- assertEquals(Dictionary.create(new double[] {1, 0, 0, 0, 0,
2}), red);
+ assertEquals(Dictionary.create(new double[] {0, 1, 0, 0, 2,
0}), red);
}
@Test
@@ -742,10 +699,10 @@ public class CombineTest {
IDictionary ad = Dictionary.create(new double[] {1, 2,
3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockDDC(ad, 1);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockDDC(ad,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
+ HashMapLongInt m = new HashMapLongInt(10);
IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
assertEquals(0, red.getNumberOfValues(2));
@@ -763,14 +720,14 @@ public class CombineTest {
IDictionary ad = Dictionary.create(new double[] {1, 2,
3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockDDC(ad, 1);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockDDC(ad,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
- m.put(0, 0);
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(0, 0);
IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
assertEquals(1, red.getNumberOfValues(2));
- assertEquals(Dictionary.createNoCheck(new double[] {1,
0}), red);
+ assertEquals(Dictionary.createNoCheck(new double[] {0,
1}), red);
}
catch(Exception e) {
e.printStackTrace();
@@ -784,16 +741,16 @@ public class CombineTest {
IDictionary ad = Dictionary.create(new double[] {1, 2,
3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockDDC(ad, 1);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockDDC(ad,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
- m.put(0, 1);
- m.put(1, 0);
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(0, 1);
+ m.putIfAbsent(1, 0);
IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
assertEquals(2, red.getNumberOfValues(2));
- assertEquals(Dictionary.createNoCheck(new double[] {2,
0, 1, 0}), red);
+ assertEquals(Dictionary.createNoCheck(new double[] {0,
2, 0, 1}), red);
}
catch(Exception e) {
e.printStackTrace();
@@ -807,17 +764,17 @@ public class CombineTest {
IDictionary ad = Dictionary.create(new double[] {1, 2,
3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockDDC(ad, 1);
- AColGroupCompressed b = mockSDC(ad, ade);
+ AColGroupCompressed a = mockDDC(ad,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
- Map<Integer, Integer> m = new HashMap<>();
- m.put(0, 1);
- m.put(1, 0);
- m.put(4, 2);
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(0, 1);
+ m.putIfAbsent(1, 0);
+ m.putIfAbsent(4, 2);
IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
assertEquals(3, red.getNumberOfValues(2));
- assertEquals(Dictionary.createNoCheck(new double[] {2,
0, 1, 0, 1, 1}), red);
+ assertEquals(Dictionary.createNoCheck(new double[] {0,
2, 0, 1, 1, 1}), red);
}
catch(Exception e) {
e.printStackTrace();
@@ -831,18 +788,98 @@ public class CombineTest {
IDictionary ad = Dictionary.create(new double[] {1, 2,
3, 4});
double[] ade = new double[] {0};
- AColGroupCompressed a = mockDDC(ad, 1);
- AColGroupCompressed b = mockSDC(ad, ade);
-
- Map<Integer, Integer> m = new HashMap<>();
- m.put(0, 1);
- m.put(1, 0);
- m.put(5, 2);
- m.put(4, 3);
+ AColGroupCompressed a = mockDDC(ad,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ad, ade,
ColIndexFactory.create(2));
+
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(0, 1);
+ m.putIfAbsent(1, 0);
+ m.putIfAbsent(5, 2);
+ m.putIfAbsent(4, 3);
IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
assertEquals(4, red.getNumberOfValues(2));
- assertEquals(Dictionary.createNoCheck(new double[] {2,
0, 1, 0, 2, 1, 1, 1}), red);
+ assertEquals(Dictionary.createNoCheck(new double[] {0,
2, 0, 1, 1, 2, 1, 1}), red);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void combineFailCase1() {
+ try {
+
+ IDictionary ad = Dictionary.create(new double[] {3, 1,
2});
+ IDictionary ab = Dictionary.create(new double[] {2, 3});
+ double[] ade = new double[] {1};
+ AColGroupCompressed a = mockDDC(ad,
ColIndexFactory.create(1));
+ AColGroupCompressed b = mockSDC(ab, ade,
ColIndexFactory.create(2));
+
+ HashMapLongInt m = new HashMapLongInt(10);
+ // 0=8, 1=7, 2=5, 3=0, 4=6, 5=2, 6=4, 7=1, 8=3
+ m.putIfAbsent(0, 8);
+ m.putIfAbsent(1, 7);
+ m.putIfAbsent(2, 5);
+ m.putIfAbsent(3, 0);
+ m.putIfAbsent(4, 6);
+ m.putIfAbsent(5, 2);
+ m.putIfAbsent(6, 4);
+ m.putIfAbsent(7, 1);
+ m.putIfAbsent(8, 3);
+ IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
+
+ assertEquals(9, red.getNumberOfValues(2));
+ assertEquals(Dictionary.createNoCheck(//
+ new double[] {//
+ 2, 3, //
+ 3, 1, //
+ 2, 2, //
+ 3, 2, //
+ 3, 3, //
+ 1, 2, //
+ 2, 1, //
+ 1, 1, //
+ 1, 3,//
+ }), red);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void combineFailCase2() {
+ try {
+
+ IDictionary ad = Dictionary.create(new double[] {3, 1,
2});
+ IDictionary ab = Dictionary.create(new double[] {2, 3});
+ double[] ade = new double[] {1};
+ AColGroupCompressed a = mockDDC(ad,
ColIndexFactory.createI(1));
+ AColGroupCompressed b = mockSDC(ab, ade,
ColIndexFactory.createI(2));
+
+ HashMapLongInt m = new HashMapLongInt(10);
+ for(int i = 0; i < 9; i++) {
+ m.putIfAbsent(i, i);
+ }
+
+ IDictionary red =
DictionaryFactory.combineDictionaries(a, b, m);
+
+ assertEquals(9, red.getNumberOfValues(2));
+ assertEquals(Dictionary.createNoCheck(//
+ new double[] {//
+ 3, 1, //
+ 1, 1, //
+ 2, 1, //
+ 3, 2, //
+ 1, 2, //
+ 2, 2, //
+ 3, 3, //
+ 1, 3, //
+ 2, 3,//
+ }), red);
}
catch(Exception e) {
e.printStackTrace();
@@ -850,20 +887,87 @@ public class CombineTest {
}
}
- private ASDC mockSDC(IDictionary ad, double[] def) {
+ @Test
+ public void testCombineSDC() {
+ IDictionary ad = Dictionary.create(new double[] {2, 3});
+ IDictionary ab = Dictionary.create(new double[] {1, 2});
+ double[] ade = new double[] {1.0};
+ double[] abe = new double[] {3.0};
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.createI(1));
+ AColGroupCompressed b = mockSDC(ab, abe,
ColIndexFactory.createI(2));
+ HashMapLongInt m = new HashMapLongInt(10);
+ m.putIfAbsent(0, 8);
+ m.putIfAbsent(1, 0);
+ m.putIfAbsent(2, 4);
+ m.putIfAbsent(3, 7);
+ m.putIfAbsent(4, 6);
+ m.putIfAbsent(5, 1);
+ m.putIfAbsent(6, 5);
+ m.putIfAbsent(7, 2);
+ m.putIfAbsent(8, 3);
+
+ IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
+
+ assertEquals(9, red.getNumberOfValues(2));
+ assertEquals(Dictionary.createNoCheck(//
+ new double[] {//
+ 2, 3, //
+ 3, 1, //
+ 2, 2, //
+ 3, 2, //
+ 3, 3, //
+ 1, 2, //
+ 2, 1, //
+ 1, 1, //
+ 1, 3,//
+ }), red);
+ }
+
+ @Test
+ public void testCombineSDCRange() {
+ IDictionary ad = Dictionary.create(new double[] {2, 3});
+ IDictionary ab = Dictionary.create(new double[] {1, 2});
+ double[] ade = new double[] {1.0};
+ double[] abe = new double[] {3.0};
+ AColGroupCompressed a = mockSDC(ad, ade,
ColIndexFactory.createI(1));
+ AColGroupCompressed b = mockSDC(ab, abe,
ColIndexFactory.createI(2));
+ HashMapLongInt m = new HashMapLongInt(10);
+ for(int i = 0; i < 9; i++) {
+ m.putIfAbsent(i, i);
+ }
+ IDictionary red = DictionaryFactory.combineDictionaries(a, b,
m);
+
+ assertEquals(9, red.getNumberOfValues(2));
+ assertEquals(Dictionary.createNoCheck(//
+ new double[] {//
+ 1, 3, //
+ 2, 3, //
+ 3, 3, //
+ 1, 1, //
+ 2, 1, //
+ 3, 1, //
+ 1, 2, //
+ 2, 2, //
+ 3, 2,//
+ }), red);
+ }
+
+ private ASDC mockSDC(IDictionary ad, double[] def, IColIndex c) {
ASDC a = mock(ASDC.class);
when(a.getCompType()).thenReturn(CompressionType.SDC);
when(a.getDictionary()).thenReturn(ad);
when(a.getDefaultTuple()).thenReturn(def);
when(a.getNumCols()).thenReturn(def.length);
+ when(a.getColIndices()).thenReturn(c);
return a;
}
- private ColGroupDDC mockDDC(IDictionary ad, int nCol) {
+ private ColGroupDDC mockDDC(IDictionary ad, IColIndex c) {
ColGroupDDC a = mock(ColGroupDDC.class);
when(a.getCompType()).thenReturn(CompressionType.DDC);
when(a.getDictionary()).thenReturn(ad);
- when(a.getNumCols()).thenReturn(nCol);
+ when(a.getNumCols()).thenReturn(c.size());
+ when(a.getColIndices()).thenReturn(c);
return a;
}
}