This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a47d266dfd [SYSTEMDS-3490] Compressed Transform Encode Hash
a47d266dfd is described below
commit a47d266dfd4d4901b4694b3965502d1ae3915ff3
Author: baunsgaard <[email protected]>
AuthorDate: Tue May 16 11:51:15 2023 +0200
[SYSTEMDS-3490] Compressed Transform Encode Hash
This commit adds compressed transform encode hash and hash binning.
furthermore it optimize the hash binning in general to call in and
utilize the underlying data types of the Frame Arrays, instead of
calling toString().hashCode() on all datatypes.
This also means that the commit breaks backward compatibility, in case
the datatype is not String.
Closes #1825
---
.../colgroup/dictionary/IdentityDictionary.java | 45 +++++--
.../compress/colgroup/mapping/MapToUByte.java | 2 +-
.../sysds/runtime/data/SparseBlockFactory.java | 17 +++
.../sysds/runtime/frame/data/columns/Array.java | 8 ++
.../runtime/frame/data/columns/BitSetArray.java | 5 +
.../runtime/frame/data/columns/BooleanArray.java | 5 +
.../runtime/frame/data/columns/CharArray.java | 5 +
.../runtime/frame/data/columns/DoubleArray.java | 6 +
.../runtime/frame/data/columns/FloatArray.java | 6 +
.../runtime/frame/data/columns/IntegerArray.java | 6 +
.../runtime/frame/data/columns/LongArray.java | 6 +
.../runtime/frame/data/columns/OptionalArray.java | 9 ++
.../runtime/frame/data/columns/StringArray.java | 7 +
.../runtime/transform/encode/ColumnEncoderBin.java | 20 +--
.../transform/encode/ColumnEncoderComposite.java | 11 ++
.../transform/encode/ColumnEncoderFeatureHash.java | 42 +++---
.../runtime/transform/encode/CompressedEncode.java | 133 ++++++++++++++-----
src/test/java/org/apache/sysds/test/TestUtils.java | 144 +++++++++++++++++++++
.../component/frame/array/FrameArrayTests.java | 5 +-
.../transform/TransformCompressedTestMultiCol.java | 10 ++
.../TransformCompressedTestSingleCol.java | 30 +++--
21 files changed, 446 insertions(+), 76 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
index d80a5cca62..996d1f1b4d 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
@@ -39,24 +39,45 @@ import
org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+/**
+ * A specialized dictionary that exploits the fact that the contained
dictionary is an Identity Matrix.
+ */
public class IdentityDictionary extends ADictionary {
private static final long serialVersionUID = 2535887782150955098L;
+ /** The number of rows or columns, rows can be +1 if withEmpty is set.
*/
protected final int nRowCol;
-
+ /** Specify if the Identity matrix should contain an empty row in the
end. */
+ protected final boolean withEmpty;
+ /** A Cache to contain a materialized version of the identity matrix. */
protected SoftReference<MatrixBlockDictionary> cache = null;
/**
- * Create a Identity matrix dictionary. It behaves as if allocated a
Sparse Matrix block but exploits that the
+ * Create an identity matrix dictionary. It behaves as if allocated a
Sparse Matrix block but exploits that the
* structure is known to have certain properties.
*
- * @param nRowCol the number of rows and columns in this identity
matrix.
+ * @param nRowCol The number of rows and columns in this identity
matrix.
*/
public IdentityDictionary(int nRowCol) {
if(nRowCol <= 0)
throw new DMLCompressionException("Invalid Identity
Dictionary");
this.nRowCol = nRowCol;
+ this.withEmpty = false;
+ }
+
+ /**
+ * Create an identity matrix dictionary, It behaves as if allocated a
Sparse Matrix block but exploits that the
+ * structure is known to have certain properties.
+ *
+ * @param nRowCol The number of rows and columns in this identity
matrix.
+ * @param withEmpty If the matrix should contain an empty row in the
end.
+ */
+ public IdentityDictionary(int nRowCol, boolean withEmpty) {
+ if(nRowCol <= 0)
+ throw new DMLCompressionException("Invalid Identity
Dictionary");
+ this.nRowCol = nRowCol;
+ this.withEmpty = withEmpty;
}
@Override
@@ -65,7 +86,7 @@ public class IdentityDictionary extends ADictionary {
// LOG.warn("Should not call getValues on Identity Dictionary");
// double[] ret = new double[nRowCol * nRowCol];
// for(int i = 0; i < nRowCol; i++) {
- // ret[(i * nRowCol) + i] = 1;
+ // ret[(i * nRowCol) + i] = 1;
// }
// return ret;
}
@@ -222,7 +243,7 @@ public class IdentityDictionary extends ADictionary {
@Override
public int getNumberOfValues(int ncol) {
- return nRowCol;
+ return nRowCol + (withEmpty ? 1 : 0);
}
@Override
@@ -399,9 +420,17 @@ public class IdentityDictionary extends ADictionary {
}
private MatrixBlockDictionary createMBDict() {
- final SparseBlock sb =
SparseBlockFactory.createIdentityMatrix(nRowCol);
- final MatrixBlock identity = new MatrixBlock(nRowCol, nRowCol,
nRowCol, sb);
- return new MatrixBlockDictionary(identity);
+ if(withEmpty) {
+ final SparseBlock sb =
SparseBlockFactory.createIdentityMatrixWithEmptyRow(nRowCol);
+ final MatrixBlock identity = new MatrixBlock(nRowCol +
1, nRowCol, nRowCol, sb);
+ return new MatrixBlockDictionary(identity);
+ }
+ else {
+
+ final SparseBlock sb =
SparseBlockFactory.createIdentityMatrix(nRowCol);
+ final MatrixBlock identity = new MatrixBlock(nRowCol,
nRowCol, nRowCol, sb);
+ return new MatrixBlockDictionary(identity);
+ }
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java
index f7a9091689..5df496a517 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java
@@ -117,7 +117,7 @@ public class MapToUByte extends MapToByte {
public int[] getCounts(int[] ret) {
for(int i = 0; i < _data.length; i++)
ret[_data[i]]++;
- return ret;
+ return ret;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
index 18bd68489e..66f07ab6ad 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
@@ -100,4 +100,21 @@ public abstract class SparseBlockFactory
return new SparseBlockCSR(rowPtr, colIdx, vals, nnz);
}
+
+ public static SparseBlock createIdentityMatrixWithEmptyRow(int nRowCol){
+ final int[] rowPtr = new int[nRowCol+2];
+ final int[] colIdx = new int[nRowCol];
+ final double[] vals = new double[nRowCol];
+ int nnz = nRowCol;
+
+ for(int i = 0; i < nRowCol; i++){
+ rowPtr[i] = i;
+ colIdx[i] = i;
+ vals[i] = 1;
+ }
+ // add last index for row pointers.
+ rowPtr[nRowCol] = nRowCol;
+ rowPtr[nRowCol+1] = nRowCol;
+ return new SparseBlockCSR(rowPtr, colIdx, vals, nnz);
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
index d5fbda7874..9111550725 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
@@ -570,6 +570,14 @@ public abstract class Array<T> implements Writable {
return this.getClass().getSimpleName();
}
+ /**
+ * Hash the given index of the array.
+ * It is allowed to return NaN on null elements.
+ *
+ * @param idx The index to hash
+ * @return The hash value of that index.
+ */
+ public abstract double hashDouble(int idx);
public ArrayIterator getIterator(){
return new ArrayIterator();
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
index 70dbde179e..5eed5ce3e0 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
@@ -538,6 +538,11 @@ public class BitSetArray extends ABooleanArray {
return sb.toString();
}
+ @Override
+ public double hashDouble(int idx){
+ return get(idx) ? 1.0 : 0.0;
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size + 10);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
index ba91f96d5d..e74f8bcd65 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
@@ -338,6 +338,11 @@ public class BooleanArray extends ABooleanArray {
(Boolean.parseBoolean(value) || value.equals("1") ||
value.equals("1.0") || value.equals("t"));
}
+ @Override
+ public double hashDouble(int idx){
+ return get(idx) ? 1.0 : 0.0;
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 2 + 10);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
index 73389beb9c..b08232cde2 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
@@ -322,6 +322,11 @@ public class CharArray extends Array<Character> {
return _data[i] != 0;
}
+ @Override
+ public double hashDouble(int idx){
+ return Character.hashCode(_data[idx]);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 2 + 15);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
index b529a4287b..de8def92de 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
@@ -374,6 +374,12 @@ public class DoubleArray extends Array<Double> {
return _data[i] != 0.0d;
}
+
+ @Override
+ public double hashDouble(int idx){
+ return Double.hashCode(_data[idx]);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
index e53486506e..4659ec34b4 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
@@ -326,6 +326,12 @@ public class FloatArray extends Array<Float> {
return _data[i] != 0.0f;
}
+
+ @Override
+ public double hashDouble(int idx){
+ return Float.hashCode(_data[idx]);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
index 999025524c..3a3b05da11 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
@@ -331,6 +331,12 @@ public class IntegerArray extends Array<Integer> {
return _data[i] != 0;
}
+
+ @Override
+ public double hashDouble(int idx){
+ return Integer.hashCode(_data[idx]);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
index aa2e1d5961..8b931f3ad5 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
@@ -333,6 +333,12 @@ public class LongArray extends Array<Long> {
return _data[i] != 0;
}
+
+ @Override
+ public double hashDouble(int idx){
+ return Long.hashCode(_data[idx]);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
index fe85c2530a..bf69c4f56d 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
@@ -430,6 +430,15 @@ public class OptionalArray<T> extends Array<T> {
return !_n.isAllTrue();
}
+
+ @Override
+ public double hashDouble(int idx){
+ if(_n.get(idx))
+ return _a.hashDouble(idx);
+ else
+ return Double.NaN;
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
index 4e89dbf548..00be24e0c7 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
@@ -661,6 +661,13 @@ public class StringArray extends Array<String> {
}
}
+ @Override
+ public double hashDouble(int idx){
+ if(_data[idx] != null)
+ return _data[idx].hashCode();
+ else
+ return Double.NaN;
+ }
@Override
public String toString() {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 7051d3f301..895141db07 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -222,7 +222,7 @@ public class ColumnEncoderBin extends ColumnEncoder {
int endRow = getEndIndex(in.getNumRows(), startRow, blockSize);
double[] vals = new double[endRow-startRow];
for(int i = startRow; i < endRow; i++) {
- double inVal = in.getDouble(i, colID - 1);
+ double inVal = in.getDoubleNaN(i, colID - 1);
//FIXME current NaN handling introduces 0s and thus
// impacts the computation of bin boundaries
if(Double.isNaN(inVal))
@@ -404,15 +404,15 @@ public class ColumnEncoderBin extends ColumnEncoder {
sb.append(": ");
sb.append(_colID);
sb.append(" --- Method: " + _binMethod + " num Bin: " +
_numBin);
- if(_binMethod == BinMethod.EQUI_WIDTH){
-
- sb.append("\n---- BinMin: "+
Arrays.toString(_binMins));
- sb.append("\n---- BinMax: "+
Arrays.toString(_binMaxs));
- }
- else{
-
- sb.append(" --- MinMax: "+ _colMins + " " + _colMaxs);
- }
+ // if(_binMethod == BinMethod.EQUI_WIDTH) {
+ sb.append("\n---- BinMin: " + Arrays.toString(_binMins));
+ sb.append("\n---- BinMax: " + Arrays.toString(_binMaxs));
+ // }
+ // else {
+ // // sb.append(" --- MinMax: "+ _colMins + " " + _colMaxs);
+ // sb.append("\n---- BinMin: " + Arrays.toString(_binMins));
+ // sb.append("\n---- BinMax: " + Arrays.toString(_binMaxs));
+ // }
return sb.toString();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 8b9710f71d..6f18263a26 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -446,6 +446,17 @@ public class ColumnEncoderComposite extends ColumnEncoder {
&& _columnEncoders.get(1) instanceof
ColumnEncoderDummycode;
}
+ public boolean isHash() {
+ return _columnEncoders.size() == 1//
+ && _columnEncoders.get(0) instanceof
ColumnEncoderFeatureHash;//
+ }
+
+ public boolean isHashToDummy() {
+ return _columnEncoders.size() == 2//
+ && _columnEncoders.get(0) instanceof
ColumnEncoderFeatureHash//
+ && _columnEncoders.get(1) instanceof
ColumnEncoderDummycode;
+ }
+
private static class ColumnCompositeUpdateDCTask implements
Callable<Object> {
private final ColumnEncoderComposite _encoder;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
index 12e3f80b70..467d16ae7c 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.transform.encode;
import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
@@ -28,6 +29,7 @@ import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -66,29 +68,35 @@ public class ColumnEncoderFeatureHash extends ColumnEncoder
{
@Override
protected double getCode(CacheBlock<?> in, int row) {
- // hash a single row
- String key = in.getString(row, _colID - 1);
- if(key == null)
- return Double.NaN;
- return (key.hashCode() % _K) + 1;
+ if(in instanceof FrameBlock){
+ Array<?> a = ((FrameBlock)in).getColumn(_colID -1);
+ return getCode(a, row);
+ }
+ else{ // default
+ // hash a single row
+ String key = in.getString(row, _colID - 1);
+ if(key == null)
+ return Double.NaN;
+ return (key.hashCode() % _K) + 1;
+ }
+ }
+
+ protected double getCode(Array<?> a, int row){
+ return Math.abs(a.hashDouble(row) % _K + 1);
}
protected double[] getCodeCol(CacheBlock<?> in, int startInd, int
blkSize) {
// hash a block of rows
int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
double codes[] = new double[endInd-startInd];
- for (int i=startInd; i<endInd; i++) {
- String key = in.getString(i, _colID - 1);
- if(key == null || key.isEmpty())
- codes[i-startInd] = Double.NaN;
- else {
- // Calculate non-negative modulo
- //double mod = key.hashCode() % _K > 0 ?
key.hashCode() % _K : _K + key.hashCode() % _K;
- double mod = (key.hashCode() % _K) + 1;
- if (mod < 0)
- mod += _K;
- codes[i - startInd] = mod;
- }
+ if( in instanceof FrameBlock) {
+ Array<?> a = ((FrameBlock) in).getColumn(_colID-1);
+ for(int i = startInd; i < endInd; i++)
+ codes[i - startInd] = getCode(a, i);
+ }
+ else {// default
+ for(int i = startInd; i < endInd; i++)
+ codes[i - startInd] = getCode(in, i);
}
return codes;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
index 8a60aae6ff..63eb81e008 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
@@ -51,6 +51,7 @@ import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.UtilFunctions;
public class CompressedEncode {
protected static final Log LOG =
LogFactory.getLog(CompressedEncode.class.getName());
@@ -141,6 +142,10 @@ public class CompressedEncode {
return bin(c);
else if(c.isBinToDummy())
return binToDummy(c);
+ else if(c.isHash())
+ return hash(c);
+ else if(c.isHashToDummy())
+ return hashToDummy(c);
else
throw new NotImplementedException("Not supporting : " +
c);
}
@@ -149,13 +154,14 @@ public class CompressedEncode {
private AColGroup recodeToDummy(ColumnEncoderComposite c) {
int colId = c._colID;
Array<?> a = in.getColumn(colId - 1);
+ boolean containsNull = a.containsNull();
HashMap<?, Long> map = a.getRecodeMap();
int domain = map.size();
IColIndex colIndexes = ColIndexFactory.create(0, domain);
if(domain == 1)
return ColGroupConst.create(colIndexes, new double[]
{1});
- ADictionary d = new IdentityDictionary(colIndexes.size());
- AMapToData m = createMappingAMapToData(a, map);
+ ADictionary d = new IdentityDictionary(colIndexes.size(),
containsNull);
+ AMapToData m = createMappingAMapToData(a, map, containsNull);
List<ColumnEncoder> r = c.getEncoders();
r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>)
map));
return ColGroupDDC.create(colIndexes, d, m, null);
@@ -174,12 +180,11 @@ public class CompressedEncode {
AMapToData m = binEncode(a, b, containsNull);
AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
- try{
-
+ try {
ret.getNumberNonZeros(a.size());
}
- catch(Exception e){
- throw new DMLRuntimeException("Failed binning \n\n" + a
+ "\n" + b + "\n" + d + "\n" + m,e);
+ catch(Exception e) {
+ throw new DMLRuntimeException("Failed binning \n\n" + a
+ "\n" + b + "\n" + d + "\n" + m, e);
}
return ret;
}
@@ -221,9 +226,8 @@ public class CompressedEncode {
final List<ColumnEncoder> r = c.getEncoders();
final ColumnEncoderBin b = (ColumnEncoderBin) r.get(0);
b.build(in);
-
- IColIndex colIndexes = ColIndexFactory.create(0, b._numBin +
(containsNull ? 1 : 0));
- ADictionary d = new IdentityDictionary(colIndexes.size());
+ IColIndex colIndexes = ColIndexFactory.create(0, b._numBin);
+ ADictionary d = new IdentityDictionary(colIndexes.size(),
containsNull);
AMapToData m = binEncode(a, b, containsNull);
AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
ret.getNumberNonZeros(a.size());
@@ -235,19 +239,22 @@ public class CompressedEncode {
int colId = c._colID;
Array<?> a = in.getColumn(colId - 1);
HashMap<?, Long> map = a.getRecodeMap();
+ boolean containsNull = a.containsNull();
int domain = map.size();
// int domain = c.getDomainSize();
IColIndex colIndexes = ColIndexFactory.create(1);
if(domain == 1)
return ColGroupConst.create(colIndexes, new double[]
{1});
- MatrixBlock incrementing = new MatrixBlock(domain, 1, false);
+ MatrixBlock incrementing = new MatrixBlock(domain +
(containsNull ? 1 : 0) , 1, false);
for(int i = 0; i < domain; i++)
incrementing.quickSetValue(i, 0, i + 1);
+ if(containsNull)
+ incrementing.quickSetValue(domain, 0 , Double.NaN);
ADictionary d = MatrixBlockDictionary.create(incrementing);
- AMapToData m = createMappingAMapToData(a, map);
+ AMapToData m = createMappingAMapToData(a, map, containsNull);
List<ColumnEncoder> r = c.getEncoders();
r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>)
map));
@@ -261,6 +268,7 @@ public class CompressedEncode {
IColIndex colIndexes = ColIndexFactory.create(1);
int colId = c._colID;
Array<?> a = in.getColumn(colId - 1);
+ boolean containsNull = a.containsNull();
HashMap<Object, Long> map = (HashMap<Object, Long>)
a.getRecodeMap();
final int blockSz =
ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE);
if(map.size() >= blockSz) {
@@ -271,36 +279,103 @@ public class CompressedEncode {
return ColGroupUncompressed.create(colIndexes, col,
false);
}
else {
- double[] vals = new double[map.size() +
(a.containsNull() ? 1 : 0)];
- for(int i = 0; i < a.size(); i++) {
- Object v = a.get(i);
- if(map.containsKey(v)) {
- vals[map.get(v).intValue()] =
a.getAsDouble(i);
- }
- else {
- map.put(null, (long) map.size());
- vals[map.get(v).intValue()] =
a.getAsDouble(i);
- }
- }
-
+ double[] vals = new double[map.size() + (containsNull ?
1 : 0)];
+ if(containsNull)
+ vals[map.size()] = Double.NaN;
+ ValueType t = a.getValueType();
+ map.forEach((k,v) -> vals[v.intValue()] =
UtilFunctions.objectToDouble(t,k));
ADictionary d = Dictionary.create(vals);
- AMapToData m = createMappingAMapToData(a, map);
+ AMapToData m = createMappingAMapToData(a, map,
containsNull);
return ColGroupDDC.create(colIndexes, d, m, null);
}
}
- private AMapToData createMappingAMapToData(Array<?> a, HashMap<?, Long>
map) {
- AMapToData m = MapToFactory.create(in.getNumRows(), map.size());
+ private AMapToData createMappingAMapToData(Array<?> a, HashMap<?, Long>
map, boolean containsNull) {
+ final int si = map.size();
+ AMapToData m = MapToFactory.create(in.getNumRows(), si +
(containsNull ? 1 : 0));
Array<?>.ArrayIterator it = a.getIterator();
- while(it.hasNext()) {
- Object v = it.next();
- if(v != null)
+ if(containsNull){
+
+ while(it.hasNext()) {
+ Object v = it.next();
+ if(v != null)
+ m.set(it.getIndex(),
map.get(v).intValue());
+ else
+ m.set(it.getIndex(),si);
+ }
+ }
+ else{
+ while(it.hasNext()) {
+ Object v = it.next();
m.set(it.getIndex(), map.get(v).intValue());
+ }
+ }
+ return m;
+ }
+
+ private AMapToData createHashMappingAMapToData(Array<?> a, int k,
boolean nulls) {
+ AMapToData m = MapToFactory.create(a.size(), k + (nulls ? 1 :
0));
+ if(nulls) {
+ for(int i = 0; i < a.size(); i++) {
+ double h = a.hashDouble(i);
+ if(Double.isNaN(h)) {
+ m.set(i, k);
+ }
+ else {
+ m.set(i, (int) h % k);
+ }
+ }
+ }
+ else {
+ for(int i = 0; i < a.size(); i++) {
+ double h = a.hashDouble(i);
+ m.set(i, (int) h % k);
+ }
}
return m;
}
+ private AColGroup hash(ColumnEncoderComposite c) {
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ ColumnEncoderFeatureHash CEHash = (ColumnEncoderFeatureHash)
c.getEncoders().get(0);
+
+ // HashMap<?, Long> map = a.getRecodeMap();
+ int domain = (int) CEHash.getK();
+ boolean nulls = a.containsNull();
+ IColIndex colIndexes = ColIndexFactory.create(0, 1);
+ if(domain == 1)
+ return ColGroupConst.create(colIndexes, new double[]
{1});
+
+ MatrixBlock incrementing = new MatrixBlock(domain +
(nulls ? 1 : 0), 1, false);
+ for(int i = 0; i < domain; i++)
+ incrementing.quickSetValue(i, 0, i + 1);
+ if(nulls)
+ incrementing.quickSetValue(domain, 0,
Double.NaN);
+
+ ADictionary d =
MatrixBlockDictionary.create(incrementing);
+
+ AMapToData m = createHashMappingAMapToData(a, domain , nulls);
+ AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
+ ret.getNumberNonZeros(a.size());
+ return ret;
+ }
+
+ private AColGroup hashToDummy(ColumnEncoderComposite c) {
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ ColumnEncoderFeatureHash CEHash = (ColumnEncoderFeatureHash)
c.getEncoders().get(0);
+ int domain = (int) CEHash.getK();
+ boolean nulls = a.containsNull();
+ IColIndex colIndexes = ColIndexFactory.create(0, domain);
+ if(domain == 1)
+ return ColGroupConst.create(colIndexes, new double[]
{1});
+ ADictionary d = new IdentityDictionary(colIndexes.size(),
nulls);
+ AMapToData m = createHashMappingAMapToData(a, domain, nulls);
+ return ColGroupDDC.create(colIndexes, d, m, null);
+ }
+
private class EncodeTask implements Callable<AColGroup> {
ColumnEncoderComposite c;
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 65e89f6e6a..bc5d48b8a2 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -52,6 +52,7 @@ import java.util.StringTokenizer;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
@@ -68,8 +69,11 @@ import org.apache.sysds.runtime.data.DenseBlockFP64;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
+import org.apache.sysds.runtime.frame.data.columns.OptionalArray;
+import org.apache.sysds.runtime.frame.data.columns.StringArray;
import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema;
import org.apache.sysds.runtime.frame.data.lib.FrameUtil;
import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -2268,6 +2272,146 @@ public class TestUtils
return generateRandomFrameBlock(rows, schema, random);
}
+ public static FrameBlock generateRandomFrameBlock(int rows, ValueType[]
schema, long seed, double nullChance){
+ Random random = (seed == -1) ? TestUtils.random : new
Random(seed);
+
+ FrameBlock frameBlock = new FrameBlock();
+ for(int col = 0; col < schema.length; col++){
+ Array<?> column = generateColumn(rows, schema[col],
random, nullChance);
+ frameBlock.appendColumn(column);
+ }
+ return frameBlock;
+ }
+
+ @SuppressWarnings("unchecked")
+ private static Array<?> generateColumn(int rows, ValueType type, Random
rand, double nullChance) {
+ if(nullChance == 0) {
+ switch(type) {
+ case BOOLEAN:
+ Array<Boolean> a = (Array<Boolean>)
ArrayFactory.allocate(type, rows);
+ for(int r = 0; r < rows; r++)
+ a.set(r, rand.nextBoolean());
+ return a;
+ case CHARACTER:
+ Array<Character> c = (Array<Character>)
ArrayFactory.allocate(type, rows);
+ for(int r = 0; r < rows; r++)
+ c.set(r,
rand.nextInt(Character.MAX_VALUE));
+ return c;
+ case FP32:
+ Array<Float> f = (Array<Float>)
ArrayFactory.allocate(type, rows);
+ for(int r = 0; r < rows; r++)
+ f.set(r, rand.nextFloat());
+ return f;
+ case FP64:
+ Array<Double> d = (Array<Double>)
ArrayFactory.allocate(type, rows);
+ for(int r = 0; r < rows; r++)
+ d.set(r, rand.nextDouble());
+ return d;
+ case INT32:
+ case UINT4:
+ case UINT8:
+ Array<Integer> i = (Array<Integer>)
ArrayFactory.allocate(type, rows);
+ int limit = type == ValueType.UINT4 ?
16 : type == ValueType.UINT8 ? 256 : Integer.MAX_VALUE;
+ for(int r = 0; r < rows; r++)
+ i.set(r, rand.nextInt(limit));
+ return i;
+ case INT64:
+ Array<Long> l = (Array<Long>)
ArrayFactory.allocate(type, rows);
+ for(int r = 0; r < rows; r++)
+ l.set(r, rand.nextLong());
+ return l;
+ case STRING:
+ StringArray s = (StringArray)
ArrayFactory.allocate(type, rows);
+ for(int r = 0; r < rows; r++) {
+ String st = random.ints('a',
'z' + 1).limit(10)
+
.collect(StringBuilder::new, StringBuilder::appendCodePoint,
StringBuilder::append).toString();
+ s.set(r, st);
+ }
+ return s;
+ case UNKNOWN:
+ default:
+ throw new NotImplementedException();
+ }
+
+ }
+ else {
+ switch(type) {
+ case BOOLEAN:
+ OptionalArray<Boolean> a =
(OptionalArray<Boolean>) ArrayFactory.allocateOptional(type, rows);
+ for(int r = 0; r < rows; r++) {
+ if(rand.nextDouble() <
nullChance)
+ a.set(r, (Boolean)
null);
+ else
+ a.set(r,
rand.nextBoolean());
+ }
+ return a;
+ case CHARACTER:
+ OptionalArray<Character> c =
(OptionalArray<Character>) ArrayFactory.allocateOptional(type, rows);
+ for(int r = 0; r < rows; r++) {
+ if(rand.nextDouble() <
nullChance)
+ c.set(r, (Character)
null);
+ else
+ c.set(r,
rand.nextInt(Character.MAX_VALUE));
+ }
+ return c;
+ case FP32:
+ OptionalArray<Float> f =
(OptionalArray<Float>) ArrayFactory.allocateOptional(type, rows);
+ for(int r = 0; r < rows; r++) {
+ if(rand.nextDouble() <
nullChance)
+ f.set(r, (Float) null);
+ else
+ f.set(r,
rand.nextFloat());
+ }
+ return f;
+ case FP64:
+ OptionalArray<Double> d =
(OptionalArray<Double>) ArrayFactory.allocateOptional(type, rows);
+ for(int r = 0; r < rows; r++) {
+ if(rand.nextDouble() <
nullChance)
+ d.set(r, (Double) null);
+ else
+ d.set(r,
rand.nextDouble());
+ }
+ return d;
+ case INT32:
+ case UINT4:
+ case UINT8:
+ Array<Integer> i = (Array<Integer>)
ArrayFactory.allocateOptional(type, rows);
+ int limit = type == ValueType.UINT4 ?
16 : type == ValueType.UINT8 ? 256 : Integer.MAX_VALUE;
+ for(int r = 0; r < rows; r++) {
+ if(rand.nextDouble() <
nullChance)
+ i.set(r, (Integer)
null);
+ else
+ i.set(r,
rand.nextInt(limit));
+ }
+ return i;
+ case INT64:
+ OptionalArray<Long> l =
(OptionalArray<Long>) ArrayFactory.allocateOptional(type, rows);
+ for(int r = 0; r < rows; r++) {
+ if(rand.nextDouble() <
nullChance)
+ l.set(r, (Long) null);
+ else
+ l.set(r,
rand.nextLong());
+ }
+ return l;
+ case STRING:
+ StringArray s = (StringArray)
ArrayFactory.allocateOptional(type, rows);
+ for(int r = 0; r < rows; r++) {
+ if(rand.nextDouble() <
nullChance)
+ s.set(r, (String) null);
+ else {
+ String st =
random.ints('a', 'z' + 1).limit(10)
+
.collect(StringBuilder::new, StringBuilder::appendCodePoint,
StringBuilder::append).toString();
+ s.set(r, st);
+ }
+ }
+ return s;
+ case UNKNOWN:
+ default:
+ throw new NotImplementedException();
+ }
+ }
+ }
+
public static FrameBlock generateRandomFrameBlock(int rows, int cols,
long seed){
ValueType[] schema = generateRandomSchema(cols, seed);
return generateRandomFrameBlock(rows, schema ,seed);
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
index fae84b39d0..601555576b 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
@@ -1408,7 +1408,7 @@ public class FrameArrayTests {
((Array<Character>)
aa).fill((Character) null);
if(!isOptional)
for(int i = 0; i < aa.size();
i++)
- assertEquals(aa.get(i),
(char)0);
+ assertEquals(aa.get(i),
(char) 0);
break;
case FP32:
((Array<Float>) aa).fill((Float) null);
@@ -1565,8 +1565,7 @@ public class FrameArrayTests {
switch(t) {
case STRING:
return
ArrayFactory.create(generateRandomStringOpt(size, seed));
- case BITSET:
- // return
ArrayFactory.create(generateRandomBitSet(size, seed), size);
+ case BITSET:// not a thing
case BOOLEAN:
return
ArrayFactory.create(generateRandomBooleanOpt(size, seed));
case INT32:
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
index e945ad5b26..2cea03489d 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
@@ -113,6 +113,16 @@ public class TransformCompressedTestMultiCol {
"{ids:true, bin:[{id:1, method:equi-height,
numbins:10},{id:2, method:equi-height, numbins:10},{id:3, method:equi-height,
numbins:40}], dummycode:[1,2,3] }");
}
+ @Test
+ public void testHash(){
+ test("{ids:true, hash:[1,2,3], K:10}");
+ }
+
+ @Test
+ public void testHashToDummy(){
+ test("{ids:true, hash:[1,2,3], K:10, dummycode:[1,2]}");
+ }
+
public void test(String spec) {
try {
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
index 9855b359fa..a573783f6e 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
@@ -56,10 +56,12 @@ public class TransformCompressedTestSingleCol {
try {
FrameBlock data =
TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231);
- data.setSchema(new ValueType[] {ValueType.INT32});
- for(int k : threads) {
+ for(int k : threads)
+ tests.add(new Object[] {data, k});
+
+ data = TestUtils.generateRandomFrameBlock(100, new
ValueType[] {ValueType.UINT4}, 231, 0.2);
+ for(int k : threads)
tests.add(new Object[] {data, k});
- }
}
catch(Exception e) {
e.printStackTrace();
@@ -113,20 +115,32 @@ public class TransformCompressedTestSingleCol {
test("{ids:true, bin:[{id:1, method:equi-height, numbins:10}],
dummycode:[1] }");
}
+ @Test
+ public void testHash() {
+ test("{ids:true, hash:[1], K:10}");
+ }
+
+ @Test
+ public void testHashToDummy() {
+ test("{ids:true, hash:[1], K:10, dummycode:[1]}");
+ }
+
public void test(String spec) {
try {
FrameBlock meta = null;
- MultiColumnEncoder encoderCompressed =
EncoderFactory.createEncoder(spec, data.getColumnNames(),
- data.getNumColumns(), meta);
-
- MatrixBlock outCompressed =
encoderCompressed.encode(data, k, true);
- FrameBlock outCompressedMD =
encoderCompressed.getMetaData(null);
MultiColumnEncoder encoderNormal =
EncoderFactory.createEncoder(spec, data.getColumnNames(),
data.getNumColumns(), meta);
MatrixBlock outNormal = encoderNormal.encode(data, k);
FrameBlock outNormalMD =
encoderNormal.getMetaData(null);
+ MultiColumnEncoder encoderCompressed =
EncoderFactory.createEncoder(spec, data.getColumnNames(),
+ data.getNumColumns(), meta);
+ MatrixBlock outCompressed =
encoderCompressed.encode(data, k, true);
+ FrameBlock outCompressedMD =
encoderCompressed.getMetaData(null);
+ // LOG.error(data.slice(0,10));
+ // LOG.error(outNormal.slice(0,10));
+ // LOG.error(outCompressed.slice(0,10));
TestUtils.compareMatrices(outNormal, outCompressed, 0,
"Not Equal after apply");
TestUtils.compareFrames(outNormalMD, outCompressedMD,
true);
}