This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e96877eea2 [SYSTEMDS-3489] CLA Compress NaN
e96877eea2 is described below
commit e96877eea2d36dd92cdc79146da5774be403ccc8
Author: baunsgaard <[email protected]>
AuthorDate: Mon Jan 23 14:06:30 2023 +0100
[SYSTEMDS-3489] CLA Compress NaN
Currently, the compression framework does not allow one to compress a
MatrixBlock containing NaN values. While the allocated compression block
does not support operations on matrices containing NaN it does not
mean we cannot compress it still.
This commit compress a MatrixBlock containing NaN values, and in the
process of compressing replacing all NaN values with 0 and write a
warning in case a NaN is discovered.
Also fixed in this commit is edge cases of slice in compressed blocks.
where a slice of SDC would slice slices that should be SDCSingle.
Closes #1771
---
.../runtime/compress/colgroup/ColGroupFactory.java | 50 +++++-
.../runtime/compress/colgroup/ColGroupSDC.java | 10 +-
.../runtime/compress/colgroup/ColGroupSDCFOR.java | 13 +-
.../compress/colgroup/ColGroupSDCSingle.java | 48 ++----
.../compress/colgroup/ColGroupSDCSingleZeros.java | 8 +-
.../compress/colgroup/ColGroupSDCZeros.java | 14 +-
.../compress/colgroup/mapping/AMapToData.java | 1 -
.../compress/colgroup/mapping/MapToBit.java | 12 +-
.../runtime/compress/colgroup/offset/AOffset.java | 45 +++++
.../compress/colgroup/offset/OffsetByte.java | 2 +-
.../compress/estim/encoding/EncodingFactory.java | 78 ++++++---
.../runtime/compress/estim/encoding/IEncode.java | 5 +
.../compress/estim/encoding/SparseEncoding.java | 4 -
.../compress/readers/ReaderColumnSelection.java | 9 +
.../ReaderColumnSelectionDenseMultiBlock.java | 12 +-
...erColumnSelectionDenseMultiBlockTransposed.java | 13 +-
.../ReaderColumnSelectionDenseSingleBlock.java | 11 +-
...rColumnSelectionDenseSingleBlockTransposed.java | 15 +-
.../readers/ReaderColumnSelectionSparse.java | 14 +-
.../ReaderColumnSelectionSparseTransposed.java | 24 ++-
.../runtime/compress/utils/DoubleCountHashMap.java | 18 +-
.../component/compress/CompressedCustomTests.java | 52 ++++++
.../component/compress/colgroup/ColGroupTest.java | 13 +-
.../estim/encoding/EncodeSampleCustom.java | 53 ++++++
.../compress/offset/OffsetReverseTest.java | 146 +++++++++++++++++
.../component/compress/readers/ReadersTest.java | 182 +++++++++++++++++++++
.../readers/ReadersTestCompareReaders.java | 10 +-
27 files changed, 756 insertions(+), 106 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 178527d2b2..9eee973a70 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -252,7 +252,7 @@ public class ColGroupFactory {
final boolean t = cs.transposed;
// Fast path compressions
- if(ct == CompressionType.EMPTY && !t)
+ if(ct == CompressionType.EMPTY && (!t ||
isAllNanTransposed(cg)))
return new ColGroupEmpty(colIndexes);
else if(ct == CompressionType.UNCOMPRESSED) // don't construct
mapping if uncompressed
return ColGroupUncompressed.create(colIndexes, in, t);
@@ -323,7 +323,7 @@ public class ColGroupFactory {
if(map.size() == 0)
return new ColGroupEmpty(colIndexes);
-
+
ADictionary dict = DictionaryFactory.create(map);
final int nUnique = map.size();
final AMapToData resData = MapToFactory.resize(d, nUnique);
@@ -704,6 +704,52 @@ public class ColGroupFactory {
}
}
+ private boolean isAllNanTransposed(CompressedSizeInfoColGroup cg) {
+ final int[] cols = cg.getColumns();
+ return in.isInSparseFormat() ? isAllNanTransposedSparse(cols) :
isAllNanTransposedDense(cols);
+ }
+
+ private boolean isAllNanTransposedSparse(int[] cols) {
+ SparseBlock sb = in.getSparseBlock();
+ for(int c : cols){
+ if(sb.isEmpty(c))
+ continue;
+ double[] vl = sb.values(c);
+ for(double v : vl){
+ if(!Double.isNaN(v))
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private boolean isAllNanTransposedDense(int[] cols) {
+ if(in.getDenseBlock().isContiguous()){
+ double[] vals = in.getDenseBlockValues();
+ for(int c : cols){
+ int off = c *nRow;
+ for(int r = 0; r < nRow; r++ ){
+ if(!Double.isNaN(vals[off + r])){
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+ else{
+ DenseBlock db = in.getDenseBlock();
+ for(int c : cols){
+ double[] vals = db.values(c);
+ int off = db.pos(c);
+ for(int r = 0; r < nRow; r++ ){
+ if(!Double.isNaN(vals[off + r]))
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+
private class CompressTask implements Callable<Object> {
private final List<CompressedSizeInfoColGroup> _groups;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index d1c12f535c..7e928e668f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -29,6 +29,7 @@ import
org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
+import
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
@@ -40,6 +41,7 @@ import
org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -86,6 +88,10 @@ public class ColGroupSDC extends ASDC implements
AMapToDataGroup {
return ColGroupSDCSingle.create(colIndices, numRows,
dict, defaultTuple, offsets, null);
else if(allZero)
return ColGroupSDCZeros.create(colIndices, numRows,
dict, offsets, data, cachedCounts);
+ else if(data.getUnique() == 1){
+ MatrixBlock mb =
dict.getMBDict(colIndices.length).getMatrixBlock().slice(0,0);
+ return ColGroupSDCSingle.create(colIndices, numRows,
MatrixBlockDictionary.create(mb), defaultTuple, offsets, null);
+ }
else
return new ColGroupSDC(colIndices, numRows, dict,
defaultTuple, offsets, data, cachedCounts);
}
@@ -559,11 +565,13 @@ public class ColGroupSDC extends ASDC implements
AMapToDataGroup {
@Override
public AColGroup sliceRows(int rl, int ru) {
+ if(ru > _numRows)
+ throw new DMLRuntimeException("Invalid row range");
OffsetSliceInfo off = _indexes.slice(rl, ru);
if(off.lIndex == -1)
return ColGroupConst.create(_colIndexes,
Dictionary.create(_defaultTuple));
AMapToData newData = _data.slice(off.lIndex, off.uIndex);
- return new ColGroupSDC(_colIndexes, _numRows, _dict,
_defaultTuple, off.offsetSlice, newData, null);
+ return create(_colIndexes, ru - rl, _dict, _defaultTuple,
off.offsetSlice, newData, null);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
index 276d5703aa..b306c09862 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
@@ -70,7 +70,10 @@ public class ColGroupSDCFOR extends ASDC implements
AMapToDataGroup {
private ColGroupSDCFOR(int[] colIndices, int numRows, ADictionary dict,
AOffset indexes, AMapToData data,
int[] cachedCounts, double[] reference) {
super(colIndices, numRows, dict, indexes, cachedCounts);
- if(data.getUnique() !=
dict.getNumberOfValues(colIndices.length))
+ // allow for now 1 data unique.
+ if(data.getUnique() == 1)
+ LOG.warn("SDCFor unique is 1, indicate it should have
been SDCSingle please add support");
+ else if(data.getUnique() !=
dict.getNumberOfValues(colIndices.length))
throw new DMLCompressionException("Invalid construction
of SDCZero group");
_data = data;
_reference = reference;
@@ -85,8 +88,10 @@ public class ColGroupSDCFOR extends ASDC implements
AMapToDataGroup {
return ColGroupConst.create(colIndexes, reference);
else if(allZero)
return ColGroupSDCZeros.create(colIndexes, numRows,
dict, offsets, data, cachedCounts);
- else
- return new ColGroupSDCFOR(colIndexes, numRows, dict,
offsets, data, cachedCounts, reference);
+ // else if(data.getUnique() == 1){
+ // TODO add support for changing to SDCSINGLE.
+ // }
+ return new ColGroupSDCFOR(colIndexes, numRows, dict, offsets,
data, cachedCounts, reference);
}
public static AColGroup sparsifyFOR(ColGroupSDC g) {
@@ -437,7 +442,7 @@ public class ColGroupSDCFOR extends ASDC implements
AMapToDataGroup {
if(off.lIndex == -1)
return ColGroupConst.create(_colIndexes,
Dictionary.create(_reference));
AMapToData newData = _data.slice(off.lIndex, off.uIndex);
- return new ColGroupSDCFOR(_colIndexes, _numRows, _dict,
off.offsetSlice, newData, null, _reference);
+ return create(_colIndexes, ru - rl, _dict, off.offsetSlice,
newData, null, _reference);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index 60b567996a..23a38bfbe4 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -70,48 +70,24 @@ public class ColGroupSDCSingle extends ASDC {
final boolean allZero = ColGroupUtils.allZero(defaultTuple);
if(dict == null && allZero)
return new ColGroupEmpty(colIndexes);
- else if(dict == null) {
- if(offsets.getSize() * 2 > numRows + 2) {
- AOffset rev = reverse(numRows, offsets);
- return
ColGroupSDCSingleZeros.create(colIndexes, numRows,
Dictionary.create(defaultTuple), rev,
- cachedCounts);
- }
- else
- return new ColGroupSDCSingle(colIndexes,
numRows, null, defaultTuple, offsets, cachedCounts);
+ else if(dict == null && offsets.getSize() * 2 > numRows + 2) {
+ AOffset rev = AOffset.reverse(numRows, offsets);
+ return ColGroupSDCSingleZeros.create(colIndexes,
numRows, Dictionary.create(defaultTuple), rev, cachedCounts);
}
+ else if(dict == null)
+ return new ColGroupSDCSingle(colIndexes, numRows, null,
defaultTuple, offsets, cachedCounts);
else if(allZero)
return ColGroupSDCSingleZeros.create(colIndexes,
numRows, dict, offsets, cachedCounts);
- else {
- if(offsets.getSize() * 2.0 > numRows + 2.0) {
- AOffset rev = reverse(numRows, offsets);
- return new ColGroupSDCSingle(colIndexes,
numRows, null, dict.getValues(), rev, null);
- }
- else
- return new ColGroupSDCSingle(colIndexes,
numRows, dict, defaultTuple, offsets, cachedCounts);
+ else if(offsets.getSize() * 2 > numRows + 2) {
+ AOffset rev = AOffset.reverse(numRows, offsets);
+ return new ColGroupSDCSingle(colIndexes, numRows,
Dictionary.create(defaultTuple), dict.getValues(), rev, null);
}
- }
-
- private static AOffset reverse(int numRows, AOffset offsets) {
- int[] newOff = new int[numRows - offsets.getSize()];
- final AOffsetIterator it = offsets.getOffsetIterator();
- final int last = offsets.getOffsetToLast();
- int i = 0;
- int j = 0;
+ else
+ return new ColGroupSDCSingle(colIndexes, numRows, dict,
defaultTuple, offsets, cachedCounts);
- while(i < last) {
- if(i == it.value()) {
- i++;
- it.next();
- }
- else
- newOff[j++] = i++;
- }
- i++; // last
- while(i < numRows)
- newOff[j++] = i++;
- return OffsetFactory.createOffset(newOff);
}
+
@Override
public CompressionType getCompType() {
return CompressionType.SDC;
@@ -569,7 +545,7 @@ public class ColGroupSDCSingle extends ASDC {
OffsetSliceInfo off = _indexes.slice(rl, ru);
if(off.lIndex == -1)
return ColGroupConst.create(_colIndexes,
Dictionary.create(_defaultTuple));
- return new ColGroupSDCSingle(_colIndexes, _numRows, _dict,
_defaultTuple, off.offsetSlice, null);
+ return create(_colIndexes, ru -rl, _dict, _defaultTuple,
off.offsetSlice, null);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
index bc6756b18c..59fe6ec26d 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
@@ -29,6 +29,7 @@ import
org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
+import
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import
org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo;
@@ -67,6 +68,11 @@ public class ColGroupSDCSingleZeros extends ASDCZero {
int[] cachedCounts) {
if(dict == null)
return new ColGroupEmpty(colIndices);
+ else if(offsets.getSize() * 2 > numRows + 2) {
+ AOffset rev = AOffset.reverse(numRows, offsets);
+ ADictionary empty = MatrixBlockDictionary.create(new
MatrixBlock(1, colIndices.length, true));
+ return ColGroupSDCSingle.create(colIndices, numRows,
empty, dict.getValues(), rev, null);
+ }
else
return new ColGroupSDCSingleZeros(colIndices, numRows,
dict, offsets, cachedCounts);
}
@@ -802,7 +808,7 @@ public class ColGroupSDCSingleZeros extends ASDCZero {
OffsetSliceInfo off = _indexes.slice(rl, ru);
if(off.lIndex == -1)
return null;
- return new ColGroupSDCSingleZeros(_colIndexes, _numRows, _dict,
off.offsetSlice, null);
+ return create(_colIndexes, ru - rl, _dict, off.offsetSlice,
null);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index ae2ecd334e..a678f8cedc 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -24,10 +24,12 @@ import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
+import
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
@@ -56,7 +58,7 @@ import
org.apache.sysds.runtime.matrix.operators.UnaryOperator;
*
* This column group is handy in cases where sparse unsafe operations is
executed on very sparse columns.
*/
-public class ColGroupSDCZeros extends ASDCZero implements AMapToDataGroup{
+public class ColGroupSDCZeros extends ASDCZero implements AMapToDataGroup {
private static final long serialVersionUID = -3703199743391937991L;
/** Pointers to row indexes in the dictionary. Note the dictionary has
one extra entry. */
@@ -75,6 +77,10 @@ public class ColGroupSDCZeros extends ASDCZero implements
AMapToDataGroup{
int[] cachedCounts) {
if(dict == null)
return new ColGroupEmpty(colIndices);
+ else if(data.getUnique() == 1){
+ MatrixBlock mb =
dict.getMBDict(colIndices.length).getMatrixBlock().slice(0,0);
+ return ColGroupSDCSingleZeros.create(colIndices,
numRows, MatrixBlockDictionary.create(mb), offsets, null);
+ }
else
return new ColGroupSDCZeros(colIndices, numRows, dict,
offsets, data, cachedCounts);
}
@@ -90,7 +96,7 @@ public class ColGroupSDCZeros extends ASDCZero implements
AMapToDataGroup{
}
@Override
- public AMapToData getMapToData(){
+ public AMapToData getMapToData() {
return _data;
}
@@ -714,11 +720,13 @@ public class ColGroupSDCZeros extends ASDCZero implements
AMapToDataGroup{
@Override
public AColGroup sliceRows(int rl, int ru) {
+ if(ru > _numRows)
+ throw new DMLRuntimeException("Invalid row range");
OffsetSliceInfo off = _indexes.slice(rl, ru);
if(off.lIndex == -1)
return null;
AMapToData newData = _data.slice(off.lIndex, off.uIndex);
- return new ColGroupSDCZeros(_colIndexes, _numRows, _dict,
off.offsetSlice, newData, null);
+ return create(_colIndexes, ru - rl, _dict, off.offsetSlice,
newData, null);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java
index 4e781cf5bc..a506734a9c 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java
@@ -816,7 +816,6 @@ public abstract class AMapToData implements Serializable {
*/
public abstract AMapToData slice(int l, int u);
-
public abstract AMapToData append(AMapToData t);
public abstract AMapToData appendN(AMapToDataGroup[] d);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java
index a69bb0cb93..01ab2a4bb1 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java
@@ -25,6 +25,7 @@ import java.io.IOException;
import java.util.BitSet;
import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import
org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE;
@@ -52,8 +53,9 @@ public class MapToBit extends AMapToData {
_data = d;
_size = size;
if(_data.isEmpty()) {
- unique = 1;
- LOG.warn("Empty bit set should not happen");
+ // unique = 1;
+ throw new DMLRuntimeException("Empty BitSet should not
happen it should return MapToZero");
+ // LOG.warn("Empty bit set should not happen");
}
}
@@ -324,7 +326,11 @@ public class MapToBit extends AMapToData {
@Override
public AMapToData slice(int l, int u) {
- return new MapToBit(getUnique(), _data.get(l, u), u - l);
+ BitSet s = _data.get(l,u);
+ if(s.isEmpty())
+ return new MapToZero(u-l);
+ else
+ return new MapToBit(getUnique(), s, u - l);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
index e871e1bd27..5af9b46ce9 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
@@ -21,10 +21,12 @@ package org.apache.sysds.runtime.compress.colgroup.offset;
import java.io.DataOutput;
import java.io.IOException;
import java.io.Serializable;
+import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AOffsetsGroup;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
@@ -523,6 +525,36 @@ public abstract class AOffset implements Serializable {
return sb.toString();
}
+ public static AOffset reverse(int numRows, AOffset offsets) {
+ if(numRows < offsets.getOffsetToLast()) {
+ throw new DMLRuntimeException("Invalid number of rows
for reverse");
+ }
+
+ int[] newOff = new int[numRows - offsets.getSize()];
+ final AOffsetIterator it = offsets.getOffsetIterator();
+ final int last = offsets.getOffsetToLast();
+ int i = 0;
+ int j = 0;
+
+ while(i < last) {
+ if(i == it.value()) {
+ i++;
+ it.next();
+ }
+ else
+ newOff[j++] = i++;
+ }
+ i++; // last
+ while(i < numRows)
+ newOff[j++] = i++;
+
+ if(j != newOff.length)
+ throw new DMLRuntimeException(
+ "Not assigned all offsets ... something must be
wrong:\n" + offsets + "\n" + Arrays.toString(newOff));
+ return OffsetFactory.createOffset(newOff);
+
+ }
+
public static final class OffsetSliceInfo {
public final int lIndex;
public final int uIndex;
@@ -533,6 +565,19 @@ public abstract class AOffset implements Serializable {
this.uIndex = u;
this.offsetSlice = off;
}
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("sliceInfo: ");
+ sb.append(lIndex);
+ sb.append("->");
+ sb.append(uIndex);
+ sb.append(" -- ");
+ sb.append(offsetSlice);
+ return sb.toString();
+ }
+
}
protected static class OffsetCache {
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
index f29feae572..08140a2d50 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
@@ -127,7 +127,7 @@ public class OffsetByte extends AOffset {
}
protected OffsetSliceInfo slice(int lowOff, int highOff, int lowValue,
int highValue, int low, int high) {
- int newSize = high - low - 1;
+ int newSize = high - low +1 ;
byte[] newOffsets = Arrays.copyOfRange(offsets, lowOff,
highOff);
AOffset off = new OffsetByte(newOffsets, lowValue, highValue,
newSize, noOverHalf, noZero);
return new OffsetSliceInfo(low, high + 1, off);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
index 2fe6d5e6e5..92a8cea937 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
@@ -22,6 +22,8 @@ package org.apache.sysds.runtime.compress.estim.encoding;
import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
@@ -37,6 +39,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
public interface EncodingFactory {
+ public static final Log LOG =
LogFactory.getLog(EncodingFactory.class.getName());
+
/**
* Encode a list of columns together from the input matrix, as if it is
cocoded.
*
@@ -122,14 +126,18 @@ public interface EncodingFactory {
// Iteration 1, make Count HashMap.
for(int i = off; i < end; i++) // sequential access
- map.increment(vals[i]);
+ if(!Double.isNaN(vals[i]))
+ map.increment(vals[i]);
+ else
+ map.increment(0);
final int nUnique = map.size();
if(nUnique == 1)
return new ConstEncoding(m.getNumColumns());
-
- if(map.getOrDefault(0, -1) > nCol / 4) {
+ else if(nUnique == 0)
+ return new EmptyEncoding();
+ else if(map.getOrDefault(0, -1) > nCol / 4) {
map.replaceWithUIDsNoZero();
final int zeroCount = map.get(0);
final int nV = nCol - zeroCount;
@@ -141,7 +149,10 @@ public interface EncodingFactory {
for(int i = off, r = 0, di = 0; i < end; i++, r++) {
if(vals[i] != 0) {
offsets.appendValue(r);
- d.set(di++, map.get(vals[i]));
+ if(!Double.isNaN(vals[i]))
+ d.set(di++, map.get(vals[i]));
+ else
+ d.set(di++, map.get(0.0));
}
}
@@ -154,8 +165,12 @@ public interface EncodingFactory {
final AMapToData d = MapToFactory.create(nCol, nUnique);
// Iteration 2, make final map
- for(int i = off, r = 0; i < end; i++, r++)
- d.set(r, map.get(vals[i]));
+ for(int i = off, r = 0; i < end; i++, r++) {
+ if(!Double.isNaN(vals[i]))
+ d.set(r, map.get(vals[i]));
+ else
+ d.set(r, map.get(0.0));
+ }
return new DenseEncoding(d);
}
@@ -172,25 +187,30 @@ public interface EncodingFactory {
final int[] aix = sb.indexes(row);
// Iteration 1 of non zero values, make Count HashMap.
- for(int i = apos; i < alen; i++) // sequential of non zero
cells.
- map.increment(avals[i]);
+ for(int i = apos; i < alen; i++) {
+ // sequential of non zero cells.
+ if(!Double.isNaN(avals[i]))
+ map.increment(avals[i]);
- final int nUnique = map.size();
+ }
+ final int nUnique = map.size();
map.replaceWithUIDs();
final int nCol = m.getNumColumns();
- if(alen - apos > nCol / 4) { // return a dense encoding
+ if(nUnique == 0) // only if all NaN
+ return new EmptyEncoding();
+ else if(alen - apos > nCol / 4) { // return a dense encoding
// If the row was full but the overall matrix is sparse.
final int correct = (alen - apos == m.getNumColumns())
? 0 : 1;
final AMapToData d = MapToFactory.create(nCol, nUnique
+ correct);
// Since the dictionary is allocated with zero then we
exploit that here and
// only iterate through non zero entries.
for(int i = apos; i < alen; i++)
- // correction one to assign unique IDs taking
into account zero
- d.set(aix[i], map.get(avals[i]) + correct);
- // the rest is automatically set to zero.
+ if(!Double.isNaN(avals[i])) // correction one
to assign unique IDs taking into account zero
+ d.set(aix[i], map.get(avals[i]) +
correct);
+ // the rest is automatically set to zero.
return new DenseEncoding(d);
}
else { // return a sparse encoding
@@ -199,7 +219,8 @@ public interface EncodingFactory {
// Iteration 2 of non zero values, make either a
IEncode Dense or sparse map.
for(int i = apos, j = 0; i < alen; i++, j++)
- d.set(j, map.get(avals[i]));
+ if(!Double.isNaN(avals[i]))
+ d.set(j, map.get(avals[i]));
// Iteration 3 of non zero indexes, make a Offset
Encoding to know what cells are zero and not.
// not done yet
@@ -221,8 +242,10 @@ public interface EncodingFactory {
// Iteration 1, make Count HashMap.
for(int i = off; i < end; i += nCol) // jump down through rows.
- map.increment(vals[i]);
-
+ if(!Double.isNaN(vals[i]))
+ map.increment(vals[i]);
+ else
+ map.increment(0);
final int nUnique = map.size();
if(nUnique == 1)
return new ConstEncoding(m.getNumColumns());
@@ -236,7 +259,7 @@ public interface EncodingFactory {
final AMapToData d = MapToFactory.create(nV, nUnique -
1);
for(int i = off, r = 0, di = 0; i < end; i += nCol,
r++) {
- if(vals[i] != 0) {
+ if(vals[i] != 0 && !Double.isNaN(vals[i])) {
offsets.appendValue(r);
d.set(di++, map.get(vals[i]));
}
@@ -252,7 +275,10 @@ public interface EncodingFactory {
final AMapToData d = MapToFactory.create(nRow, nUnique);
// Iteration 2, make final map
for(int i = off, r = 0; i < end; i += nCol, r++)
- d.set(r, map.get(vals[i]));
+ if(!Double.isNaN(vals[i]))
+ d.set(r, map.get(vals[i]));
+ else
+ d.set(r, map.get(0));
return new DenseEncoding(d);
}
}
@@ -274,8 +300,11 @@ public interface EncodingFactory {
final int[] aix = sb.indexes(r);
final int index = Arrays.binarySearch(aix, apos, alen,
col);
if(index >= 0) {
- offsets.appendValue(r);
- map.increment(sb.values(r)[index]);
+ final double v = sb.values(r)[index];
+ if(!Double.isNaN(v)) {
+ offsets.appendValue(r);
+ map.increment(sb.values(r)[index]);
+ }
}
}
if(offsets.size() == 0)
@@ -295,8 +324,11 @@ public interface EncodingFactory {
final int[] aix = sb.indexes(r);
// Performance hit because of binary search for each
row.
final int index = Arrays.binarySearch(aix, apos, alen,
col);
- if(index >= 0)
- d.set(off++, map.get(sb.values(r)[index]));
+ if(index >= 0) {
+ final double v = sb.values(r)[index];
+ if(index >= 0 && !Double.isNaN(v))
+ d.set(off++, map.get(v));
+ }
}
// Iteration 3 of non zero indexes, make a Offset Encoding to
know what cells are zero and not.
@@ -369,7 +401,7 @@ public interface EncodingFactory {
return new SparseEncoding(d, o, nRows);
}
- public static SparseEncoding createSparse(AMapToData map, AOffset off,
int nRows){
+ public static SparseEncoding createSparse(AMapToData map, AOffset off,
int nRows) {
return new SparseEncoding(map, off, nRows);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
index 5f15c147ac..26b2201428 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.compress.estim.encoding;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
@@ -27,6 +29,9 @@ import
org.apache.sysds.runtime.compress.estim.EstimationFactors;
* column groups.
*/
public interface IEncode {
+
+ public static final Log LOG =
LogFactory.getLog(IEncode.class.getName());
+
/**
* Combine two encodings, note it should be guaranteed by the caller
that the number of unique multiplied does not
* overflow Integer.
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
index 66af6c7ec5..2b242e0ffb 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.compress.estim.encoding;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
@@ -33,8 +31,6 @@ import org.apache.sysds.runtime.compress.utils.IntArrayList;
/** Most common is zero encoding */
public class SparseEncoding implements IEncode {
- static final Log LOG =
LogFactory.getLog(SparseEncoding.class.getName());
-
/** A map to the distinct values contained */
protected final AMapToData map;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java
index d02e73536b..54d7ff28cf 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java
@@ -30,6 +30,8 @@ public abstract class ReaderColumnSelection {
protected static final Log LOG =
LogFactory.getLog(ReaderColumnSelection.class.getName());
+ protected static boolean nanEncountered = false;
+
protected final int[] _colIndexes;
protected final DblArray reusableReturn;
protected final double[] reusableArr;
@@ -101,4 +103,11 @@ public abstract class ReaderColumnSelection {
else if(rl >= ru)
throw new DMLCompressionException("Invalid inverse
range for reader " + rl + " to " + ru);
}
+
+ protected void warnNaN(){
+ if(!nanEncountered){
+ LOG.warn("NaN value encountered, replaced by 0 in
compression, since nan is not supported");
+ nanEncountered = true;
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java
index 9619b08b2d..0daa79e7fa 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java
@@ -37,8 +37,16 @@ public class ReaderColumnSelectionDenseMultiBlock extends
ReaderColumnSelection
_rl++;
for(int i = 0; i < _colIndexes.length; i++) {
final double v = _data.get(_rl, _colIndexes[i]);
- empty &= v == 0;
- reusableArr[i] = v;
+ boolean isNan = Double.isNaN(v);
+ if(isNan){
+ warnNaN();
+ reusableArr[i] = 0;
+ }
+ else{
+
+ empty &= v == 0 ;
+ reusableArr[i] = v;
+ }
}
}
return empty ? null : reusableReturn;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlockTransposed.java
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlockTransposed.java
index 507fa03191..1273d3fc4d 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlockTransposed.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlockTransposed.java
@@ -36,9 +36,16 @@ public class ReaderColumnSelectionDenseMultiBlockTransposed
extends ReaderColumn
while(empty && _rl < _ru) {
_rl++;
for(int i = 0; i < _colIndexes.length; i++) {
- double v = _data.get(_colIndexes[i], _rl);
- empty &= v == 0;
- reusableArr[i] = v;
+ final double v = _data.get(_colIndexes[i], _rl);
+ boolean isNan = Double.isNaN(v);
+ if(isNan){
+ warnNaN();
+ reusableArr[i] = 0;
+ }
+ else{
+ empty &= v == 0;
+ reusableArr[i] = v;
+ }
}
}
return empty ? null : reusableReturn;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlock.java
index 6a073a7532..c24ab5c17b 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlock.java
@@ -40,8 +40,15 @@ public class ReaderColumnSelectionDenseSingleBlock extends
ReaderColumnSelection
final int indexOff = _rl * _numCols;
for(int i = 0; i < _colIndexes.length; i++) {
double v = _data[indexOff + _colIndexes[i]];
- empty &= v == 0;
- reusableArr[i] = v;
+ boolean isNan = Double.isNaN(v);
+ if(isNan){
+ warnNaN();
+ reusableArr[i] = 0;
+ }
+ else{
+ empty &= v == 0;
+ reusableArr[i] = v;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockTransposed.java
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockTransposed.java
index 291a3de2c9..ebdd2548de 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockTransposed.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockTransposed.java
@@ -26,7 +26,7 @@ public class ReaderColumnSelectionDenseSingleBlockTransposed
extends ReaderColum
private final double[] _data;
protected ReaderColumnSelectionDenseSingleBlockTransposed(MatrixBlock
data, int[] colIndexes, int rl, int ru) {
- super(colIndexes.clone(), rl, Math.min(ru,
data.getNumColumns()) -1 );
+ super(colIndexes.clone(), rl, Math.min(ru,
data.getNumColumns()) - 1);
_data = data.getDenseBlockValues();
for(int i = 0; i < _colIndexes.length; i++)
_colIndexes[i] = _colIndexes[i] * data.getNumColumns();
@@ -34,12 +34,19 @@ public class
ReaderColumnSelectionDenseSingleBlockTransposed extends ReaderColum
protected DblArray getNextRow() {
boolean empty = true;
- while(empty && _rl < _ru ) {
+ while(empty && _rl < _ru) {
_rl++;
for(int i = 0; i < _colIndexes.length; i++) {
final double v = _data[_colIndexes[i] + _rl];
- empty &= v == 0;
- reusableArr[i] = v;
+ boolean isNan = Double.isNaN(v);
+ if(isNan) {
+ warnNaN();
+ reusableArr[i] = 0;
+ }
+ else {
+ empty &= v == 0;
+ reusableArr[i] = v;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparse.java
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparse.java
index 0379b5850b..f9af95b440 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparse.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparse.java
@@ -66,12 +66,20 @@ public class ReaderColumnSelectionSparse extends
ReaderColumnSelection {
int skip = 0;
int j = Arrays.binarySearch(aix, apos, alen, _colIndexes[0]);
if(j < 0)
- j = Math.abs(j+1);
+ j = Math.abs(j + 1);
while(skip < _colIndexes.length && j < alen) {
if(_colIndexes[skip] == aix[j]) {
- reusableArr[skip] = avals[j];
- zeroResult = false;
+ final Double v = avals[j];
+ boolean isNan = Double.isNaN(v);
+ if(isNan) {
+ warnNaN();
+ reusableArr[skip] = 0;
+ }
+ else {
+ reusableArr[skip] = avals[j];
+ zeroResult = false;
+ }
skip++;
j++;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparseTransposed.java
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparseTransposed.java
index b6bfa1b92c..021e1489e5 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparseTransposed.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionSparseTransposed.java
@@ -87,13 +87,23 @@ public class ReaderColumnSelectionSparseTransposed extends
ReaderColumnSelection
_rl = _ru;
return null;
}
+ boolean empty = true;
for(int i = 0; i < _colIndexes.length; i++) {
final int c = _colIndexes[i];
final int sp = sparsePos[i];
final int[] aix = a.indexes(c);
if(aix[sp] == _rl) {
final double[] avals = a.values(c);
- reusableArr[i] = avals[sp];
+ double v = avals[sp];
+ boolean isNan = Double.isNaN(v);
+ if(isNan) {
+ warnNaN();
+ reusableArr[i] = 0;
+ }
+ else {
+ empty = false;
+ reusableArr[i] = avals[sp];
+ }
final int spa = sparsePos[i]++;
final int len = a.size(c) + a.pos(c) - 1;
if(spa >= len || aix[spa] >= _ru) {
@@ -105,7 +115,7 @@ public class ReaderColumnSelectionSparseTransposed extends
ReaderColumnSelection
reusableArr[i] = 0;
}
- return reusableReturn;
+ return empty ? getNextRow(): reusableReturn;
}
private void skipToRow() {
@@ -130,7 +140,15 @@ public class ReaderColumnSelectionSparseTransposed extends
ReaderColumnSelection
final int[] aix = a.indexes(c);
if(aix[sp] == _rl) {
final double[] avals = a.values(c);
- reusableArr[i] = avals[sp];
+ final double v = avals[sp];
+ boolean isNan = Double.isNaN(v);
+ if(isNan) {
+ warnNaN();
+ reusableArr[i] = 0;
+ }
+ else {
+ reusableArr[i] = v;
+ }
if(++sparsePos[i] >= a.size(c) +
a.pos(c))
sparsePos[i] = -1;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java
b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java
index ddf344c71d..15b0f8ebb9 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java
@@ -105,12 +105,18 @@ public class DoubleCountHashMap {
* @return count on key
*/
public int get(double key) {
- int ix = hashIndex(key);
- Bucket l = _data[ix];
- while(!(l.v.key == key))
- l = l.n;
-
- return l.v.count;
+ try{
+ int ix = hashIndex(key);
+ Bucket l = _data[ix];
+ while(!(l.v.key == key))
+ l = l.n;
+
+ return l.v.count;
+ } catch( Exception e){
+ if(Double.isNaN(key))
+ return get(0.0);
+ throw e;
+ }
}
public int getOrDefault(double key, int def) {
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
new file mode 100644
index 0000000000..2816504c98
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.junit.Test;
+
+public class CompressedCustomTests {
+ @Test
+ public void compressNaNDense() {
+ MatrixBlock m = new MatrixBlock(100, 100, Double.NaN);
+
+ MatrixBlock m2 =
CompressedMatrixBlockFactory.compress(m).getLeft();
+
+ for(int i = 0; i < m.getNumRows(); i++)
+ for(int j = 0; j < m.getNumColumns(); j++)
+ assertEquals(0.0, m2.quickGetValue(i, j), 0.0);
+ }
+
+ @Test
+ public void compressNaNSparse() {
+ MatrixBlock m = new MatrixBlock(100, 100, true);
+ for(int i = 0; i < m.getNumRows(); i++)
+ m.setValue(i, i, Double.NaN);
+ assertTrue(m.isInSparseFormat());
+ MatrixBlock m2 =
CompressedMatrixBlockFactory.compress(m).getLeft();
+ for(int i = 0; i < m.getNumRows(); i++)
+ for(int j = 0; j < m.getNumColumns(); j++)
+ assertEquals(0.0, m2.quickGetValue(i, j), 0.0);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
index f609f0bcea..ed25cbc59f 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java
@@ -2156,7 +2156,13 @@ public class ColGroupTest extends ColGroupBase {
}
AColGroup a = base.sliceRows(rl, ru);
+
AColGroup b = other.sliceRows(rl, ru);
+ // LOG.error(rl + " " +ru);
+ // LOG.error(base);
+ // LOG.error(a);
+
+
final int newNRow = ru - rl;
if(a == null || b == null)
@@ -2166,17 +2172,18 @@ public class ColGroupTest extends ColGroupBase {
return;
assertTrue(a.getColIndices() == base.getColIndices());
assertTrue(b.getColIndices() == other.getColIndices());
-
UA_ROW(InstructionUtils.parseBasicAggregateUnaryOperator("uar+", 1), 0,
newNRow, a, b, newNRow);
-
+
int nRow = ru - rl;
MatrixBlock ot = sparseMB(ru - rl, maxCol);
MatrixBlock bt = sparseMB(ru - rl, maxCol);
decompressToSparseBlock(a, b, ot, bt, 0, nRow);
-
+
MatrixBlock otd = denseMB(ru - rl, maxCol);
MatrixBlock btd = denseMB(ru - rl, maxCol);
decompressToDenseBlock(otd, btd, a, b, 0, nRow);
+
UA_ROW(InstructionUtils.parseBasicAggregateUnaryOperator("uar+", 1), 0,
newNRow, a, b, newNRow);
+
}
catch(Exception e) {
e.printStackTrace();
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
index 7d3d9d3a83..9807ca2f91 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.component.compress.estim.encoding;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -39,6 +40,7 @@ import
org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.component.compress.offset.OffsetTests;
import org.junit.Test;
@@ -254,4 +256,55 @@ public class EncodeSampleCustom {
throw new NotImplementedError();
}
}
+
+ private static MatrixBlock getSparseNaNMatrix() {
+ MatrixBlock m = new MatrixBlock(100, 100, true);
+ for(int i = 0; i < m.getNumRows(); i++)
+ m.setValue(i, i, Double.NaN);
+ return m;
+ }
+
+ private static MatrixBlock getDenseNaNMatrix() {
+ return new MatrixBlock(100, 100, Double.NaN);
+
+ }
+
+ @Test
+ public void testNaNDense() {
+ MatrixBlock m = getDenseNaNMatrix();
+ IEncode e = EncodingFactory.createFromMatrixBlock(m, false, 0);
+ // Technically it could have returned a sparse Encoding, but
nulls, are considered values therefore the branching
+ // is saying that there is values in all cells making it chose
a dense encoding.
+ // This is not a problem in the general compression, only in
the edge case a entire or most of a column is null
+ assertTrue(e.isDense());
+ assertEquals(1, e.getUnique());
+ }
+
+
+ @Test
+ public void testNaNDenseTransposed() {
+ MatrixBlock m = getDenseNaNMatrix();
+ IEncode e = EncodingFactory.createFromMatrixBlock(m, true, 0);
+ // Technically it could have returned a sparse Encoding, but
nulls, are considered values therefore the branching
+ // is saying that there is values in all cells making it chose
a dense encoding.
+ // This is not a problem in the general compression, only in
the edge case a entire or most of a column is null
+ assertTrue(e.isDense());
+ assertEquals(1, e.getUnique());
+ }
+
+ @Test
+ public void testNaNSparse() {
+ MatrixBlock m = getSparseNaNMatrix();
+ IEncode e = EncodingFactory.createFromMatrixBlock(m, false, 0);
+ assertTrue(!e.isDense());
+ assertEquals(1, e.getUnique());
+ }
+
+ @Test
+ public void testNaNSparseTransposed() {
+ MatrixBlock m = getSparseNaNMatrix();
+ IEncode e = EncodingFactory.createFromMatrixBlock(m, true, 0);
+ assertTrue(!e.isDense());
+ assertEquals(1, e.getUnique());
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetReverseTest.java
b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetReverseTest.java
new file mode 100644
index 0000000000..23e63bd7ae
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetReverseTest.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.offset;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
+import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
+import org.junit.Test;
+
+public class OffsetReverseTest {
+
+ @Test
+ public void reverse1() {
+ AOffset off = OffsetFactory.createOffset(new int[] {1, 10, 13,
14});
+ AOffset rev = AOffset.reverse(16, off);
+ OffsetTests.compare(rev, new int[] {0, 2, 3, 4, 5, 6, 7, 8, 9,
11, 12, 15});
+ }
+
+ @Test
+ public void reverse2() {
+ AOffset off = OffsetFactory.createOffset(new int[] {1, 10});
+ AOffset rev = AOffset.reverse(16, off);
+ OffsetTests.compare(rev, new int[] {0, 2, 3, 4, 5, 6, 7, 8, 9,
11, 12, 13, 14, 15});
+ }
+
+ @Test
+ public void reverse3() {
+ AOffset off = OffsetFactory.createOffset(new int[] {1});
+ AOffset rev = AOffset.reverse(16, off);
+ OffsetTests.compare(rev, new int[] {0, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15});
+ }
+
+ @Test
+ public void reverse4() {
+ AOffset off = OffsetFactory.createOffset(new int[] {1, 10, 13,
14, 15});
+ AOffset rev = AOffset.reverse(16, off);
+ OffsetTests.compare(rev, new int[] {0, 2, 3, 4, 5, 6, 7, 8, 9,
11, 12});
+ }
+
+ @Test
+ public void reverse4_withCreateMethod() {
+ int[] exp = new int[] {0, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12};
+ AOffset off = OffsetFactory.createOffset(create(exp, 16));
+ AOffset rev = AOffset.reverse(16, off);
+ OffsetTests.compare(rev, exp);
+ }
+
+ @Test
+ public void reverse1_withCreateMethod() {
+ int[] exp = new int[] {0, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 15};
+ AOffset off = OffsetFactory.createOffset(create(exp, 16));
+ AOffset rev = AOffset.reverse(16, off);
+ OffsetTests.compare(rev, exp);
+ }
+
+ @Test
+ public void reverseCreate1() {
+ test(new int[] {100, 132, 520}, 1000);
+ }
+
+ @Test
+ public void reverseCreate2() {
+ test(new int[] {100, 132, 520, 999}, 1000);
+ }
+
+ @Test
+ public void reverseCreate3() {
+ test(new int[] {1, 999}, 1000);
+ }
+
+ @Test
+ public void reverseCreate4() {
+ test(new int[] {256, 512, 999}, 1000);
+ }
+
+ @Test
+ public void reverse() {
+ AOffset off = OffsetFactory
+ .createOffset(
+ new int[] {1, 3, 7, 8, 10, 11, 14, 15, 16, 23,
25, 26, 28, 31, 32, 34, 36, 38, 42, 43, 44, 46, 47, 52, 55,
+ 56, 57, 62, 63, 67, 68, 69, 70, 72, 74,
75, 79, 81, 82, 83, 84, 85, 87, 88, 92, 93, 94, 95, 96, 98, 100,
+ 105, 108, 109, 110, 111, 117, 120, 121,
124, 125, 126, 128, 129, 132, 135, 137, 139, 144, 147, 148, 149,
+ 150, 152, 155, 156, 157, 158, 159, 161,
165, 166, 167, 168, 170, 173, 176, 179, 180, 182, 183, 185, 187,
+ 188, 190, 191, 192, 194, 195, 196, 197,
200, 202, 203, 206, 209, 211, 215, 216, 217, 220, 221, 222, 223,
+ 224, 225, 226, 227, 228, 230, 234, 239,
240, 241, 246, 249, 253, 255, 256, 257, 261, 262, 263, 266, 268,
+ 269, 270, 271, 277, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 292, 293, 294, 297, 299, 302, 305,
+ 308, 313, 314, 318, 319, 323, 324, 325,
329, 330, 331, 332, 333, 338, 339, 341, 342, 343, 344, 345, 346,
+ 347, 350, 351, 352, 354, 355, 356, 358,
362, 363, 365, 367, 373, 374, 375, 376, 379, 380, 381, 382, 384,
+ 385, 387, 388, 390, 391, 392, 395, 397,
401, 402, 405, 406, 407, 411, 415, 416, 418, 419, 420, 423, 424,
+ 426, 427, 428, 429, 431, 435, 436, 438,
439, 440, 441, 445, 446, 447, 450, 451, 452, 456, 458, 461, 462,
+ 464, 465, 467, 468, 469, 470, 477, 481,
484, 485, 487, 488, 489, 494, 495, 500, 504, 505, 506, 508, 510,
+ 512, 513, 517, 518, 520, 524, 525, 526,
527, 528, 529, 531, 532, 534, 538, 540, 543, 544, 546, 548, 551,
+ 553, 554, 556, 560, 562, 563, 564, 567,
569, 570, 571, 575, 577, 578, 579, 581, 582, 585, 586, 587, 589,
+ 592, 593, 594, 598, 600, 605, 607, 613,
615, 617, 618, 623, 624, 629, 630, 632, 633, 634, 635, 636, 637,
+ 638, 639, 641, 644, 645, 646, 649, 651,
652, 654, 657, 659, 663, 664, 669, 671, 672, 673, 677, 678, 679,
+ 680, 683, 684, 685, 686, 687, 691, 692,
694, 696, 698, 700, 702, 705, 706, 713, 715, 720, 722, 723, 724,
+ 728, 729, 730, 733, 735, 736, 737, 739,
740, 741, 742, 743, 744, 745, 746, 747, 750, 751, 752, 758, 762,
+ 763, 764, 767, 768, 771, 772, 775, 776,
778, 779, 781, 785, 788, 789, 791, 792, 793, 794, 797, 804, 806,
+ 807, 809, 810, 811, 812, 813, 815, 816,
818, 819, 820, 821, 822, 824, 825, 827, 831, 833, 834, 835, 837,
+ 838, 839, 840, 841, 843, 848, 849, 851,
852, 853, 859, 862, 863, 864, 865, 866, 870, 871, 872, 873, 874,
+ 875, 877, 879, 880, 882, 883, 887, 889,
891, 892, 894, 896, 897, 898, 899, 901, 902, 903, 905, 906, 908,
+ 911, 912, 913, 917, 919, 920, 922, 926,
927, 931, 933, 935, 936, 938, 940, 941, 943, 944, 945, 946, 948,
+ 950, 953, 957, 959, 961, 967, 968, 974,
979, 980, 982, 983, 984, 986, 989, 990, 991, 993, 995, 996, 997});
+ AOffset rev = AOffset.reverse(1000, off);
+ // System.out.println(off);
+ // System.out.println(rev);
+ assertEquals(0, rev.getOffsetIterator().value());
+ }
+
+ private void test(int[] missing, int max) {
+ AOffset off = OffsetFactory.createOffset(create(missing, max));
+ AOffset rev = AOffset.reverse(max, off);
+ OffsetTests.compare(rev, missing);
+ }
+
+ private static int[] create(int[] missing, int max) {
+ int[] ret = new int[max - missing.length];
+ int j = 0;
+ int k = 0;
+ for(int i = 0; i < max; i++) {
+ if(j < missing.length && missing[j] == i)
+ j++;
+ else
+ ret[k++] = i;
+ }
+ return ret;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java
b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java
index a834f42848..7b90eefede 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java
@@ -20,6 +20,8 @@
package org.apache.sysds.test.component.compress.readers;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.fail;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -132,4 +134,184 @@ public class ReadersTest {
mb.allocateDenseBlock();
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, false,
10, 9);
}
+
+ @Test
+ public void isEmptyNan() {
+ try {
+
+ MatrixBlock mb = new MatrixBlock(10, 5, Double.NaN);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, false, 0,
+ mb.getNumRows());
+ assertEquals(null, reader.nextRow());
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isNaN() {
+ try {
+
+ MatrixBlock mb = new MatrixBlock(10, 5, Double.NaN);
+ mb.setValue(1, 1, 3214);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, false, 0,
+ mb.getNumRows());
+ DblArray a = reader.nextRow();
+ assertNotEquals(null, a);
+ assertEquals(3214.0, a.getData()[1], 0.0);
+ assertEquals(0.0, a.getData()[0], 0.0);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isEmptyNanTransposed() {
+ try {
+
+ MatrixBlock mb = new MatrixBlock(10, 5, Double.NaN);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, true, 0,
+ mb.getNumRows());
+ assertEquals(null, reader.nextRow());
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isNaNTransposed() {
+ try {
+
+ MatrixBlock mb = new MatrixBlock(10, 5, Double.NaN);
+ mb.setValue(1, 1, 3214);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, true, 0,
+ mb.getNumRows());
+ DblArray a = reader.nextRow();
+ assertNotEquals(null, a);
+ assertEquals(3214.0, a.getData()[1], 0.0);
+ assertEquals(0.0, a.getData()[0], 0.0);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isEmptyNanMultiBlock() {
+ try {
+
+ MatrixBlock mb =
ReadersTestCompareReaders.createMock(new MatrixBlock(10, 5, Double.NaN));
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, false, 0,
+ mb.getNumRows());
+ assertEquals(null, reader.nextRow());
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isNaNMultiBlock() {
+ try {
+
+ MatrixBlock mb =
ReadersTestCompareReaders.createMock(new MatrixBlock(10, 5, Double.NaN));
+ mb.setValue(1, 1, 3214);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, false, 0,
+ mb.getNumRows());
+ DblArray a = reader.nextRow();
+ assertNotEquals(null, a);
+ assertEquals(3214.0, a.getData()[1], 0.0);
+ assertEquals(0.0, a.getData()[0], 0.0);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isEmptyNanMultiBlockTransposed() {
+ try {
+
+ MatrixBlock mb =
ReadersTestCompareReaders.createMock(new MatrixBlock(10, 5, Double.NaN));
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, true, 0,
+ mb.getNumRows());
+ assertEquals(null, reader.nextRow());
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isNaNMultiBlockTransposed() {
+ try {
+
+ MatrixBlock mb =
ReadersTestCompareReaders.createMock(new MatrixBlock(10, 5, Double.NaN));
+ mb.setValue(1, 1, 3214);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mb, new int[] {0, 1}, true, 0,
+ mb.getNumRows());
+ DblArray a = reader.nextRow();
+ assertNotEquals(null, a);
+ assertEquals(3214.0, a.getData()[1], 0.0);
+ assertEquals(0.0, a.getData()[0], 0.0);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void isNanSparseBlock() {
+ MatrixBlock mbs = new MatrixBlock(10, 10, true);
+ mbs.setValue(1, 1, 3214);
+ mbs.setValue(0, 0, Double.NaN);
+ mbs.setValue(0, 1, Double.NaN);
+ mbs.setValue(1, 0, Double.NaN);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mbs, new int[] {0, 1}, false, 0,
+ mbs.getNumRows());
+
+ DblArray a = reader.nextRow();
+ assertNotEquals(null, a);
+ assertEquals(3214.0, a.getData()[1], 0.0);
+ assertEquals(0.0, a.getData()[0], 0.0);
+ assertEquals(null, reader.nextRow());
+ }
+
+ @Test
+ public void isNanSparseBlockTransposed() {
+ MatrixBlock mbs = new MatrixBlock(10, 10, true);
+ mbs.setValue(1, 1, 3214);
+ mbs.setValue(0, 0, Double.NaN);
+ mbs.setValue(0, 1, Double.NaN);
+ mbs.setValue(1, 0, Double.NaN);
+
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(mbs, new int[] {0, 1}, true, 0,
+ mbs.getNumRows());
+
+ DblArray a = reader.nextRow();
+ assertNotEquals(null, a);
+ assertEquals(3214.0, a.getData()[1], 0.0);
+ assertEquals(0.0, a.getData()[0], 0.0);
+ assertEquals(null, reader.nextRow());
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTestCompareReaders.java
b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTestCompareReaders.java
index 724ea79a63..f9958f73d0 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTestCompareReaders.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTestCompareReaders.java
@@ -575,9 +575,17 @@ public class ReadersTestCompareReaders {
return subCols;
}
- private class DenseBlockFP64Mock extends DenseBlockFP64 {
+ public static MatrixBlock createMock(MatrixBlock d){
+ DenseBlockFP64 a = new DenseBlockFP64Mock(new
int[]{d.getNumRows(), d.getNumColumns()}, d.getDenseBlockValues());
+ MatrixBlock b = new MatrixBlock(d.getNumRows(),
d.getNumColumns(), a);
+ b.setNonZeros(d.getNumRows() * d.getNumColumns());
+ return b;
+ }
+
+ protected static class DenseBlockFP64Mock extends DenseBlockFP64 {
private static final long serialVersionUID =
-3601232958390554672L;
+
public DenseBlockFP64Mock(int[] dims, double[] data) {
super(dims, data);
}