This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new b3aac0d95b [SYSTEMDS-3592] Frame Compress Sample based
b3aac0d95b is described below
commit b3aac0d95b9e624c0122a69441f9d7c4e02d0296
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Fri Jan 5 12:48:09 2024 +0100
[SYSTEMDS-3592] Frame Compress Sample based
This commit change the frame compression to be sample based,
it also change the detect schema back to be sample based.
Closes #1970
---
.../runtime/frame/data/columns/ABooleanArray.java | 18 +++
.../sysds/runtime/frame/data/columns/Array.java | 124 ++++++++++++++--
.../runtime/frame/data/columns/ArrayFactory.java | 9 +-
.../runtime/frame/data/columns/BitSetArray.java | 4 +-
.../runtime/frame/data/columns/BooleanArray.java | 8 +-
.../runtime/frame/data/columns/CharArray.java | 6 +-
.../sysds/runtime/frame/data/columns/DDCArray.java | 165 ++++++++++++++++-----
.../runtime/frame/data/columns/DoubleArray.java | 10 +-
.../runtime/frame/data/columns/FloatArray.java | 21 ++-
.../runtime/frame/data/columns/HashLongArray.java | 53 ++++++-
.../runtime/frame/data/columns/IntegerArray.java | 8 +-
.../runtime/frame/data/columns/LongArray.java | 8 +-
.../runtime/frame/data/columns/OptionalArray.java | 95 +++++++++++-
.../runtime/frame/data/columns/RaggedArray.java | 4 +-
.../runtime/frame/data/columns/StringArray.java | 87 +++++------
.../data/compress/ArrayCompressionStatistics.java | 12 +-
.../data/compress/CompressedFrameBlockFactory.java | 28 ++--
.../frame/data/lib/FrameLibApplySchema.java | 14 +-
.../frame/data/lib/FrameLibDetectSchema.java | 25 +++-
.../sysds/runtime/frame/data/lib/FrameUtil.java | 4 +-
.../component/frame/FrameSerializationTest.java | 5 +
.../sysds/test/component/frame/FrameUtilTest.java | 92 ++++++++----
.../component/frame/array/CustomArrayTests.java | 20 ++-
.../component/frame/array/FrameArrayTests.java | 7 +-
.../frame/compress/FrameCompressTest.java | 17 +++
.../frame/compress/FrameCompressTestUtils.java | 8 +-
26 files changed, 663 insertions(+), 189 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java
index 206a0722d7..6d2f28d3dd 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.frame.data.columns;
+import java.util.HashMap;
+import java.util.Map;
+
public abstract class ABooleanArray extends Array<Boolean> {
public ABooleanArray(int size) {
@@ -43,4 +46,19 @@ public abstract class ABooleanArray extends Array<Boolean> {
public boolean possiblyContainsNaN(){
return false;
}
+
+ @Override
+ protected Map<Boolean, Long> createRecodeMap() {
+ Map<Boolean, Long> map = new HashMap<>();
+ long id = 1;
+ for(int i = 0; i < size() && id <= 2; i++) {
+ Boolean val = get(i);
+ if(val != null) {
+ Long v = map.putIfAbsent(val, id);
+ if(v == null)
+ id++;
+ }
+ }
+ return map;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
index 11accc814b..d2021872ba 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
@@ -31,6 +31,8 @@ import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Writable;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType;
import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics;
@@ -79,7 +81,8 @@ public abstract class Array<T> implements Writable {
/**
* Get a recode map that maps each unique value in the array, to a long
ID. Null values are ignored, and not included
- * in the mapping. The resulting recode map in stored in a soft
reference to speed up repeated calls to the same column.
+ * in the mapping. The resulting recode map in stored in a soft
reference to speed up repeated calls to the same
+ * column.
*
* @return A recode map
*/
@@ -128,7 +131,8 @@ public abstract class Array<T> implements Writable {
protected Map<T, Integer> getDictionary() {
final Map<T, Integer> dict = new HashMap<>();
Integer id = 0;
- for(int i = 0; i < size(); i++) {
+ final int s = size();
+ for(int i = 0; i < s; i++) {
final T val = get(i);
final Integer v = dict.get(val);
if(v == null)
@@ -138,6 +142,30 @@ public abstract class Array<T> implements Writable {
return dict;
}
+ /**
+ * Get the dictionary of contained values, including null with
threshold.
+ *
+ * If the number of distinct values are found to be above the threshold
value, then abort constructing the
+ * dictionary.
+ *
+ * @return a dictionary containing all unique values or null if
threshold of distinct is exceeded.
+ */
+ protected Map<T, Integer> tryGetDictionary(int threshold) {
+ final Map<T, Integer> dict = new HashMap<>();
+ Integer id = 0;
+ final int s = size();
+ for(int i = 0; i < s && id < threshold; i++) {
+ final T val = get(i);
+ final Integer v = dict.get(val);
+ if(v == null)
+ dict.put(val, id++);
+ }
+ if(id >= threshold)
+ return null;
+ else
+ return dict;
+ }
+
/**
* Get the number of elements in the array, this does not necessarily
reflect the current allocated size.
*
@@ -233,7 +261,7 @@ public abstract class Array<T> implements Writable {
* @param ru row upper (inclusive)
* @param value value array to take values from (same type)
*/
- public void set(int rl, int ru, Array<T> value){
+ public void set(int rl, int ru, Array<T> value) {
for(int i = rl; i <= ru; i++)
set(i, value.get(i));
}
@@ -246,7 +274,7 @@ public abstract class Array<T> implements Writable {
* @param value value array to take values from
* @param rlSrc the offset into the value array to take values from
*/
- public void set(int rl, int ru, Array<T> value, int rlSrc){
+ public void set(int rl, int ru, Array<T> value, int rlSrc) {
for(int i = rl, off = rlSrc; i <= ru; i++, off++)
set(i, value.get(off));
}
@@ -354,7 +382,18 @@ public abstract class Array<T> implements Writable {
*
* @return A better or equivalent value type to represent the column,
including null information.
*/
- public abstract Pair<ValueType, Boolean> analyzeValueType();
+ public final Pair<ValueType, Boolean> analyzeValueType() {
+ return analyzeValueType(size());
+ }
+
+ /**
+ * Analyze the column to figure out if the value type can be refined to
a better type. The return is in two parts,
+ * first the type it can be, second if it contains nulls.
+ *
+ * @param maxCells maximum number of cells to analyze
+ * @return A better or equivalent value type to represent the column,
including null information.
+ */
+ public abstract Pair<ValueType, Boolean> analyzeValueType(int maxCells);
/**
* Get the internal FrameArrayType, to specify the encoding of the
Types, note there are more Frame Array Types than
@@ -405,7 +444,22 @@ public abstract class Array<T> implements Writable {
public abstract boolean possiblyContainsNaN();
+ public Array<?> safeChangeType(ValueType t, boolean containsNull){
+ try{
+ return changeType(t, containsNull);
+ }
+ catch(Exception e){
+ Pair<ValueType, Boolean> ct = analyzeValueType(); //
full analysis
+ return changeType(ct.getKey(), ct.getValue());
+ }
+ }
+
+ public Array<?> changeType(ValueType t, boolean containsNull) {
+ return containsNull ? changeTypeWithNulls(t) : changeType(t);
+ }
+
public Array<?> changeTypeWithNulls(ValueType t) {
+
final ABooleanArray nulls = getNulls();
if(nulls == null)
return changeType(t);
@@ -520,7 +574,7 @@ public abstract class Array<T> implements Writable {
/**
* Change type to a Hash46 array type
*
- * @return A Hash64 array
+ * @return A Hash64 array
*/
protected abstract Array<Object> changeTypeHash64();
@@ -653,6 +707,12 @@ public abstract class Array<T> implements Writable {
}
+ public double[] extractDouble(double[] ret, int rl, int ru) {
+ for(int i = rl; i < ru; i++)
+ ret[i - rl] = getAsDouble(i);
+ return ret;
+ }
+
public abstract boolean equals(Array<T> other);
public ArrayCompressionStatistics statistics(int nSamples) {
@@ -666,6 +726,12 @@ public abstract class Array<T> implements Writable {
else
d.put(key, 1);
}
+ Pair<ValueType, Boolean> vt = analyzeValueType(nSamples);
+ if(vt.getKey() == ValueType.UNKNOWN)
+ vt = analyzeValueType();
+
+ if(vt.getKey() == ValueType.UNKNOWN)
+ vt = new Pair<>(ValueType.STRING, false);
final int[] freq = new int[d.size()];
int id = 0;
@@ -673,18 +739,56 @@ public abstract class Array<T> implements Writable {
freq[id++] = e.getValue();
int estDistinct = SampleEstimatorFactory.distinctCount(freq,
size(), nSamples);
- long memSize = getInMemorySize(); // uncompressed size
- int memSizePerElement = (int) ((memSize * 8L) / size());
+
+ // memory size is different depending on valuetype.
+ long memSize = vt.getKey() != getValueType() ? //
+ ArrayFactory.getInMemorySize(vt.getKey(), size(),
containsNull()) : //
+ getInMemorySize(); // uncompressed size
+
+ int memSizePerElement;
+ switch(vt.getKey()) {
+ case UINT4:
+ case UINT8:
+ case INT32:
+ case FP32:
+ memSizePerElement = 4;
+ break;
+ case INT64:
+ case FP64:
+ case HASH64:
+ memSizePerElement = 8;
+ break;
+ case CHARACTER:
+ memSizePerElement = 2;
+ break;
+ case BOOLEAN:
+ memSizePerElement = 1;
+ case UNKNOWN:
+ case STRING:
+ default:
+ memSizePerElement = (int) (memSize / size());
+ }
long ddcSize = DDCArray.estimateInMemorySize(memSizePerElement,
estDistinct, size());
if(ddcSize < memSize)
return new
ArrayCompressionStatistics(memSizePerElement, //
- estDistinct, true, getValueType(),
FrameArrayType.DDC, memSize, ddcSize);
-
+ estDistinct, true, vt.getKey(), vt.getValue(),
FrameArrayType.DDC, getInMemorySize(), ddcSize);
+ else if(vt.getKey() != getValueType() )
+ return new
ArrayCompressionStatistics(memSizePerElement, //
+ estDistinct, true, vt.getKey(), vt.getValue(),
null, getInMemorySize(), memSize);
return null;
}
+ public AMapToData createMapping(Map<T, Integer> d) {
+ final int s = size();
+ final AMapToData m = MapToFactory.create(s, d.size());
+
+ for(int i = 0; i < s; i++)
+ m.set(i, d.get(get(i)));
+ return m;
+ }
+
public class ArrayIterator implements Iterator<T> {
int index = -1;
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
index ad2a3c2d57..4ea341313f 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
@@ -28,6 +28,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.utils.MemoryEstimates;
public interface ArrayFactory {
@@ -309,8 +310,12 @@ public interface ArrayFactory {
if(src.getFrameArrayType() == FrameArrayType.OPTIONAL)
target = allocateOptional(src.getValueType(),
rlen);
else if(src.getFrameArrayType() == FrameArrayType.DDC) {
- Array<?> ddcDict = ((DDCArray<?>)
src).getDict();
- if(ddcDict.getFrameArrayType() ==
FrameArrayType.OPTIONAL) {
+ final DDCArray<?> ddcA = ((DDCArray<?>) src);
+ final Array<?> ddcDict = ddcA.getDict();
+ if(ddcDict == null){ // read empty dict.
+ target = new DDCArray<>(null,
MapToFactory.create(rlen, ddcA.getMap().getUnique()));
+ }
+ else if(ddcDict.getFrameArrayType() ==
FrameArrayType.OPTIONAL) {
target =
allocateOptional(src.getValueType(), rlen);
}
else {
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
index 710d8a8deb..cd23ce60b6 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
@@ -390,7 +390,7 @@ public class BitSetArray extends ABooleanArray {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(ValueType.BOOLEAN, false);
}
@@ -512,7 +512,7 @@ public class BitSetArray extends ABooleanArray {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size / 64 + 1; i++)
if(_data[i] != 0L)
return false;
return true;
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
index b44845bc34..ae0307ba41 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
@@ -198,7 +198,7 @@ public class BooleanArray extends ABooleanArray {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(ValueType.BOOLEAN, false);
}
@@ -312,7 +312,7 @@ public class BooleanArray extends ABooleanArray {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(_data[i])
return false;
return true;
@@ -320,7 +320,7 @@ public class BooleanArray extends ABooleanArray {
@Override
public boolean isAllTrue() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(!_data[i])
return false;
return true;
@@ -375,7 +375,7 @@ public class BooleanArray extends ABooleanArray {
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 2 + 10);
+ StringBuilder sb = new StringBuilder(_size * 2 + 10);
sb.append(super.toString() + ":[");
for(int i = 0; i < _size - 1; i++)
sb.append((_data[i] ? 1 : 0) + ",");
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
index 14fcfd9f69..f597b8ec62 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
@@ -181,7 +181,7 @@ public class CharArray extends Array<Character> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(ValueType.CHARACTER, false);
}
@@ -308,7 +308,7 @@ public class CharArray extends Array<Character> {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(_data[i] != 0)
return false;
return true;
@@ -357,7 +357,7 @@ public class CharArray extends Array<Character> {
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 2 + 15);
+ StringBuilder sb = new StringBuilder(_size * 2 + 15);
sb.append(super.toString());
sb.append(":[");
for(int i = 0; i < _size - 1; i++) {
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
index 8f3dcd9dcb..3b7200c7be 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
@@ -55,14 +55,53 @@ public class DDCArray<T> extends ACompressedArray<T> {
}
}
- public Array<T> getDict(){
+ public Array<T> getDict() {
return dict;
}
- public AMapToData getMap(){
+ public AMapToData getMap() {
return map;
}
+ public <J> DDCArray<J> setDict(Array<J> dict) {
+ return new DDCArray<J>(dict, map);
+ }
+
+ public DDCArray<T> nullDict() {
+ return new DDCArray<T>(null, map);
+ }
+
+ private static int getTryThreshold(ValueType t, int allRows, long
inMemSize) {
+ switch(t) {
+ case BOOLEAN:
+ return 1; // booleans do not compress well
unless all constant.
+ case UINT4:
+ case UINT8:
+ return 2;
+ case CHARACTER:
+ return 256;
+ case FP32:
+ case INT32:
+ return 65536; // char distinct
+ case HASH64:
+ case FP64:
+ case INT64:
+ case STRING:
+ case UNKNOWN:
+ default:
+ long MapSize =
MapToFactory.estimateInMemorySize(allRows, allRows);
+ int i = 2;
+
+ while(allRows/i >= 1 && inMemSize - MapSize <
ArrayFactory.getInMemorySize(t, allRows/i, false)){
+ i = i*2;
+ }
+
+ int d = Math.max(0, allRows/i);
+ return d;
+
+ }
+ }
+
/**
* Try to compress array into DDC format.
*
@@ -78,10 +117,13 @@ public class DDCArray<T> extends ACompressedArray<T> {
// or if the instance if RaggedArray where all values typically
are unique.
if(s <= 10 || arr instanceof RaggedArray)
return arr;
+ final int t = getTryThreshold(arr.getValueType(), s,
arr.getInMemorySize());
// Two pass algorithm
// 1.full iteration: Get unique
- Map<T, Integer> rcd = arr.getDictionary();
+ Map<T, Integer> rcd = arr.tryGetDictionary(t);
+ if(rcd == null)
+ return arr;
// Abort if there are to many unique values.
if(rcd.size() > s / 2)
@@ -99,18 +141,49 @@ public class DDCArray<T> extends ACompressedArray<T> {
ar.set(e.getValue(), e.getKey());
// 2. full iteration: Make map
- final AMapToData m = MapToFactory.create(s, rcd.size());
- for(int i = 0; i < s; i++)
- m.set(i, rcd.get(arr.get(i)));
+ final AMapToData m = arr.createMapping(rcd);
return new DDCArray<>(ar, m);
}
+ @Override
+ protected Map<T, Long> createRecodeMap() {
+ return dict.createRecodeMap();
+ }
+
+ /**
+ * compress and change value.
+ *
+ * @param <T> The type of the array.
+ * @param arr The array to compress
+ * @param vt The value type to target
+ * @param containsNull If the array contains null.
+ * @return a compressed column group.
+ */
+ public static <T> Array<?> compressToDDC(Array<T> arr, ValueType vt,
boolean containsNull) {
+ Array<?> arrT;
+ try {
+ arrT = containsNull ? arr.changeTypeWithNulls(vt) :
arr.changeType(vt);
+ }
+ catch(Exception e) {
+ // fall back to full analysis.
+ Pair<ValueType, Boolean> ct = arr.analyzeValueType();
+ arrT = ct.getValue() ?
arr.changeTypeWithNulls(ct.getKey()) : arr.changeType(ct.getKey());
+ }
+
+ return compressToDDC(arrT);
+ }
+
@Override
public void write(DataOutput out) throws IOException {
out.writeByte(FrameArrayType.DDC.ordinal());
map.write(out);
- dict.write(out);
+ if(dict == null)
+ out.writeBoolean(false);
+ else{
+ out.writeBoolean(true);
+ dict.write(out);
+ }
}
@Override
@@ -121,25 +194,30 @@ public class DDCArray<T> extends ACompressedArray<T> {
@SuppressWarnings("unchecked")
public static DDCArray<?> read(DataInput in) throws IOException {
AMapToData map = MapToFactory.readIn(in);
- Array<?> dict = ArrayFactory.read(in, map.getUnique());
- switch(dict.getValueType()) {
- case BOOLEAN:
- // Interesting case, that does not make much
sense.
- return new DDCArray<>((Array<Boolean>) dict,
map);
- case FP32:
- return new DDCArray<>((Array<Float>) dict, map);
- case FP64:
- return new DDCArray<>((Array<Double>) dict,
map);
- case UINT8:
- case INT32:
- return new DDCArray<>((Array<Integer>) dict,
map);
- case INT64:
- return new DDCArray<>((Array<Long>) dict, map);
- case CHARACTER:
- return new DDCArray<>((Array<Character>) dict,
map);
- case STRING:
- default:
- return new DDCArray<>((Array<String>) dict,
map);
+ if(in.readBoolean()) {
+ Array<?> dict = ArrayFactory.read(in, map.getUnique());
+ switch(dict.getValueType()) {
+ case BOOLEAN:
+ // Interesting case, that does not make
much sense.
+ return new DDCArray<>((Array<Boolean>)
dict, map);
+ case FP32:
+ return new DDCArray<>((Array<Float>)
dict, map);
+ case FP64:
+ return new DDCArray<>((Array<Double>)
dict, map);
+ case UINT8:
+ case INT32:
+ return new DDCArray<>((Array<Integer>)
dict, map);
+ case INT64:
+ return new DDCArray<>((Array<Long>)
dict, map);
+ case CHARACTER:
+ return new
DDCArray<>((Array<Character>) dict, map);
+ case STRING:
+ default:
+ return new DDCArray<>((Array<String>)
dict, map);
+ }
+ }
+ else {
+ return new DDCArray<>((Array<String>) null, map);
}
}
@@ -148,6 +226,14 @@ public class DDCArray<T> extends ACompressedArray<T> {
return dict.get(map.getIndex(index));
}
+ @Override
+ public double[] extractDouble(double[] ret, int rl, int ru) {
+ // overridden to allow GIT compile
+ for(int i = rl; i < ru; i++)
+ ret[i - rl] = getAsDouble(i);
+ return ret;
+ }
+
@Override
public double getAsDouble(int i) {
return dict.getAsDouble(map.getIndex(i));
@@ -176,22 +262,24 @@ public class DDCArray<T> extends ACompressedArray<T> {
@Override
public ValueType getValueType() {
- return dict.getValueType();
+ return dict == null ? ValueType.STRING : dict.getValueType();
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
- return dict.analyzeValueType();
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
+ return dict.analyzeValueType(maxCells);
}
@Override
protected void set(int rl, int ru, DDCArray<T> value) {
- if(value.dict.size() != dict.size() || (FrameBlock.debug &&
!value.dict.equals(dict)))
+ if((dict != null && value.dict != null)
+ &&( value.dict.size() != dict.size() //
+ || (FrameBlock.debug && !value.dict.equals(dict))))
throw new DMLCompressionException("Invalid setting of
DDC Array, of incompatible instance.");
final AMapToData tm = value.map;
for(int i = rl; i <= ru; i++) {
- map.set(i, tm.getIndex(i-rl));
+ map.set(i, tm.getIndex(i - rl));
}
}
@@ -202,7 +290,7 @@ public class DDCArray<T> extends ACompressedArray<T> {
@Override
public long getExactSerializedSize() {
- return 1L + map.getExactSizeOnDisk() +
dict.getExactSerializedSize();
+ return 1L +1L+ map.getExactSizeOnDisk() +
dict.getExactSerializedSize();
}
@Override
@@ -236,7 +324,7 @@ public class DDCArray<T> extends ACompressedArray<T> {
}
@Override
- protected Array<Object> changeTypeHash64(){
+ protected Array<Object> changeTypeHash64() {
return new DDCArray<>(dict.changeTypeHash64(), map);
}
@@ -250,6 +338,12 @@ public class DDCArray<T> extends ACompressedArray<T> {
return new DDCArray<>(dict.changeTypeCharacter(), map);
}
+ @Override
+ public Array<?> changeTypeWithNulls(ValueType t) {
+ Array<?> d2 = dict.changeTypeWithNulls(t);
+ return new DDCArray<>(d2, map);
+ }
+
@Override
public boolean isShallowSerialize() {
return true; // Always the case if we use this compression
scheme.
@@ -306,7 +400,7 @@ public class DDCArray<T> extends ACompressedArray<T> {
}
public static long estimateInMemorySize(int memSizeBitPerElement, int
estDistinct, int nRow) {
- return (estDistinct * memSizeBitPerElement) / 8 +
MapToFactory.estimateInMemorySize(nRow, estDistinct);
+ return (long)estDistinct * memSizeBitPerElement +
MapToFactory.estimateInMemorySize(nRow, estDistinct);
}
protected DDCArray<T> allocateLarger(int nRow) {
@@ -330,11 +424,10 @@ public class DDCArray<T> extends ACompressedArray<T> {
}
@Override
- public boolean possiblyContainsNaN(){
+ public boolean possiblyContainsNaN() {
return dict.possiblyContainsNaN();
}
-
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
index e4e1a76b6a..68672c5d73 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
@@ -184,9 +184,9 @@ public class DoubleArray extends Array<Double> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
ValueType state = FrameUtil.isType(_data[0]);
- for(int i = 0; i < _size; i++) {
+ for(int i = 0; i < Math.min(maxCells,_size); i++) {
ValueType c = FrameUtil.isType(_data[i], state);
if(state == ValueType.FP64)
return new Pair<>(ValueType.FP64, false);
@@ -250,7 +250,7 @@ public class DoubleArray extends Array<Double> {
@Override
public long getExactSerializedSize() {
- return 1 + 8 * _data.length;
+ return 1 + 8 * _size;
}
@Override
@@ -379,7 +379,7 @@ public class DoubleArray extends Array<Double> {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(isNotEmpty(i))
return false;
return true;
@@ -428,7 +428,7 @@ public class DoubleArray extends Array<Double> {
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
+ StringBuilder sb = new StringBuilder(_size * 5 + 2);
sb.append(super.toString() + ":[");
for(int i = 0; i < _size - 1; i++)
sb.append(_data[i] + ",");
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
index 47627894d9..03709fd14a 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
@@ -34,6 +34,8 @@ import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.MemoryEstimates;
+import ch.randelshofer.fastdoubleparser.JavaFloatParser;
+
public class FloatArray extends Array<Float> {
private float[] _data;
@@ -181,7 +183,7 @@ public class FloatArray extends Array<Float> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(ValueType.FP32, false);
}
@@ -199,7 +201,7 @@ public class FloatArray extends Array<Float> {
@Override
public long getExactSerializedSize() {
- return 1 + 4 * _data.length;
+ return 1 + 4 * _size;
}
@Override
@@ -299,13 +301,16 @@ public class FloatArray extends Array<Float> {
}
public static float parseFloat(String value) {
+ if(value == null)
+ return 0.0f;
+
+ final int len = value.length();
+ if(len == 0)
+ return 0.0f;
try {
- if(value == null || value.isEmpty())
- return 0.0f;
- return Float.parseFloat(value);
+ return JavaFloatParser.parseFloat(value, 0, len);
}
catch(NumberFormatException e) {
- final int len = value.length();
// check for common extra cases.
if(len == 3 && value.compareToIgnoreCase("Inf") == 0)
return Float.POSITIVE_INFINITY;
@@ -322,7 +327,7 @@ public class FloatArray extends Array<Float> {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(isNotEmpty(i))
return false;
return true;
@@ -371,7 +376,7 @@ public class FloatArray extends Array<Float> {
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
+ StringBuilder sb = new StringBuilder(_size * 5 + 2);
sb.append(super.toString() + ":[");
for(int i = 0; i < _size - 1; i++)
sb.append(_data[i] + ",");
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java
index 39d326e0bc..459164b21b 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java
@@ -24,10 +24,14 @@ import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
+import java.util.HashMap;
+import java.util.Map;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.utils.MemoryEstimates;
@@ -212,8 +216,38 @@ public class HashLongArray extends Array<Object> {
return ValueType.HASH64;
}
+ @Override
+ protected Map<Object, Integer> getDictionary() {
+ final Map<Object, Integer> dict = new HashMap<>();
+ Integer id = 0;
+ for(int i = 0; i < size(); i++) {
+ final Integer v = dict.get(_data[i]);
+ if(v == null)
+ dict.put(_data[i], id++);
+ }
+
+ return dict;
+ }
+
+ @Override
+ protected Map<Object, Integer> tryGetDictionary(int threshold) {
+ final Map<Object, Integer> dict = new HashMap<>();
+ Integer id = 0;
+ final int s = size();
+ for(int i = 0; i < s && id < threshold; i++) {
+ final Integer v = dict.get(_data[i]);
+ if(v == null)
+ dict.put(_data[i], id++);
+ }
+
+ if (id >= threshold)
+ return null;
+ else
+ return dict;
+ }
+
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(ValueType.HASH64, false);
}
@@ -231,7 +265,7 @@ public class HashLongArray extends Array<Object> {
@Override
public long getExactSerializedSize() {
- return 1 + 8 * _data.length;
+ return 1 + 8 * _size;
}
@Override
@@ -354,7 +388,7 @@ public class HashLongArray extends Array<Object> {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(_data[i] != 0L)
return false;
return true;
@@ -401,9 +435,20 @@ public class HashLongArray extends Array<Object> {
return false;
}
+ @Override
+ public AMapToData createMapping(Map<Object, Integer> d) {
+ // assuming the dictionary is correctly constructed.
+ final int s = size();
+ final AMapToData m = MapToFactory.create(s, d.size());
+
+ for(int i = 0; i < s; i++)
+ m.set(i, d.get(_data[i]));
+ return m;
+ }
+
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
+ StringBuilder sb = new StringBuilder(_size * 5 + 2);
sb.append(super.toString() + ":[");
for(int i = 0; i < _size - 1; i++)
sb.append(_data[i] + ",");
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
index 4a180e264c..a07e499f9e 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
@@ -181,7 +181,7 @@ public class IntegerArray extends Array<Integer> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(ValueType.INT32, false);
}
@@ -199,7 +199,7 @@ public class IntegerArray extends Array<Integer> {
@Override
public long getExactSerializedSize() {
- return 1 + 4 * _data.length;
+ return 1 + 4 * _size;
}
@Override
@@ -317,7 +317,7 @@ public class IntegerArray extends Array<Integer> {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(_data[i] != 0)
return false;
return true;
@@ -366,7 +366,7 @@ public class IntegerArray extends Array<Integer> {
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
+ StringBuilder sb = new StringBuilder(_size * 5 + 2);
sb.append(super.toString() + ":[");
for(int i = 0; i < _size - 1; i++)
sb.append(_data[i] + ",");
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
index 4d90190f67..ddf724ecf8 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
@@ -181,7 +181,7 @@ public class LongArray extends Array<Long> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(ValueType.INT64, false);
}
@@ -199,7 +199,7 @@ public class LongArray extends Array<Long> {
@Override
public long getExactSerializedSize() {
- return 1 + 8 * _data.length;
+ return 1 + 8 * _size;
}
@Override
@@ -316,7 +316,7 @@ public class LongArray extends Array<Long> {
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(_data[i] != 0L)
return false;
return true;
@@ -365,7 +365,7 @@ public class LongArray extends Array<Long> {
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
+ StringBuilder sb = new StringBuilder(_size * 5 + 2);
sb.append(super.toString() + ":[");
for(int i = 0; i < _size - 1; i++)
sb.append(_data[i] + ",");
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
index 6699f1050a..f653b7f321 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
@@ -22,10 +22,14 @@ package org.apache.sysds.runtime.frame.data.columns;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -64,7 +68,7 @@ public class OptionalArray<T> extends Array<T> {
}
@SuppressWarnings("unchecked")
- public OptionalArray(T[] a, ValueType vt){
+ public OptionalArray(T[] a, ValueType vt) {
super(a.length);
_a = (Array<T>) ArrayFactory.allocate(vt, a.length);
_n = ArrayFactory.allocateBoolean(a.length);
@@ -308,7 +312,7 @@ public class OptionalArray<T> extends Array<T> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
return new Pair<>(getValueType(), true);
}
@@ -474,10 +478,95 @@ public class OptionalArray<T> extends Array<T> {
}
@Override
- public boolean possiblyContainsNaN(){
+ public boolean possiblyContainsNaN() {
return true;
}
+ @SuppressWarnings("unchecked")
+ @Override
+ public AMapToData createMapping(Map<T, Integer> d) {
+ if(_a instanceof HashLongArray) {
+ Map<Long, Integer> dl = (Map<Long, Integer>) d;
+ HashLongArray ha = (HashLongArray) _a;
+ // assuming the dictionary is correctly constructed.
+ final int s = size();
+ final AMapToData m = MapToFactory.create(s, d.size());
+
+ final int n = dl.get(null);
+ for(int i = 0; i < s; i++) {
+ if(_n.get(i)) {
+ m.set(i, dl.get(ha.getLong(i)));
+ }
+ else {
+ m.set(i, n);
+ }
+ }
+ return m;
+ }
+ else {
+ return super.createMapping(d);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ protected Map<T, Integer> getDictionary() {
+ if(_a instanceof HashLongArray) {
+ final Map<Long, Integer> dict = new HashMap<>();
+ HashLongArray ha = (HashLongArray) _a;
+ Integer id = 0;
+ boolean nullFound = false;
+ for(int i = 0; i < size(); i++) {
+ if(_n.get(i)) {
+ final long l = ha.getLong(i);
+ final Integer v =
dict.get(ha.getLong(i));
+ if(v == null)
+ dict.put(l, id++);
+ }
+ else if(!nullFound &&
!dict.keySet().contains(null)) {
+ dict.put(null, id++);
+ nullFound = true;
+ }
+ }
+ return (Map<T, Integer>) dict;
+ }
+ else {
+ return super.getDictionary();
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ protected Map<T, Integer> tryGetDictionary(int threshold) {
+ if(_a instanceof HashLongArray) {
+ final Map<Long, Integer> dict = new HashMap<>();
+ HashLongArray ha = (HashLongArray) _a;
+ Integer id = 0;
+ boolean nullFound = false;
+ final int s = size();
+ for(int i = 0; i < s && id < threshold; i++) {
+ if(_n.get(i)) {
+ final long l = ha.getLong(i);
+ final Integer v =
dict.get(ha.getLong(i));
+ if(v == null)
+ dict.put(l, id++);
+ }
+ else if(!nullFound &&
!dict.keySet().contains(null)) {
+ dict.put(null, id++);
+ nullFound = true;
+ }
+ }
+ if(id >= threshold)
+ return null;
+
+ else
+ return (Map<T, Integer>) dict;
+ }
+ else {
+ return super.tryGetDictionary(threshold);
+ }
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java
index 94a30f4980..b97ee68d55 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java
@@ -244,8 +244,8 @@ public class RaggedArray<T> extends Array<T> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
- return _a.analyzeValueType();
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
+ return _a.analyzeValueType(maxCells);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
index 03c2c7cc82..46a9050538 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
@@ -147,20 +147,7 @@ public class StringArray extends Array<String> {
out.writeByte(FrameArrayType.STRING.ordinal());
out.writeLong(getInMemorySize());
- // final Charset cs = Charset.defaultCharset();
for(int i = 0; i < _size; i++)
- // {
- // if(_data[i] == null){
- // out.writeInt(0);
- // }
- // else{
- // // cs.encode(_data[i]);
- // byte[] bs = _data[i].getBytes(cs);
- // out.writeInt(bs.length);
- // out.write(bs);
- // }
- // }
-
out.writeUTF((_data[i] != null) ? _data[i] : "");
}
@@ -168,25 +155,9 @@ public class StringArray extends Array<String> {
public void readFields(DataInput in) throws IOException {
_size = _data.length;
materializedSize = in.readLong();
- // byte[] bs = new byte[16];
- // final Charset cs = Charset.defaultCharset();
for(int i = 0; i < _size; i++) {
- // int l = in.readInt();
- // if(l == 0){
- // _data[i] = null;
- // }
- // else{
- // if(l > bs.length)
- // bs = new byte[l];
- // in.readFully(bs, 0, l);
- // String tmp = new String(bs, 0, l, cs);
- // // String tmp = in.readUTF();
- // _data[i] = tmp;
- // }
- {
- String tmp = in.readUTF();
- _data[i] = tmp.isEmpty() ? null : tmp;
- }
+ String tmp = in.readUTF();
+ _data[i] = tmp.isEmpty() ? null : tmp;
}
}
@@ -289,10 +260,10 @@ public class StringArray extends Array<String> {
}
@Override
- public Pair<ValueType, Boolean> analyzeValueType() {
+ public Pair<ValueType, Boolean> analyzeValueType(int maxCells) {
ValueType state = ValueType.UNKNOWN;
boolean nulls = false;
- for(int i = 0; i < _size; i++) {
+ for(int i = 0; i < Math.min(maxCells, _size); i++) {
final ValueType c = FrameUtil.isType(_data[i], state);
if(c == ValueType.STRING)
return new Pair<>(ValueType.STRING, false);
@@ -537,16 +508,30 @@ public class StringArray extends Array<String> {
}
protected Array<Integer> changeTypeIntegerNormal() {
- int[] ret = new int[size()];
- for(int i = 0; i < size(); i++) {
- final String s = _data[i];
- try {
+ try {
+ int[] ret = new int[size()];
+ for(int i = 0; i < size(); i++) {
+ final String s = _data[i];
if(s != null)
ret[i] = Integer.parseInt(s);
}
- catch(NumberFormatException e) {
- throw new DMLRuntimeException("Unable to change
to Integer from String array", e);
+ return new IntegerArray(ret);
+ }
+ catch(NumberFormatException e) {
+ if(e.getMessage().contains("For input string: \"\"")) {
+ LOG.warn("inefficient safe cast");
+ return changeTypeIntegerSafe();
}
+ throw new DMLRuntimeException("Unable to change to
Integer from String array", e);
+ }
+ }
+
+ protected Array<Integer> changeTypeIntegerSafe() {
+ int[] ret = new int[size()];
+ for(int i = 0; i < size(); i++) {
+ final String s = _data[i];
+ if(s != null && s.length() > 0)
+ ret[i] = Integer.parseInt(s);
}
return new IntegerArray(ret);
}
@@ -574,15 +559,31 @@ public class StringArray extends Array<String> {
for(int i = 0; i < size(); i++) {
final String s = _data[i];
if(s != null)
- ret[i] = Long.parseLong(s, 16);
+ ret[i] = HashLongArray.parseHashLong(s);
}
return new HashLongArray(ret);
}
catch(NumberFormatException e) {
+ if(e.getMessage().contains("For input string: \"\"")) {
+ LOG.warn("inefficient safe cast");
+ return changeTypeHash64Safe();
+ }
throw new DMLRuntimeException("Unable to change to
Hash64 from String array", e);
}
}
+ protected Array<Object> changeTypeHash64Safe() {
+
+ long[] ret = new long[size()];
+ for(int i = 0; i < size(); i++) {
+ final String s = _data[i];
+ if(s != null && s.length() > 0)
+ ret[i] = HashLongArray.parseHashLong(s);
+ }
+ return new HashLongArray(ret);
+
+ }
+
@Override
public Array<Character> changeTypeCharacter() {
char[] ret = new char[size()];
@@ -658,13 +659,13 @@ public class StringArray extends Array<String> {
@Override
public boolean isShallowSerialize() {
- long s = getInMemorySize();
+ final long s = getInMemorySize();
return _size < 100 || s / _size < 100;
}
@Override
public boolean isEmpty() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(_data[i] != null && !_data[i].equals("0"))
return false;
return true;
@@ -672,7 +673,7 @@ public class StringArray extends Array<String> {
@Override
public boolean containsNull() {
- for(int i = 0; i < _data.length; i++)
+ for(int i = 0; i < _size; i++)
if(_data[i] == null)
return true;
return false;
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java
b/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java
index ec98e9847a..8323060f81 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java
@@ -28,16 +28,18 @@ public class ArrayCompressionStatistics {
public final long compressedSizeEstimate;
public final boolean shouldCompress;
public final ValueType valueType;
+ public final boolean containsNull;
public final FrameArrayType bestType;
- public final int bitPerValue;
+ public final int bytePerValue;
public final int nUnique;
- public ArrayCompressionStatistics(int bitPerValue, int nUnique, boolean
shouldCompress, ValueType valueType,
- FrameArrayType bestType, long originalSize, long
compressedSizeEstimate) {
- this.bitPerValue = bitPerValue;
+ public ArrayCompressionStatistics(int bytePerValue, int nUnique,
boolean shouldCompress, ValueType valueType,
+ boolean containsNull, FrameArrayType bestType, long
originalSize, long compressedSizeEstimate) {
+ this.bytePerValue = bytePerValue;
this.nUnique = nUnique;
this.shouldCompress = shouldCompress;
this.valueType = valueType;
+ this.containsNull = containsNull;
this.bestType = bestType;
this.originalSize = originalSize;
this.compressedSizeEstimate = compressedSizeEstimate;
@@ -47,7 +49,7 @@ public class ArrayCompressionStatistics {
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(String.format("Compressed Stats: size:%8d->%8d,
Use:%10s, Unique:%6d, ValueType:%7s", originalSize,
- compressedSizeEstimate, bestType.toString(), nUnique,
valueType));
+ compressedSizeEstimate, bestType == null ? "None" :
bestType.toString(), nUnique, valueType));
return sb.toString();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java
b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java
index 482e6a129e..869f97919a 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java
@@ -48,9 +48,7 @@ public class CompressedFrameBlockFactory {
this.cs = cs;
this.stats = new ArrayCompressionStatistics[in.getNumColumns()];
this.compressedColumns = new Array<?>[in.getNumColumns()];
-
this.nSamples = Math.min(in.getNumRows(), (int)
Math.ceil(in.getNumRows() * cs.sampleRatio));
-
}
public static FrameBlock compress(FrameBlock fb) {
@@ -116,15 +114,23 @@ public class CompressedFrameBlockFactory {
private void compressCol(int i) {
stats[i] = in.getColumn(i).statistics(nSamples);
if(stats[i] != null) {
- // commented out because no other encodings are
supported yet
- // switch(stats[i].bestType) {
- // case DDC:
- compressedColumns[i] =
DDCArray.compressToDDC(in.getColumn(i));
- // break;
- // default:
- // compressedColumns[i] = in.getColumn(i);
- // break;
- // }
+ if(stats[i].bestType == null){
+ // just cast to other value type.
+ compressedColumns[i] =
in.getColumn(i).safeChangeType(stats[i].valueType, stats[i].containsNull);
+ }
+ else{
+ // commented out because no other encodings are
supported yet
+ switch(stats[i].bestType) {
+ case DDC:
+ compressedColumns[i] =
DDCArray.compressToDDC(in.getColumn(i), stats[i].valueType,
+ stats[i].containsNull);
+ break;
+ default:
+ LOG.error("Unsupported encoding
default to do nothing: " + stats[i].bestType);
+ compressedColumns[i] =
in.getColumn(i);
+ break;
+ }
+ }
}
else
compressedColumns[i] = in.getColumn(i);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java
index 92372ecab2..0c8ceb9d87 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java
@@ -107,6 +107,7 @@ public class FrameLibApplySchema {
}
private FrameBlock apply() {
+
if(k <= 1 || nCol == 1)
applySingleThread();
else
@@ -121,7 +122,17 @@ public class FrameLibApplySchema {
final String[] colNames = fb.getColumnNames(false);
final ColumnMetadata[] meta = fb.getColumnMetadata();
- return new FrameBlock(schema, colNames, meta, columnsOut);
+
+ FrameBlock out = new FrameBlock(schema, colNames, meta,
columnsOut);
+ if(LOG.isDebugEnabled()){
+
+ long inMem = fb.getInMemorySize();
+ long outMem = out.getInMemorySize();
+ LOG.debug(String.format("Schema Apply Input Size: %16d"
, inMem));
+ LOG.debug(String.format(" Output Size: %16d"
, outMem));
+ LOG.debug(String.format(" Ratio :
%4.3f" , ((double) inMem / outMem)));
+ }
+ return out;
}
private void applySingleThread() {
@@ -136,7 +147,6 @@ public class FrameLibApplySchema {
columnsIn[i].changeType(schema[i]);
else
columnsOut[i] = columnsIn[i].changeType(schema[i]);
-
}
private void applyMultiThread() {
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
index 383e8b205f..71e3788a1c 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
@@ -26,6 +26,8 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
@@ -35,16 +37,22 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;
public final class FrameLibDetectSchema {
- // private static final Log LOG =
LogFactory.getLog(FrameLibDetectSchema.class.getName());
+ protected static final Log LOG =
LogFactory.getLog(FrameLibDetectSchema.class.getName());
+ /** Default minium sample size */
+ private static final int DEFAULT_MIN_CELLS = 100000;
+ /** Frame block to sample from */
private final FrameBlock in;
- // private final double sampleFraction;
+ /** parallelization degree */
private final int k;
+ /** Sample size in case above nCells */
+ private final int sampleSize;
private FrameLibDetectSchema(FrameBlock in, double sampleFraction, int
k) {
this.in = in;
- // this.sampleFraction = sampleFraction;
this.k = k;
+ final int inRows = in.getNumRows();
+ this.sampleSize = Math.min(inRows, Math.max((int) (inRows *
sampleFraction), DEFAULT_MIN_CELLS));
}
public static FrameBlock detectSchema(FrameBlock in, int k) {
@@ -66,8 +74,9 @@ public final class FrameLibDetectSchema {
private String[] singleThreadApply() {
final int cols = in.getNumColumns();
final String[] schemaInfo = new String[cols];
+
for(int i = 0; i < cols; i++)
- assign(schemaInfo, in.getColumn(i).analyzeValueType(),
i);
+ assign(schemaInfo,
in.getColumn(i).analyzeValueType(sampleSize), i);
return schemaInfo;
}
@@ -78,7 +87,7 @@ public final class FrameLibDetectSchema {
final int cols = in.getNumColumns();
final ArrayList<DetectValueTypeTask> tasks = new
ArrayList<>(cols);
for(int i = 0; i < cols; i++)
- tasks.add(new
DetectValueTypeTask(in.getColumn(i)));
+ tasks.add(new
DetectValueTypeTask(in.getColumn(i), sampleSize));
final List<Future<Pair<ValueType, Boolean>>> ret =
pool.invokeAll(tasks);
final String[] schemaInfo = new String[cols];
pool.shutdown();
@@ -103,14 +112,16 @@ public final class FrameLibDetectSchema {
private static class DetectValueTypeTask implements
Callable<Pair<ValueType, Boolean>> {
private final Array<?> _obj;
+ final int _nCells;
- protected DetectValueTypeTask(Array<?> obj) {
+ protected DetectValueTypeTask(Array<?> obj, int nCells) {
_obj = obj;
+ _nCells = nCells;
}
@Override
public Pair<ValueType, Boolean> call() {
- return _obj.analyzeValueType();
+ return _obj.analyzeValueType(_nCells);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
index 309560c46d..752bf3b983 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
@@ -46,7 +46,7 @@ public interface FrameUtil {
public static final Pattern booleanPattern = Pattern
.compile("([tT]((rue)|(RUE))?|[fF]((alse)|(ALSE))?|0\\.0+|1\\.0+|0|1)");
public static final Pattern integerFloatPattern =
Pattern.compile("[-+]?\\d+(\\.0+)?");
- public static final Pattern floatPattern =
Pattern.compile("[-+]?[0-9]*\\.?[0-9]*([eE][-+]?[0-9]+)?");
+ public static final Pattern floatPattern =
Pattern.compile("[-+]?[0-9][0-9]*\\.?[0-9]*([eE][-+]?[0-9]+)?");
public static final Pattern dotSplitPattern = Pattern.compile("\\.");
@@ -123,7 +123,7 @@ public interface FrameUtil {
}
public static ValueType isHash(final String val, final int len) {
- if(len == 8) {
+ if(len == 8 || len == 16) {
for(int i = 0; i < 8; i++) {
char v = val.charAt(i);
if(v < '0' || v > 'f')
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/FrameSerializationTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameSerializationTest.java
index 249680080a..272fe760b0 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/FrameSerializationTest.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/FrameSerializationTest.java
@@ -32,6 +32,8 @@ import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.test.TestUtils;
@@ -44,6 +46,7 @@ import org.junit.runners.Parameterized.Parameters;
@RunWith(value = Parameterized.class)
public class FrameSerializationTest {
+ protected static final Log LOG =
LogFactory.getLog(FrameSerializationTest.class.getName());
private enum SerType {
WRITABLE_SER, JAVA_SER,
@@ -93,10 +96,12 @@ public class FrameSerializationTest {
// init data frame
FrameBlock back;
// core serialization and deserialization
+
if(type == SerType.WRITABLE_SER)
back = writableSerialize(frame);
else // if(stype == SerType.JAVA_SER)
back = javaSerialize(frame);
+
TestUtils.compareFrames(frame, back, true);
}
catch(Exception ex) {
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
index 746762ce1b..672c857414 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
@@ -22,6 +22,9 @@ package org.apache.sysds.test.component.frame;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import java.util.Arrays;
+import java.util.List;
+
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -30,9 +33,8 @@ import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.lib.FrameUtil;
import
org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.junit.Test;
+
import scala.Tuple2;
-import java.util.Arrays;
-import java.util.List;
public class FrameUtilTest {
@@ -156,6 +158,36 @@ public class FrameUtilTest {
assertEquals(ValueType.UNKNOWN, FrameUtil.isType(""));
}
+ @Test
+ public void testEHash() {
+ assertEquals(ValueType.HASH64, FrameUtil.isType("e1232142"));
+ }
+
+ @Test
+ public void testEHash2() {
+ assertEquals(ValueType.HASH64, FrameUtil.isType("e6138002"));
+ }
+
+ @Test
+ public void testEHash3() {
+ assertEquals(ValueType.FP64, FrameUtil.isType("32e68002"));
+ }
+
+ @Test
+ public void testEHash4() {
+ assertEquals(ValueType.HASH64, FrameUtil.isType("3268002e"));
+ }
+
+ @Test
+ public void testEHash5() {
+ assertEquals(ValueType.FP64, FrameUtil.isType("3e268002"));
+ }
+
+ @Test
+ public void testEHash6() {
+ assertEquals(ValueType.FP64, FrameUtil.isType("3268000e2"));
+ }
+
@Test
public void testMinType() {
for(ValueType v : ValueType.values())
@@ -188,7 +220,6 @@ public class FrameUtilTest {
assertEquals(ValueType.INT32,
FrameUtil.isType(Integer.MIN_VALUE + ""));
}
-
@Test
public void testIntegerMinComma() {
assertEquals(ValueType.INT32,
FrameUtil.isType(Integer.MIN_VALUE + ".0"));
@@ -250,7 +281,7 @@ public class FrameUtilTest {
}
@Test
- public void testSparkFrameBlockALignment(){
+ public void testSparkFrameBlockALignment() {
ValueType[] schema = new ValueType[0];
FrameBlock f1 = new FrameBlock(schema, 1000);
FrameBlock f2 = new FrameBlock(schema, 500);
@@ -259,78 +290,85 @@ public class FrameUtilTest {
SparkConf sparkConf = new
SparkConf().setAppName("DirectPairRDDExample").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
- //Test1 (1000, 1000, 500)
- List<Tuple2<Long, FrameBlock>> t1 = Arrays.asList(new
Tuple2<>(1L, f1),new Tuple2<>(1001L, f1),new Tuple2<>(2001L, f2));
+ // Test1 (1000, 1000, 500)
+ List<Tuple2<Long, FrameBlock>> t1 = Arrays.asList(new
Tuple2<>(1L, f1), new Tuple2<>(1001L, f1),
+ new Tuple2<>(2001L, f2));
JavaPairRDD<Long, FrameBlock> pairRDD = sc.parallelizePairs(t1);
Tuple2<Boolean, Integer> result =
FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(result._1);
assertEquals(1000L, (long) result._2);
- //Test2 (1000, 500, 1000)
- t1 = Arrays.asList(new Tuple2<>(1L, f1),new Tuple2<>(1001L,
f2),new Tuple2<>(1501L, f1));
+ // Test2 (1000, 500, 1000)
+ t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f2), new Tuple2<>(1501L, f1));
pairRDD = sc.parallelizePairs(t1);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(!result._1);
- //Test3 (1000, 500, 1000, 250)
- t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f2), new Tuple2<>(1501L, f1), new Tuple2<>(2501L, f3));
+ // Test3 (1000, 500, 1000, 250)
+ t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f2), new Tuple2<>(1501L, f1),
+ new Tuple2<>(2501L, f3));
pairRDD = sc.parallelizePairs(t1);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(!result._1);
- //Test4 (500, 500, 250)
- t1 = Arrays.asList(new Tuple2<>(1L, f2), new Tuple2<>(501L,
f2), new Tuple2<>(1001L, f3));
+ // Test4 (500, 500, 250)
+ t1 = Arrays.asList(new Tuple2<>(1L, f2), new Tuple2<>(501L,
f2), new Tuple2<>(1001L, f3));
pairRDD = sc.parallelizePairs(t1);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(result._1);
assertEquals(500L, (long) result._2);
- //Test5 (1000, 500, 1000, 250)
- t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f2), new Tuple2<>(1501L, f1), new Tuple2<>(2501L, f3));
+ // Test5 (1000, 500, 1000, 250)
+ t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f2), new Tuple2<>(1501L, f1),
+ new Tuple2<>(2501L, f3));
pairRDD = sc.parallelizePairs(t1);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(!result._1);
- //Test6 (1000, 1000, 500, 500)
- t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f2), new Tuple2<>(2501L, f2));
+ // Test6 (1000, 1000, 500, 500)
+ t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f2),
+ new Tuple2<>(2501L, f2));
pairRDD = sc.parallelizePairs(t1);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(!result._1);
- //Test7 (500, 500, 250)
- t1 = Arrays.asList(new Tuple2<>(501L, f2), new Tuple2<>(1001L,
f3), new Tuple2<>(1L, f2));
+ // Test7 (500, 500, 250)
+ t1 = Arrays.asList(new Tuple2<>(501L, f2), new Tuple2<>(1001L,
f3), new Tuple2<>(1L, f2));
pairRDD = sc.parallelizePairs(t1);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(result._1);
assertEquals(500L, (long) result._2);
- //Test8 (500, 500, 250)
- t1 = Arrays.asList( new Tuple2<>(1001L, f3), new
Tuple2<>(501L, f2), new Tuple2<>(1L, f2));
+ // Test8 (500, 500, 250)
+ t1 = Arrays.asList(new Tuple2<>(1001L, f3), new Tuple2<>(501L,
f2), new Tuple2<>(1L, f2));
pairRDD = sc.parallelizePairs(t1);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(result._1);
assertEquals(500L, (long) result._2);
- //Test9 (1000, 1000, 1000, 500)
- t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f1), new Tuple2<>(3001L, f2));
+ // Test9 (1000, 1000, 1000, 500)
+ t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f1),
+ new Tuple2<>(3001L, f2));
pairRDD = sc.parallelizePairs(t1).repartition(2);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, -1);
assertTrue(result._1);
assertEquals(1000L, (long) result._2);
- //Test10 (1000, 1000, 1000, 500)
- t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f1), new Tuple2<>(3001L, f2));
+ // Test10 (1000, 1000, 1000, 500)
+ t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f1),
+ new Tuple2<>(3001L, f2));
pairRDD = sc.parallelizePairs(t1).repartition(2);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD,
1000);
assertTrue(result._1);
assertEquals(1000L, (long) result._2);
- //Test11 (1000, 1000, 1000, 500)
- t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f1), new Tuple2<>(3001L, f2));
+ // Test11 (1000, 1000, 1000, 500)
+ t1 = Arrays.asList(new Tuple2<>(1L, f1), new Tuple2<>(1001L,
f1), new Tuple2<>(2001L, f1),
+ new Tuple2<>(3001L, f2));
pairRDD = sc.parallelizePairs(t1).repartition(2);
result = FrameRDDAggregateUtils.checkRowAlignment(pairRDD, 500);
assertTrue(!result._1);
-
+
sc.close();
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
index 90c41db1a4..f52420ce54 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
@@ -1384,7 +1384,6 @@ public class CustomArrayTests {
assertEquals(65535, HashLongArray.parseHashLong("ffff"));
}
-
@Test
public void parseHash_fffff() {
assertEquals(1048575, HashLongArray.parseHashLong("fffff"));
@@ -1400,7 +1399,6 @@ public class CustomArrayTests {
assertEquals(268435455L,
HashLongArray.parseHashLong("fffffff"));
}
-
@Test
public void parseHash_ffffffff() {
assertEquals(4294967295L,
HashLongArray.parseHashLong("ffffffff"));
@@ -1411,4 +1409,22 @@ public class CustomArrayTests {
assertEquals(-1,
HashLongArray.parseHashLong("ffffffffffffffff"));
}
+ @Test
+ public void compressWithNull() {
+ Array<Double> a = ArrayFactory
+ .create(new Double[] {0.02, null, null, 0.03, null,
null, null, null, null, null, null, null});
+ Array<Double> c = DDCArray.compressToDDC(a);
+ FrameArrayTests.compare(a, c);
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void compressHashColumn() {
+ Array<String> a = ArrayFactory
+ .create(new String[] {"aaaaaaaa", null, null,
"ffffffff", null, null, null, null, null, null, null, null});
+ Array<Object> b =
(Array<Object>)a.changeTypeWithNulls(ValueType.HASH64);
+ Array<Object> c = DDCArray.compressToDDC(b);
+ FrameArrayTests.compare(b, c);
+ }
+
}
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
index 165f1327b2..dc0f03c58e 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
@@ -420,8 +420,10 @@ public class FrameArrayTests {
return;
if(a.getFrameArrayType() == FrameArrayType.DDC)
return; // can happen where DDC is wrapping Optional.
+ if(a.getFrameArrayType() == FrameArrayType.OPTIONAL)
+ return;
- assertEquals(t, a.getFrameArrayType());
+ assertEquals(a.toString(),t, a.getFrameArrayType());
}
@Test
@@ -1244,7 +1246,7 @@ public class FrameArrayTests {
a.write(fos);
long s = fos.size();
long e = a.getExactSerializedSize();
- assertEquals(s, e);
+ assertEquals(a.toString(),s, e);
}
catch(IOException e) {
throw new RuntimeException("Error in io", e);
@@ -1916,6 +1918,7 @@ public class FrameArrayTests {
String err = a.getClass().getSimpleName() + " " +
a.getValueType() + " " + b.getClass().getSimpleName() + " "
+ b.getValueType();
assertTrue(a.size() == b.size());
+
for(int i = 0; i < size; i++) {
final Object av = a.get(i);
final Object bv = b.get(i);
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java
b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java
index 9ed4fa6747..8518de226e 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java
@@ -28,6 +28,7 @@ import org.apache.sysds.runtime.frame.data.FrameBlock;
import
org.apache.sysds.runtime.frame.data.compress.CompressedFrameBlockFactory;
import org.apache.sysds.runtime.frame.data.compress.FrameCompressionSettings;
import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress;
+import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
@@ -46,6 +47,22 @@ public class FrameCompressTest {
runTest(a, 4);
}
+ @Test
+ public void testParallelWithSchema() {
+ FrameBlock a =
FrameCompressTestUtils.generateCompressableBlock(200, 5, 1232,
ValueType.STRING);
+ FrameBlock sc = FrameLibDetectSchema.detectSchema(a, 4);
+ a.applySchema(sc);
+ runTest(a, 4);
+ }
+
+ @Test
+ public void testParallelWithRandom() {
+ FrameBlock a =
FrameCompressTestUtils.generateCompressableBlockRandomTypes(200, 5, 1232);
+ FrameBlock sc = FrameLibDetectSchema.detectSchema(a, 4);
+ a = a.applySchema(sc);
+ runTest(a, 4);
+ }
+
public void runTest(FrameBlock a, int k) {
try {
FrameBlock b = FrameLibCompress.compress(a, k);
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestUtils.java
b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestUtils.java
index bc512c5567..2837cdde93 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestUtils.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestUtils.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.component.frame.compress;
import java.util.Random;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
@@ -55,6 +56,8 @@ public class FrameCompressTestUtils {
switch(vt) {
case BOOLEAN:
return
ArrayFactory.create(FrameArrayTests.generateRandomBooleanOpt(size, seed));
+ case UINT8:
+ case UINT4:
case INT32:
return
ArrayFactory.create(FrameArrayTests.generateRandomIntegerNUniqueLengthOpt(size,
seed, nUnique));
case INT64:
@@ -65,9 +68,12 @@ public class FrameCompressTestUtils {
return
ArrayFactory.create(FrameArrayTests.generateRandomDoubleNUniqueLengthOpt(size,
seed, nUnique));
case CHARACTER:
return
ArrayFactory.create(FrameArrayTests.generateRandomCharacterNUniqueLengthOpt(size,
seed, nUnique));
+ case HASH64:
+ return
ArrayFactory.create(FrameArrayTests.generateRandomHash64OptNUnique(size, seed,
nUnique));
case STRING:
- default:
return
ArrayFactory.create(FrameArrayTests.generateRandomStringNUniqueLengthOpt(size,
seed, nUnique, 132));
+ default:
+ throw new NotImplementedException(vt + "");
}
}
}