This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7ddc86c687 [SYSTEMDS-3486] CharArray Primitive
7ddc86c687 is described below
commit 7ddc86c6875dd10aad352d0fa4d0ece1cc65f9f3
Author: baunsgaard <[email protected]>
AuthorDate: Wed Jan 11 21:28:30 2023 +0100
[SYSTEMDS-3486] CharArray Primitive
This commit adds a CharArray primitive to the FrameBlock columns.
The char array is used in cases where there is a column in frame blocks
with a single character in each cell.
This overall reduce the allocation cost of each value from a full
Object + pointer + Array pointer, to a simple char.
Some design decisions include that if a int column is changed into
a char column the bit encoding is prioritized rather than similar character.
For instance the value 0 of an int is ' ' char that is the first char,
and it is bitwise the same value (just with less bits.).
If a string is converted then "0" becomes the char value 42.
Closes #1761
---
src/main/java/org/apache/sysds/common/Types.java | 17 +-
.../sysds/runtime/frame/data/columns/Array.java | 11 +-
.../runtime/frame/data/columns/ArrayFactory.java | 16 +-
.../runtime/frame/data/columns/BitSetArray.java | 8 +
.../runtime/frame/data/columns/BooleanArray.java | 14 +-
.../columns/{LongArray.java => CharArray.java} | 223 +++++++++++----------
.../runtime/frame/data/columns/DoubleArray.java | 17 +-
.../runtime/frame/data/columns/FloatArray.java | 23 ++-
.../runtime/frame/data/columns/IntegerArray.java | 11 +-
.../runtime/frame/data/columns/LongArray.java | 11 +-
.../runtime/frame/data/columns/StringArray.java | 13 +-
.../sysds/runtime/frame/data/lib/FrameUtil.java | 12 +-
.../apache/sysds/runtime/util/UtilFunctions.java | 20 ++
.../sysds/test/component/frame/FrameTest.java | 14 +-
.../component/frame/array/CustomArrayTests.java | 42 ++++
.../frame/array/FrameArrayConstantTests.java | 13 +-
.../component/frame/array/FrameArrayTests.java | 74 ++++++-
.../functions/frame/FrameDropInvalidTypeTest.java | 1 +
.../io/proto/FrameReaderWriterProtoTest.java | 29 ++-
19 files changed, 403 insertions(+), 166 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 68cc801a1c..58dfd8bfae 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -76,7 +76,8 @@ public class Types
*/
public enum ValueType {
UINT8, // Used for parsing in UINT values from numpy.
- FP32, FP64, INT32, INT64, BOOLEAN, STRING, UNKNOWN;
+ FP32, FP64, INT32, INT64, BOOLEAN, STRING, UNKNOWN,
+ CHARACTER;
public boolean isNumeric() {
return this == UINT8 || this == INT32 || this == INT64
|| this == FP32 || this == FP64;
@@ -85,7 +86,7 @@ public class Types
return this == UNKNOWN;
}
public boolean isPseudoNumeric() {
- return isNumeric() || this == BOOLEAN;
+ return isNumeric() || this == BOOLEAN || this ==
CHARACTER;
}
public String toExternalString() {
switch(this) {
@@ -112,6 +113,7 @@ public class Types
case "INT": return INT64;
case "BOOLEAN": return BOOLEAN;
case "STRING": return STRING;
+ case "CHARACTER": return CHARACTER;
case "UNKNOWN": return UNKNOWN;
default:
throw new DMLRuntimeException("Unknown
value type: "+value);
@@ -127,20 +129,23 @@ public class Types
*
* @param a First ValueType
* @param b Second ValueType
- * @return The common higest type to represent both
+ * @return The common highest type to represent both
*/
public static ValueType getHighestCommonType(ValueType a,
ValueType b){
if(a == b)
return a;
- if(b == UNKNOWN)
+ else if(b == UNKNOWN)
throw new DMLRuntimeException(
"Invalid or not implemented support for
comparing valueType of: " + a + " and " + b);
switch(a){
+ case CHARACTER:
+ return STRING;
case STRING:
return a;
case FP64:
switch(b){
+ case CHARACTER:
case STRING:
return b;
default:
@@ -148,6 +153,7 @@ public class Types
}
case FP32:
switch(b){
+ case CHARACTER:
case STRING:
case FP64:
return b;
@@ -156,6 +162,7 @@ public class Types
}
case INT64:
switch(b){
+ case CHARACTER:
case STRING:
case FP64:
case FP32:
@@ -165,6 +172,7 @@ public class Types
}
case INT32:
switch(b){
+ case CHARACTER:
case STRING:
case FP64:
case FP32:
@@ -175,6 +183,7 @@ public class Types
}
case UINT8:
switch(b){
+ case CHARACTER:
case STRING:
case FP64:
case FP32:
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
index c757e3c7e8..465fad79ac 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
@@ -181,7 +181,7 @@ public abstract class Array<T> implements Writable {
* @param value array of other type
*/
public final void setFromOtherTypeNz(Array<?> value) {
- setFromOtherTypeNz(0, value.size()-1, value);
+ setFromOtherTypeNz(0, value.size() - 1, value);
}
/**
@@ -320,6 +320,8 @@ public abstract class Array<T> implements Writable {
return changeTypeLong();
case STRING:
return changeTypeString();
+ case CHARACTER:
+ return changeTypeCharacter();
case UNKNOWN:
default:
throw new DMLRuntimeException("Not a valid type
to change to : " + t);
@@ -375,6 +377,13 @@ public abstract class Array<T> implements Writable {
*/
protected abstract Array<String> changeTypeString();
+ /**
+ * Change type to a Character array type
+ *
+ * @return Character type of array
+ */
+ protected abstract Array<Character> changeTypeCharacter();
+
/**
* Get the minimum and maximum length of the contained values as string
type.
*
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
index c98c9eb11f..c39add704f 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
@@ -32,7 +32,7 @@ public interface ArrayFactory {
public final static int bitSetSwitchPoint = 64;
public enum FrameArrayType {
- STRING, BOOLEAN, BITSET, INT32, INT64, FP32, FP64;
+ STRING, BOOLEAN, BITSET, INT32, INT64, FP32, FP64, CHARACTER;
}
public static StringArray create(String[] col) {
@@ -63,6 +63,10 @@ public interface ArrayFactory {
return new DoubleArray(col);
}
+ public static CharArray create(char[] col) {
+ return new CharArray(col);
+ }
+
public static long getInMemorySize(ValueType type, int _numRows) {
switch(type) {
case BOOLEAN:
@@ -83,6 +87,8 @@ public interface ArrayFactory {
// cannot be known since strings have dynamic
length
// lets assume something large to make it
somewhat safe.
return Array.baseMemoryCost() + (long)
MemoryEstimates.stringCost(12) * _numRows;
+ case CHARACTER:
+ return Array.baseMemoryCost() + (long)
MemoryEstimates.charArrayCost(_numRows);
default: // not applicable
throw new DMLRuntimeException("Invalid type to
estimate size of :" + type);
}
@@ -112,6 +118,8 @@ public interface ArrayFactory {
return new FloatArray(new float[nRow]);
case FP64:
return new DoubleArray(new double[nRow]);
+ case CHARACTER:
+ return new CharArray(new char[nRow]);
default:
throw new DMLRuntimeException("Unsupported
value type: " + v);
}
@@ -139,7 +147,11 @@ public interface ArrayFactory {
case FP32:
arr = new FloatArray(new float[nRow]);
break;
- default: // String
+ case CHARACTER:
+ arr = new CharArray(new char[nRow]);
+ break;
+ case STRING:
+ default:
arr = new StringArray(new String[nRow]);
break;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
index a648b8b3fd..af1385bd9e 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java
@@ -441,6 +441,14 @@ public class BitSetArray extends Array<Boolean> {
return new StringArray(ret);
}
+ @Override
+ public Array<Character> changeTypeCharacter() {
+ char[] ret = new char[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = (char) (get(i) ? 1 : 0);
+ return new CharArray(ret);
+ }
+
@Override
public void fill(String value) {
fill(BooleanArray.parseBoolean(value));
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
index d14cc23428..cd1e402d8e 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java
@@ -270,6 +270,14 @@ public class BooleanArray extends Array<Boolean> {
return new StringArray(ret);
}
+ @Override
+ public Array<Character> changeTypeCharacter() {
+ char[] ret = new char[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = (char) (_data[i] ? 1 : 0);
+ return new CharArray(ret);
+ }
+
@Override
public void fill(String value) {
fill(parseBoolean(value));
@@ -277,9 +285,7 @@ public class BooleanArray extends Array<Boolean> {
@Override
public void fill(Boolean value) {
- for(int i = 0; i < _size; i++)
- _data[i] = value;
-
+ Arrays.fill(_data, value);
}
@Override
@@ -298,7 +304,7 @@ public class BooleanArray extends Array<Boolean> {
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
+ StringBuilder sb = new StringBuilder(_data.length + 2);
sb.append(super.toString() + ":[");
for(int i = 0; i < _size - 1; i++)
sb.append((_data[i] ? 1 : 0) + ",");
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
similarity index 59%
copy from
src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
copy to src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
index 4dbdc91d32..af4c83e6e7 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-
package org.apache.sysds.runtime.frame.data.columns;
import java.io.DataInput;
@@ -30,61 +29,80 @@ import java.util.BitSet;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType;
+import org.apache.sysds.runtime.frame.data.lib.FrameUtil;
import org.apache.sysds.runtime.util.UtilFunctions;
-import org.apache.sysds.utils.MemoryEstimates;
-public class LongArray extends Array<Long> {
- private long[] _data;
+public class CharArray extends Array<Character> {
+
+ protected char[] _data;
- public LongArray(long[] data) {
+ public CharArray(char[] data) {
super(data.length);
_data = data;
}
- public long[] get() {
+ public char[] get() {
return _data;
}
@Override
- public Long get(int index) {
+ public void write(DataOutput out) throws IOException {
+ out.writeByte(FrameArrayType.CHARACTER.ordinal());
+ for(int i = 0; i < _size; i++)
+ out.writeChar(_data[i]);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ _size = _data.length;
+ for(int i = 0; i < _size; i++)
+ _data[i] = in.readChar();
+ }
+
+ @Override
+ public Character get(int index) {
return _data[index];
}
@Override
- public void set(int index, Long value) {
- _data[index] = (value != null) ? value : 0L;
+ public double getAsDouble(int i) {
+ return (int) _data[i];
}
@Override
- public void set(int index, double value) {
- _data[index] = (long) value;
+ public void set(int index, Character value) {
+ _data[index] = value != null ? value : 0;
}
@Override
- public void set(int index, String value) {
- set(index, parseLong(value));
+ public void set(int index, double value) {
+ _data[index] = parseChar(Double.toString(value));
}
@Override
- public void set(int rl, int ru, Array<Long> value) {
- set(rl, ru, value, 0);
+ public void set(int index, String value) {
+ _data[index] = parseChar(value);
}
@Override
public void setFromOtherType(int rl, int ru, Array<?> value) {
- final ValueType vt = value.getValueType();
for(int i = rl; i <= ru; i++)
- _data[i] = UtilFunctions.objectToLong(vt, value.get(i));
+ _data[i] = value.get(i).toString().charAt(0);
}
@Override
- public void set(int rl, int ru, Array<Long> value, int rlSrc) {
- System.arraycopy(((LongArray) value)._data, rlSrc, _data, rl,
ru - rl + 1);
+ public void set(int rl, int ru, Array<Character> value) {
+ set(rl, ru, value, 0);
}
@Override
- public void setNz(int rl, int ru, Array<Long> value) {
- long[] data2 = ((LongArray) value)._data;
+ public void set(int rl, int ru, Array<Character> value, int rlSrc) {
+ System.arraycopy(((CharArray) value)._data, rlSrc, _data, rl,
ru - rl + 1);
+ }
+
+ @Override
+ public void setNz(int rl, int ru, Array<Character> value) {
+ char[] data2 = ((CharArray) value)._data;
for(int i = rl; i <= ru; i++)
if(data2[i] != 0)
_data[i] = data2[i];
@@ -94,7 +112,7 @@ public class LongArray extends Array<Long> {
public void setFromOtherTypeNz(int rl, int ru, Array<?> value) {
final ValueType vt = value.getValueType();
for(int i = rl; i <= ru; i++) {
- long v = UtilFunctions.objectToLong(vt, value.get(i));
+ char v = UtilFunctions.objectToCharacter(vt,
value.get(i));
if(v != 0)
_data[i] = v;
}
@@ -102,53 +120,34 @@ public class LongArray extends Array<Long> {
@Override
public void append(String value) {
- append(parseLong(value));
+ append(parseChar(value));
}
@Override
- public void append(Long value) {
+ public void append(Character value) {
if(_data.length <= _size)
_data = Arrays.copyOf(_data, newSize());
- _data[_size++] = (value != null) ? value : 0L;
+ _data[_size++] = (value != null) ? value : 0;
}
@Override
- public LongArray append(Array<Long> other) {
+ public Array<Character> append(Array<Character> other) {
final int endSize = this._size + other.size();
- final long[] ret = new long[endSize];
+ final char[] ret = new char[endSize];
System.arraycopy(_data, 0, ret, 0, this._size);
- System.arraycopy((long[]) other.get(), 0, ret, this._size,
other.size());
- return new LongArray(ret);
- }
-
- @Override
- public void write(DataOutput out) throws IOException {
- out.writeByte(FrameArrayType.INT64.ordinal());
- for(int i = 0; i < _size; i++)
- out.writeLong(_data[i]);
- }
-
- @Override
- public void readFields(DataInput in) throws IOException {
- _size = _data.length;
- for(int i = 0; i < _size; i++)
- _data[i] = in.readLong();
+ System.arraycopy((char[]) other.get(), 0, ret, this._size,
other.size());
+ return new CharArray(ret);
}
@Override
- public Array<Long> clone() {
- return new LongArray(Arrays.copyOf(_data, _size));
- }
-
- @Override
- public Array<Long> slice(int rl, int ru) {
- return new LongArray(Arrays.copyOfRange(_data, rl, ru));
+ public Array<Character> slice(int rl, int ru) {
+ return new CharArray(Arrays.copyOfRange(_data, rl, ru));
}
@Override
public void reset(int size) {
if(_data.length < size || _data.length > 2 * size)
- _data = new long[size];
+ _data = new char[size];
else
for(int i = 0; i < size; i++)
_data[i] = 0;
@@ -157,38 +156,32 @@ public class LongArray extends Array<Long> {
@Override
public byte[] getAsByteArray() {
- ByteBuffer longBuffer = ByteBuffer.allocate(8 * _size);
- longBuffer.order(ByteOrder.LITTLE_ENDIAN);
+ ByteBuffer charBuffer = ByteBuffer.allocate(2 * _size);
+ charBuffer.order(ByteOrder.nativeOrder());
for(int i = 0; i < _size; i++)
- longBuffer.putLong(_data[i]);
- return longBuffer.array();
+ charBuffer.putChar(_data[i]);
+ return charBuffer.array();
+
}
@Override
public ValueType getValueType() {
- return ValueType.INT64;
+ return ValueType.CHARACTER;
}
@Override
public ValueType analyzeValueType() {
- return ValueType.INT64;
+ return ValueType.CHARACTER;
}
@Override
public FrameArrayType getFrameArrayType() {
- return FrameArrayType.INT64;
- }
-
- @Override
- public long getInMemorySize() {
- long size = super.getInMemorySize(); // object header + object
reference
- size += MemoryEstimates.longArrayCost(_data.length);
- return size;
+ return FrameArrayType.CHARACTER;
}
@Override
public long getExactSerializedSize() {
- return 1 + 8 * _data.length;
+ return 1 + 2 * _data.length;
}
@Override
@@ -196,8 +189,7 @@ public class LongArray extends Array<Long> {
BitSet ret = new BitSet(size());
for(int i = 0; i < size(); i++) {
if(_data[i] != 0 && _data[i] != 1)
- throw new DMLRuntimeException(
- "Unable to change to Boolean from
Integer array because of value:" + _data[i]);
+ throw new DMLRuntimeException("Unable to change
to Boolean from char array because of value:" + _data[i]);
ret.set(i, _data[i] == 0 ? false : true);
}
return new BitSetArray(ret, size());
@@ -207,7 +199,7 @@ public class LongArray extends Array<Long> {
protected Array<Boolean> changeTypeBoolean() {
boolean[] ret = new boolean[size()];
for(int i = 0; i < size(); i++) {
- if(_data[i] < 0 || _data[i] > 1)
+ if(_data[i] != 0 && _data[i] != 1)
throw new DMLRuntimeException(
"Unable to change to Boolean from
Integer array because of value:" + _data[i]);
ret[i] = _data[i] == 0 ? false : true;
@@ -217,72 +209,68 @@ public class LongArray extends Array<Long> {
@Override
protected Array<Double> changeTypeDouble() {
- double[] ret = new double[size()];
- for(int i = 0; i < size(); i++)
- ret[i] = (double) _data[i];
- return new DoubleArray(ret);
+ try {
+ double[] ret = new double[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = (int) _data[i];
+ return new DoubleArray(ret);
+ }
+ catch(NumberFormatException e) {
+ throw new DMLRuntimeException("Invalid parsing of char
to double", e);
+ }
}
@Override
protected Array<Float> changeTypeFloat() {
- float[] ret = new float[size()];
- for(int i = 0; i < size(); i++)
- ret[i] = (float) _data[i];
- return new FloatArray(ret);
+ try {
+ float[] ret = new float[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = Float.parseFloat(_data[i] + "");
+ return new FloatArray(ret);
+ }
+ catch(NumberFormatException e) {
+ throw new DMLRuntimeException("Invalid parsing of char
to float", e);
+ }
}
@Override
protected Array<Integer> changeTypeInteger() {
int[] ret = new int[size()];
- for(int i = 0; i < size(); i++) {
- if(_data[i] != (long) (int) _data[i])
- throw new DMLRuntimeException("Unable to change
to integer from long array because of value:" + _data[i]);
+ for(int i = 0; i < size(); i++)
ret[i] = (int) _data[i];
- }
+
return new IntegerArray(ret);
}
@Override
protected Array<Long> changeTypeLong() {
- return this;
+ long[] ret = new long[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = (int) _data[i];
+ return new LongArray(ret);
}
@Override
protected Array<String> changeTypeString() {
String[] ret = new String[size()];
for(int i = 0; i < size(); i++)
- ret[i] = get(i).toString();
+ ret[i] = _data[i] + "";
return new StringArray(ret);
}
@Override
- public void fill(String value) {
- fill(parseLong(value));
+ public Array<Character> changeTypeCharacter() {
+ return this;
}
@Override
- public void fill(Long value) {
- for(int i = 0; i < _size; i++)
- _data[i] = value;
+ public void fill(String val) {
+ fill(parseChar(val));
}
@Override
- public double getAsDouble(int i) {
- return _data[i];
- }
-
- protected static long parseLong(String s) {
- if(s == null)
- return 0;
- try {
- return Long.parseLong(s);
- }
- catch(NumberFormatException e) {
- if(s.contains("."))
- return Long.parseLong(s.split("\\.")[0]);
- else
- throw e;
- }
+ public void fill(Character val) {
+ Arrays.fill(_data, (char) val);
}
@Override
@@ -290,12 +278,31 @@ public class LongArray extends Array<Long> {
return true;
}
+ @Override
+ public Array<Character> clone() {
+ return new CharArray(Arrays.copyOf(_data, _size));
+ }
+
+ public static char parseChar(String value) {
+ if(value == null)
+ return 0;
+ else if(value.length() == 1)
+ return value.charAt(0);
+ else if(FrameUtil.isIntType(value, value.length()) != null)
+ return (char) Double.parseDouble(value);
+ else
+ throw new DMLRuntimeException("Invalid parsing of
Character");
+ }
+
@Override
public String toString() {
- StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
- sb.append(super.toString() + ":[");
- for(int i = 0; i < _size - 1; i++)
- sb.append(_data[i] + ",");
+ StringBuilder sb = new StringBuilder(_data.length);
+ sb.append(super.toString());
+ sb.append(":[");
+ for(int i = 0; i < _size - 1; i++) {
+ sb.append(_data[i]);
+ sb.append(',');
+ }
sb.append(_data[_size - 1]);
sb.append("]");
return sb.toString();
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
index 6c37faf8b3..acf2962621 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java
@@ -62,8 +62,8 @@ public class DoubleArray extends Array<Double> {
}
@Override
- public void set(int index, String value){
- set(index, parseDouble(value) );
+ public void set(int index, String value) {
+ set(index, parseDouble(value));
}
@Override
@@ -118,7 +118,7 @@ public class DoubleArray extends Array<Double> {
final int endSize = this._size + other.size();
final double[] ret = new double[endSize];
System.arraycopy(_data, 0, ret, 0, this._size);
- System.arraycopy((double[])other.get(), 0, ret, this._size,
other.size());
+ System.arraycopy((double[]) other.get(), 0, ret, this._size,
other.size());
return new DoubleArray(ret);
}
@@ -302,6 +302,14 @@ public class DoubleArray extends Array<Double> {
return new StringArray(ret);
}
+ @Override
+ public Array<Character> changeTypeCharacter() {
+ char[] ret = new char[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = CharArray.parseChar(get(i).toString());
+ return new CharArray(ret);
+ }
+
@Override
public void fill(String value) {
fill(parseDouble(value));
@@ -309,8 +317,7 @@ public class DoubleArray extends Array<Double> {
@Override
public void fill(Double value) {
- for(int i = 0; i < _size; i++)
- _data[i] = value;
+ Arrays.fill(_data, value);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
index 5c740310e8..3f51991504 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java
@@ -60,10 +60,9 @@ public class FloatArray extends Array<Float> {
_data[index] = (float) value;
}
-
@Override
- public void set(int index, String value){
- set(index,parseFloat(value) );
+ public void set(int index, String value) {
+ set(index, parseFloat(value));
}
@Override
@@ -118,8 +117,7 @@ public class FloatArray extends Array<Float> {
final int endSize = this._size + other.size();
final float[] ret = new float[endSize];
System.arraycopy(_data, 0, ret, 0, this._size);
-
- System.arraycopy((float[])other.get(), 0, ret, this._size,
other.size());
+ System.arraycopy((float[]) other.get(), 0, ret, this._size,
other.size());
return new FloatArray(ret);
}
@@ -151,7 +149,7 @@ public class FloatArray extends Array<Float> {
public void reset(int size) {
if(_data.length < size || _data.length > 2 * size)
_data = new float[size];
- else
+ else
for(int i = 0; i < size; i++)
_data[i] = 0;
_size = size;
@@ -257,6 +255,14 @@ public class FloatArray extends Array<Float> {
return new StringArray(ret);
}
+ @Override
+ public Array<Character> changeTypeCharacter() {
+ char[] ret = new char[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = CharArray.parseChar(get(i).toString());
+ return new CharArray(ret);
+ }
+
@Override
public void fill(String value) {
fill(parseFloat(value));
@@ -264,12 +270,11 @@ public class FloatArray extends Array<Float> {
@Override
public void fill(Float value) {
- for(int i = 0; i < _size; i++)
- _data[i] = value;
+ Arrays.fill(_data, value);
}
@Override
- public double getAsDouble(int i){
+ public double getAsDouble(int i) {
return _data[i];
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
index a94987ea15..9e01bebcbf 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java
@@ -252,6 +252,14 @@ public class IntegerArray extends Array<Integer> {
return new StringArray(ret);
}
+ @Override
+ public Array<Character> changeTypeCharacter() {
+ char[] ret = new char[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = get(i).toString().charAt(0);
+ return new CharArray(ret);
+ }
+
@Override
public void fill(String value) {
fill(parseInt(value));
@@ -259,8 +267,7 @@ public class IntegerArray extends Array<Integer> {
@Override
public void fill(Integer value) {
- for(int i = 0; i < _size; i++)
- _data[i] = value;
+ Arrays.fill(_data, value);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
index 4dbdc91d32..f17df67a80 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java
@@ -262,8 +262,7 @@ public class LongArray extends Array<Long> {
@Override
public void fill(Long value) {
- for(int i = 0; i < _size; i++)
- _data[i] = value;
+ Arrays.fill(_data, value);
}
@Override
@@ -285,6 +284,14 @@ public class LongArray extends Array<Long> {
}
}
+ @Override
+ public Array<Character> changeTypeCharacter() {
+ char[] ret = new char[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = get(i).toString().charAt(0);
+ return new CharArray(ret);
+ }
+
@Override
public boolean isShallowSerialize() {
return true;
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
index 254d4f87ec..9f74fa9a73 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
@@ -264,7 +264,7 @@ public class StringArray extends Array<String> {
@Override
public long getExactSerializedSize() {
- long si = 1 + 8; // byte identifier
+ long si = 1 + 8; // byte identifier and long size
for(String s : _data)
si += IOUtilFunctions.getUTFSize(s);
return si;
@@ -440,10 +440,17 @@ public class StringArray extends Array<String> {
return new Pair<>(minLength, maxLength);
}
+ @Override
+ public Array<Character> changeTypeCharacter() {
+ char[] ret = new char[size()];
+ for(int i = 0; i < size(); i++)
+ ret[i] = _data[i].charAt(0);
+ return new CharArray(ret);
+ }
+
@Override
public void fill(String value) {
- for(int i = 0; i < _size; i++)
- _data[i] = value;
+ Arrays.fill(_data, value);
materializedSize = -1;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
index 97130c5999..73cab0d72a 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
@@ -35,7 +35,8 @@ import org.apache.sysds.runtime.util.UtilFunctions;
public interface FrameUtil {
public static final Log LOG =
LogFactory.getLog(FrameUtil.class.getName());
- public static final Pattern booleanPattern =
Pattern.compile("([tT]((rue)|(RUE))?|[fF]((alse)|(ALSE))?|0\\.0+|1\\.0+|0|1)");
+ public static final Pattern booleanPattern = Pattern
+
.compile("([tT]((rue)|(RUE))?|[fF]((alse)|(ALSE))?|0\\.0+|1\\.0+|0|1)");
public static final Pattern integerFloatPattern =
Pattern.compile("[-+]?\\d+(\\.0+)?");
public static final Pattern floatPattern =
Pattern.compile("[-+]?[0-9]*\\.?[0-9]*([eE][-+]?[0-9]+)?");
@@ -72,9 +73,9 @@ public interface FrameUtil {
return ValueType.INT64;
}
- private static ValueType isIntType(final String val, final int len) {
+ public static ValueType isIntType(final String val, final int len) {
if(len <= 22) {
- if(simpleIntMatch(val, len)){
+ if(simpleIntMatch(val, len)) {
if(len < 8)
return ValueType.INT32;
return intType(Long.parseLong(val));
@@ -88,7 +89,7 @@ public interface FrameUtil {
return null;
}
- private static ValueType isFloatType(final String val, final int len) {
+ public static ValueType isFloatType(final String val, final int len) {
if(len <= 25 && floatPattern.matcher(val).matches()) {
// return isFloatType(v);
@@ -128,6 +129,7 @@ public interface FrameUtil {
switch(minType) {
case UNKNOWN:
case BOOLEAN:
+ case CHARACTER:
if(isBooleanType(val, len) != null)
return ValueType.BOOLEAN;
case UINT8:
@@ -141,6 +143,8 @@ public interface FrameUtil {
r = isFloatType(val, len);
if(r != null)
return r;
+ if(len == 1)
+ return ValueType.CHARACTER;
case STRING:
return ValueType.STRING;
default:
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index dd080cdfac..5cb08bfcfd 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -45,6 +45,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.CharArray;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.Pair;
@@ -489,6 +490,7 @@ public class UtilFunctions {
case INT64: return Long.parseLong(in);
case FP64: return Double.parseDouble(in);
case FP32: return Float.parseFloat(in);
+ case CHARACTER: return CharArray.parseChar(in);
default: throw new RuntimeException("Unsupported value
type: "+vt);
}
}
@@ -536,6 +538,24 @@ public class UtilFunctions {
}
}
+ public static char objectToCharacter(ValueType vt, Object in){
+ if(in == null)
+ return 0;
+ switch(vt) {
+ case FP64:
+ case FP32:
+ case INT64:
+ case INT32:
+ return in.toString().charAt(0);
+ case BOOLEAN:
+ return ((Boolean) in) ? '1' : '0';
+ case STRING:
+ return !((String) in).isEmpty() ?
((String)in).charAt(0) : 0;
+ default:
+ throw new DMLRuntimeException("Unsupported
value type: " + vt);
+ }
+ }
+
public static int objectToInteger(ValueType vt, Object in) {
if(in == null)
return 0;
diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
index 5d542215b2..34b4d688fa 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
@@ -27,6 +27,8 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
@@ -40,7 +42,9 @@ import org.junit.runners.Parameterized.Parameters;
@RunWith(value = Parameterized.class)
public class FrameTest {
- public FrameBlock f;
+ protected static final Log LOG =
LogFactory.getLog(FrameTest.class.getName());
+
+ public final FrameBlock f;
@Parameters
public static Collection<Object[]> data() {
@@ -162,7 +166,7 @@ public class FrameTest {
for(int r = 0; r < 10; r++)
for(int c = 0; c < f.getNumColumns(); c++) {
String v = ff.get(r, c).toString();
- assertTrue(v, v.equals("0") || v.equals("0.0")
|| v.equals("false"));
+ assertTrue(v, v.equals("0") || v.equals("0.0")
|| v.equals("false") || v.equals((char)0 +""));
}
for(int r = 0; r < f.getNumRows(); r++)
for(int c = 0; c < f.getNumColumns(); c++)
@@ -177,7 +181,9 @@ public class FrameTest {
for(int r = 0; r < 240; r++)
for(int c = 0; c < f.getNumColumns(); c++) {
String v = ff.get(r, c).toString();
- assertTrue(v, v.equals("0") || v.equals("0.0")
|| v.equals("false"));
+ // LOG.error(v);
+ // LOG.error((int)v.charAt(0));
+ assertTrue(v, v.equals("0") || v.equals("0.0")
|| v.equals("false") || v.equals((char)0 +""));
}
for(int r = 0; r < f.getNumRows(); r++)
for(int c = 0; c < f.getNumColumns(); c++)
@@ -189,7 +195,7 @@ public class FrameTest {
Iterator<Object[]> it = IteratorFactory.getObjectRowIterator(f);
- for(int r = 0; r < f.getNumRows(); r++){
+ for(int r = 0; r < f.getNumRows(); r++) {
Object[] row = it.next();
for(int c = 0; c < f.getNumColumns(); c++)
assertEquals(f.get(r, c).toString(),
row[c].toString());
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
index 53d29c053f..46dcd15124 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
@@ -29,9 +29,11 @@ import java.util.BitSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
import org.apache.sysds.runtime.frame.data.columns.BitSetArray;
import org.apache.sysds.runtime.frame.data.columns.BooleanArray;
+import org.apache.sysds.runtime.frame.data.columns.CharArray;
import org.apache.sysds.runtime.frame.data.columns.IntegerArray;
import org.apache.sysds.runtime.frame.data.columns.LongArray;
import org.apache.sysds.runtime.frame.data.columns.StringArray;
@@ -487,6 +489,46 @@ public class CustomArrayTests {
assertEquals(BitSetArray.longToBits(-1),
"1111111111111111111111111111111111111111111111111111111111111111");
}
+ @Test
+ public void charSet() {
+ CharArray a = ArrayFactory.create(new char[2]);
+ a.set(0, "1.0");
+ assertEquals(a.get(0), Character.valueOf((char) 1));
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void charSet_invalid() {
+ CharArray a = ArrayFactory.create(new char[2]);
+ a.set(0, "1.01");
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void charSet_invalid_2() {
+ CharArray a = ArrayFactory.create(new char[2]);
+ a.set(0, "aa");
+ }
+
+ @Test
+ public void charSetDouble() {
+ CharArray a = ArrayFactory.create(new char[2]);
+ a.set(0, 1.0d);
+ assertEquals(a.get(0), Character.valueOf((char) 1));
+ }
+
+ @Test
+ public void charSetDouble_2() {
+ CharArray a = ArrayFactory.create(new char[2]);
+ a.set(0, 0.0d);
+ assertEquals(a.get(0), Character.valueOf((char) 0));
+ }
+
+ @Test
+ public void charSetDouble_3() {
+ CharArray a = ArrayFactory.create(new char[2]);
+ a.set(0, 10.0d);
+ assertEquals(a.get(0), Character.valueOf((char) 10));
+ }
+
public static BitSetArray createTrueBitArray(int length) {
BitSet init = new BitSet();
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java
index 36d37605d1..ca707b7156 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java
@@ -71,8 +71,13 @@ public class FrameArrayConstantTests {
public void testConstruction() {
try {
Array<?> a = ArrayFactory.allocate(t, nRow, "0");
- for(int i = 0; i < nRow; i++)
- assertEquals(a.getAsDouble(i), 0.0,
0.0000000001);
+ if(a.getValueType() == ValueType.CHARACTER)
+
+ for(int i = 0; i < nRow; i++)
+ assertEquals(a.getAsDouble(i), 48.0,
0.0000000001);
+ else
+ for(int i = 0; i < nRow; i++)
+ assertEquals(a.getAsDouble(i), 0.0,
0.0000000001);
}
catch(Exception e) {
e.printStackTrace();
@@ -84,7 +89,7 @@ public class FrameArrayConstantTests {
public void testConstruction_default() {
try {
Array<?> a = ArrayFactory.allocate(t, nRow);
- if(t != ValueType.STRING)
+ if(t != ValueType.STRING && t != ValueType.CHARACTER)
for(int i = 0; i < nRow; i++)
assertEquals(a.getAsDouble(i), 0.0,
0.0000000001);
}
@@ -111,7 +116,7 @@ public class FrameArrayConstantTests {
public void testConstruction_null() {
try {
Array<?> a = ArrayFactory.allocate(t, nRow, null);
- if(t != ValueType.STRING)
+ if(t != ValueType.STRING && t != ValueType.CHARACTER)
for(int i = 0; i < nRow; i++)
assertEquals(a.getAsDouble(i), 0.0,
0.0000000001);
}
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
index 6f81a68f4c..8ddf4d8551 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java
@@ -93,6 +93,10 @@ public class FrameArrayTests {
tests.add(new Object[]
{ArrayFactory.create(generateRandomTrueFalseString(80, 221)),
FrameArrayType.STRING});
tests.add(new Object[]
{ArrayFactory.create(generateRandomTrueFalseString(150, 221)),
FrameArrayType.STRING});
+ tests.add(new Object[] {ArrayFactory.create(new char[]
{0, 0, 0, 0, 1, 1, 1}), FrameArrayType.CHARACTER});
+ tests.add(new Object[] {ArrayFactory.create(new char[]
{'t', 't', 'f', 'f', 'T'}), FrameArrayType.CHARACTER});
+ tests.add(new Object[] {ArrayFactory.create(new char[]
{'0', '2', '3', '4', '9'}), FrameArrayType.CHARACTER});
+ tests.add(new Object[]
{ArrayFactory.create(generateRandom01chars(150, 221)),
FrameArrayType.CHARACTER});
// Long to int
tests.add(new Object[] {ArrayFactory.create(new long[]
{3214, 424, 13, 22, 111, 134}), FrameArrayType.INT64});
}
@@ -191,6 +195,11 @@ public class FrameArrayTests {
changeType(ValueType.BOOLEAN);
}
+ @Test
+ public void changeTypeCharacter() {
+ changeType(ValueType.CHARACTER);
+ }
+
public void changeType(ValueType t) {
try {
Array<?> r = a.changeType(t);
@@ -275,6 +284,9 @@ public class FrameArrayTests {
case STRING:
x = (String[]) a.get();
return;
+ case CHARACTER:
+ x = (char[]) a.get();
+ return;
default:
throw new NotImplementedException();
}
@@ -328,6 +340,10 @@ public class FrameArrayTests {
case STRING:
((Array<String>) aa).set(start, end,
(Array<String>) a, off);
break;
+ case CHARACTER:
+
+ ((Array<Character>) aa).set(start, end,
(Array<Character>) a, off);
+ break;
default:
throw new NotImplementedException();
}
@@ -371,6 +387,9 @@ public class FrameArrayTests {
case STRING:
other =
ArrayFactory.create(generateRandomString(otherSize, seed));
break;
+ case CHARACTER:
+ other =
ArrayFactory.create(generateRandomChar(otherSize, seed));
+ break;
default:
throw new NotImplementedException();
}
@@ -396,6 +415,9 @@ public class FrameArrayTests {
case STRING:
((Array<String>) aa).set(start, end,
(Array<String>) other);
break;
+ case CHARACTER:
+ ((Array<Character>) aa).set(start, end,
(Array<Character>) other);
+ break;
default:
throw new NotImplementedException();
}
@@ -427,7 +449,6 @@ public class FrameArrayTests {
((Array<Integer>) a).set(0, vi);
assertEquals(((Array<Integer>) a).get(0), vi);
return;
-
case INT64:
Long vl = 1324L;
((Array<Long>) a).set(0, vl);
@@ -435,17 +456,19 @@ public class FrameArrayTests {
return;
case BOOLEAN:
case BITSET:
-
Boolean vb = true;
((Array<Boolean>) a).set(0, vb);
assertEquals(((Array<Boolean>) a).get(0), vb);
return;
case STRING:
-
String vs = "1324L";
a.set(0, vs);
assertEquals(((Array<String>) a).get(0), vs);
-
+ return;
+ case CHARACTER:
+ Character c = '~';
+ ((Array<Character>) a).set(0, c);
+ assertEquals(((Array<Character>) a).get(0), c);
return;
default:
throw new NotImplementedException();
@@ -477,6 +500,9 @@ public class FrameArrayTests {
case STRING:
assertEquals(((Array<String>) a).get(0),
Double.toString(vd));
return;
+ case CHARACTER:
+ assertEquals((int) ((Array<Character>)
a).get(0), 1);
+ return;
default:
throw new NotImplementedException();
}
@@ -507,6 +533,9 @@ public class FrameArrayTests {
case STRING:
assertEquals(((Array<String>) a).get(0),
Double.toString(vd));
return;
+ case CHARACTER:
+ assertEquals(((Array<Character>) a).get(0),
Character.valueOf((char) 0));
+ return;
default:
throw new NotImplementedException();
}
@@ -671,6 +700,16 @@ public class FrameArrayTests {
aa.append(vi8s);
assertEquals((int) aa.get(aa.size() - 1), vi8);
break;
+ case CHARACTER:
+ char vc = '@';
+ String vci = vc + "";
+ aa.append(vci);
+ assertEquals((char) aa.get(aa.size() - 1), vc);
+ vc = (char) 42;
+ vci = vc + "";
+ aa.append(vci);
+ assertEquals((char) aa.get(aa.size() - 1), vc);
+ break;
case UNKNOWN:
default:
throw new DMLRuntimeException("Invalid type");
@@ -704,6 +743,9 @@ public class FrameArrayTests {
case UINT8:
assertEquals((int) aa.get(aa.size() - 1), 0);
break;
+ case CHARACTER:
+ assertEquals((char) aa.get(aa.size() - 1), 0);
+ break;
case UNKNOWN:
default:
throw new DMLRuntimeException("Invalid type");
@@ -741,6 +783,9 @@ public class FrameArrayTests {
case UINT8:
assertEquals((int) aa.get(aa.size() -
1), 0);
break;
+ case CHARACTER:
+ assertEquals((char) aa.get(aa.size() -
1), 0);
+ break;
case UNKNOWN:
default:
throw new DMLRuntimeException("Invalid
type");
@@ -778,6 +823,9 @@ public class FrameArrayTests {
case STRING:
((Array<String>)
aa).setNz((Array<String>) a);
break;
+ case CHARACTER:
+ ((Array<Character>)
aa).setNz((Array<Character>) a);
+ break;
case UNKNOWN:
default:
throw new DMLRuntimeException("Invalid
type");
@@ -916,6 +964,8 @@ public class FrameArrayTests {
return
ArrayFactory.create(generateRandomFloat(size, seed));
case FP64:
return
ArrayFactory.create(generateRandomDouble(size, seed));
+ case CHARACTER:
+ return
ArrayFactory.create(generateRandomChar(size, seed));
default:
throw new DMLRuntimeException("Unsupported
value type: " + t);
@@ -946,6 +996,14 @@ public class FrameArrayTests {
return ret;
}
+ public static char[] generateRandom01chars(int size, int seed) {
+ Random r = new Random(seed);
+ char[] ret = new char[size];
+ for(int i = 0; i < size; i++)
+ ret[i] = (char) r.nextInt(1);
+ return ret;
+ }
+
public static String[] generateRandomTrueFalseString(int size, int
seed) {
Random r = new Random(seed);
String[] ret = new String[size];
@@ -970,6 +1028,14 @@ public class FrameArrayTests {
return ret;
}
+ protected static char[] generateRandomChar(int size, int seed) {
+ Random r = new Random(seed);
+ char[] ret = new char[size];
+ for(int i = 0; i < size; i++)
+ ret[i] = (char) r.nextInt((int) Character.MAX_VALUE);
+ return ret;
+ }
+
protected static int[] generateRandomInt8(int size, int seed) {
Random r = new Random(seed);
int[] ret = new int[size];
diff --git
a/src/test/java/org/apache/sysds/test/functions/frame/FrameDropInvalidTypeTest.java
b/src/test/java/org/apache/sysds/test/functions/frame/FrameDropInvalidTypeTest.java
index f0095377ea..5189b20459 100644
---
a/src/test/java/org/apache/sysds/test/functions/frame/FrameDropInvalidTypeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/frame/FrameDropInvalidTypeTest.java
@@ -208,6 +208,7 @@ public class FrameDropInvalidTypeTest extends
AutomatedTestBase {
writer.writeFrameToHDFS(frame2, input("M"), 1,
schema.length);
runTest(null);
FrameBlock frameout = readDMLFrameFromHDFS("B",
FileFormat.BINARY);
+ LOG.error(frameout);
//read output data and compare results
ArrayList<Object> data = new ArrayList<>();
for (int i = 0; i < frameout.getNumRows(); i++)
diff --git
a/src/test/java/org/apache/sysds/test/functions/io/proto/FrameReaderWriterProtoTest.java
b/src/test/java/org/apache/sysds/test/functions/io/proto/FrameReaderWriterProtoTest.java
index 8bbcd1fcff..5dbf8b2d7c 100644
---
a/src/test/java/org/apache/sysds/test/functions/io/proto/FrameReaderWriterProtoTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/io/proto/FrameReaderWriterProtoTest.java
@@ -19,6 +19,8 @@
package org.apache.sysds.test.functions.io.proto;
+import static org.junit.Assert.fail;
+
import java.io.IOException;
import java.util.Random;
@@ -66,16 +68,23 @@ public class FrameReaderWriterProtoTest {
}
public void testWriteReadFrameBlockWith(int rows, int cols) throws
IOException {
- final Random random = new Random(SEED);
- Types.ValueType[] schema = TestUtils.generateRandomSchema(cols,
random);
- FrameBlock expectedFrame =
TestUtils.generateRandomFrameBlock(rows, schema, random);
-
- frameWriterProto.writeFrameToHDFS(expectedFrame,
FILENAME_SINGLE, rows, cols);
- FrameBlock actualFrame =
frameReaderProto.readFrameFromHDFS(FILENAME_SINGLE, schema, rows, cols);
-
- String[][] expected =
DataConverter.convertToStringFrame(expectedFrame);
- String[][] actual =
DataConverter.convertToStringFrame(actualFrame);
+ try{
- TestUtils.compareFrames(expected, actual, rows, cols);
+ final Random random = new Random(SEED);
+ Types.ValueType[] schema =
TestUtils.generateRandomSchema(cols, random);
+ FrameBlock expectedFrame =
TestUtils.generateRandomFrameBlock(rows, schema, random);
+
+ frameWriterProto.writeFrameToHDFS(expectedFrame,
FILENAME_SINGLE, rows, cols);
+ FrameBlock actualFrame =
frameReaderProto.readFrameFromHDFS(FILENAME_SINGLE, schema, rows, cols);
+
+ String[][] expected =
DataConverter.convertToStringFrame(expectedFrame);
+ String[][] actual =
DataConverter.convertToStringFrame(actualFrame);
+
+ TestUtils.compareFrames(expected, actual, rows, cols);
+ }
+ catch(Exception e){
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
}
}