This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 2bc9a7eb87 [SYSTEMDS-3490] Compressed Transform Encode
2bc9a7eb87 is described below
commit 2bc9a7eb870e1ff540408b455193ff52f7ede38c
Author: baunsgaard <[email protected]>
AuthorDate: Thu Jan 26 18:24:19 2023 +0100
[SYSTEMDS-3490] Compressed Transform Encode
Transform encode fused with compression. Making a compressed output from
the frame input depending on the transformations applied. Initial results
are very promising transforming single threaded at the same speed as our
tuned multithreaded version.
This commit contains the bare minimum for the transform encode, and
following commits will add more transformation pipelines.
Currently supported is Recode to dummy, recode, and pass through in
very naive implementations.
Also contained is an IdentityDictionary implementation that allows
one to specify that the compressed dictionary simply is the identity
matrix. This allocation is very small of a object and a integer specifying
the number of rows and columns contained in the Identity matrix.
To make the implementation efficient initially a soft reference to a
materialized MatrixBlock dictionary is materialized in all not supported
cases of operations the IdentityDictionary.
Closes #1772
---
src/main/java/org/apache/sysds/common/Types.java | 6 +-
src/main/java/org/apache/sysds/conf/DMLConfig.java | 4 +-
src/main/java/org/apache/sysds/hops/LiteralOp.java | 1 +
.../sysds/parser/BuiltinFunctionExpression.java | 1 +
.../runtime/compress/CompressedMatrixBlock.java | 3 +-
.../runtime/compress/colgroup/ColGroupFactory.java | 8 +-
.../compress/colgroup/dictionary/ADictionary.java | 2 +-
.../colgroup/dictionary/DictionaryFactory.java | 2 +-
.../colgroup/dictionary/IdentityDictionary.java | 584 +++++++++++++++++++++
.../apache/sysds/runtime/data/LibTensorAgg.java | 1 +
.../org/apache/sysds/runtime/data/TensorBlock.java | 2 +
.../sysds/runtime/frame/data/FrameBlock.java | 48 +-
.../sysds/runtime/frame/data/columns/Array.java | 75 ++-
.../runtime/frame/data/columns/ArrayFactory.java | 6 +-
.../runtime/frame/data/columns/OptionalArray.java | 5 +
.../runtime/frame/data/columns/StringArray.java | 32 ++
...ltiReturnParameterizedBuiltinSPInstruction.java | 2 +-
.../spark/ParameterizedBuiltinSPInstruction.java | 2 +-
.../runtime/transform/encode/ColumnEncoder.java | 5 +
.../runtime/transform/encode/ColumnEncoderBin.java | 9 +
.../transform/encode/ColumnEncoderComposite.java | 42 +-
.../transform/encode/ColumnEncoderDummycode.java | 12 +
.../transform/encode/ColumnEncoderFeatureHash.java | 9 +
.../transform/encode/ColumnEncoderPassThrough.java | 9 +
.../transform/encode/ColumnEncoderRecode.java | 58 +-
.../runtime/transform/encode/CompressedEncode.java | 194 +++++++
.../transform/encode/MultiColumnEncoder.java | 92 ++--
.../apache/sysds/runtime/util/DataConverter.java | 1 +
.../apache/sysds/runtime/util/UtilFunctions.java | 1 +
src/test/java/org/apache/sysds/test/TestUtils.java | 5 +
.../component/frame/array/CustomArrayTests.java | 2 +-
.../frame/transform/transformCompressed.java | 114 ++++
32 files changed, 1213 insertions(+), 124 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 0573da2de3..ab81ff4e31 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -75,12 +75,12 @@ public class Types
* Value types (int, double, string, boolean, unknown).
*/
public enum ValueType {
- UINT8, // Used for parsing in UINT values from numpy.
+ UINT4, UINT8, // Used for parsing in UINT values from numpy.
FP32, FP64, INT32, INT64, BOOLEAN, STRING, UNKNOWN,
CHARACTER;
public boolean isNumeric() {
- return this == UINT8 || this == INT32 || this == INT64
|| this == FP32 || this == FP64;
+ return this == UINT8 || this == INT32 || this == INT64
|| this == FP32 || this == FP64 || this== UINT4;
}
public boolean isUnknown() {
return this == UNKNOWN;
@@ -92,6 +92,7 @@ public class Types
switch(this) {
case FP32:
case FP64: return "DOUBLE";
+ case UINT4:
case UINT8:
case INT32:
case INT64: return "INT";
@@ -107,6 +108,7 @@ public class Types
case "FP32": return FP32;
case "FP64":
case "DOUBLE": return FP64;
+ case "UINT4": return UINT4;
case "UINT8": return UINT8;
case "INT32": return INT32;
case "INT64":
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index dad670efd4..46580c294b 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -86,6 +86,7 @@ public class DMLConfig
public static final String COMPRESSED_COCODE =
"sysds.compressed.cocode";
public static final String COMPRESSED_COST_MODEL=
"sysds.compressed.costmodel";
public static final String COMPRESSED_TRANSPOSE =
"sysds.compressed.transpose";
+ public static final String COMPRESSED_TRANSFORMENCODE =
"sysds.compressed.transformencode";
public static final String NATIVE_BLAS = "sysds.native.blas";
public static final String NATIVE_BLAS_DIR =
"sysds.native.blas.directory";
public static final String DAG_LINEARIZATION =
"sysds.compile.linearization";
@@ -167,6 +168,7 @@ public class DMLConfig
_defaultVals.put(COMPRESSED_COCODE, "AUTO");
_defaultVals.put(COMPRESSED_COST_MODEL, "AUTO");
_defaultVals.put(COMPRESSED_TRANSPOSE, "auto");
+ _defaultVals.put(COMPRESSED_TRANSFORMENCODE, "false");
_defaultVals.put(DAG_LINEARIZATION,
DagLinearization.DEPTH_FIRST.name());
_defaultVals.put(CODEGEN, "false" );
_defaultVals.put(CODEGEN_API,
GeneratorAPI.JAVA.name() );
@@ -450,7 +452,7 @@ public class DMLConfig
CP_PARALLEL_OPS, CP_PARALLEL_IO, PARALLEL_ENCODE,
NATIVE_BLAS, NATIVE_BLAS_DIR,
COMPRESSED_LINALG, COMPRESSED_LOSSY,
COMPRESSED_VALID_COMPRESSIONS, COMPRESSED_OVERLAPPING,
COMPRESSED_SAMPLING_RATIO,
COMPRESSED_SOFT_REFERENCE_COUNT,
- COMPRESSED_COCODE, COMPRESSED_TRANSPOSE,
DAG_LINEARIZATION,
+ COMPRESSED_COCODE, COMPRESSED_TRANSPOSE,
COMPRESSED_TRANSFORMENCODE, DAG_LINEARIZATION,
CODEGEN, CODEGEN_API, CODEGEN_COMPILER,
CODEGEN_OPTIMIZER, CODEGEN_PLANCACHE, CODEGEN_LITERALS,
STATS_MAX_WRAP_LEN, LINEAGECACHESPILL,
COMPILERASSISTED_RW, BUFFERPOOL_LIMIT, MEMORY_MANAGER,
PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU,
EAGER_CUDA_FREE, FLOATING_POINT_PRECISION,
diff --git a/src/main/java/org/apache/sysds/hops/LiteralOp.java
b/src/main/java/org/apache/sysds/hops/LiteralOp.java
index 75bc73db12..5d3f06bd66 100644
--- a/src/main/java/org/apache/sysds/hops/LiteralOp.java
+++ b/src/main/java/org/apache/sysds/hops/LiteralOp.java
@@ -246,6 +246,7 @@ public class LiteralOp extends Hop
switch( getValueType() ) {
case BOOLEAN:
return String.valueOf(value_boolean);
+ case UINT4:
case UINT8:
case INT32:
case INT64:
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 09896c1ccf..80e9f75b80 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -615,6 +615,7 @@ public class BuiltinFunctionExpression extends
DataIdentifier
case INT64:
case INT32:
case UINT8:
+ case UINT4:
case BOOLEAN:
output.setValueType(ValueType.INT64);
break;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 0a0c4b8116..7f47f6de96 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -399,10 +399,11 @@ public class CompressedMatrixBlock extends MatrixBlock {
@Override
public void write(DataOutput out) throws IOException {
+ // LOG.error(this);
if(nonZeros > 0 && getExactSizeOnDisk() >
MatrixBlock.estimateSizeOnDisk(rlen, clen, nonZeros)) {
// If the size of this matrixBlock is smaller in
uncompressed format, then
// decompress and save inside an uncompressed column
group.
- MatrixBlock uncompressed = getUncompressed("for smaller
serialization");
+ MatrixBlock uncompressed = getUncompressed("smaller
serialization size");
ColGroupUncompressed cg = (ColGroupUncompressed)
ColGroupUncompressed.create(uncompressed);
allocateColGroup(cg);
nonZeros = cg.getNumberNonZeros(rlen);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 9eee973a70..9154f1af25 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -646,7 +646,10 @@ public class ColGroupFactory {
// count distinct items frequencies
for(int j = apos; j < alen; j++)
- map.increment(vals[j]);
+ if(!Double.isNaN(vals[j]))
+ map.increment(vals[j]);
+ else
+ map.increment(0);
DCounts[] entries = map.extractValues();
Arrays.sort(entries, Comparator.comparing(x -> -x.count));
@@ -668,7 +671,10 @@ public class ColGroupFactory {
else {
final AMapToData mapToData =
MapToFactory.create((alen - apos), entries.length);
for(int j = apos; j < alen; j++)
+ if(!Double.isNaN(vals[j]))
mapToData.set(j - apos,
map.get(vals[j]));
+ else
+ mapToData.set(j - apos, map.get(0.0));
return ColGroupSDCZeros.create(cols, nRow,
Dictionary.create(dict), offsets, mapToData, counts);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index dd9557dc41..0651e78b3b 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -44,7 +44,7 @@ public abstract class ADictionary implements Serializable {
protected static final Log LOG =
LogFactory.getLog(ADictionary.class.getName());
public static enum DictType {
- Delta, Dict, MatrixBlock, UInt8;
+ Delta, Dict, MatrixBlock, UInt8, Identity;
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
index a777201d00..7437df19ad 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
@@ -41,7 +41,7 @@ public interface DictionaryFactory {
static final Log LOG =
LogFactory.getLog(DictionaryFactory.class.getName());
public enum Type {
- FP64_DICT, MATRIX_BLOCK_DICT, INT8_DICT
+ FP64_DICT, MATRIX_BLOCK_DICT, INT8_DICT, IDENTITY
}
public static ADictionary read(DataInput in) throws IOException {
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
new file mode 100644
index 0000000000..23d5c36019
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup.dictionary;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.ref.SoftReference;
+import java.util.Arrays;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+
+public class IdentityDictionary extends ADictionary {
+
+ private static final long serialVersionUID = 2535887782150955098L;
+
+ // final private MatrixBlock _data;
+ private final int nRowCol;
+
+ private SoftReference<MatrixBlockDictionary> cache = null;
+
+ /**
+ * Create a Identity matrix dictionary. It behaves as if allocated a
Sparse Matrix block but exploits that the
+ * structure is known to have certain properties.
+ *
+ * @param nRowCol the number of rows and columns in this identity
matrix.
+ */
+ public IdentityDictionary(int nRowCol) {
+ if(nRowCol <= 0)
+ throw new DMLCompressionException("Invalid Identity
Dictionary");
+ this.nRowCol = nRowCol;
+ }
+
+ @Override
+ public double[] getValues() {
+ LOG.warn("Should not call getValues on Identity Dictionary");
+
+ double[] ret = new double[nRowCol * nRowCol];
+ for(int i = 0; i < nRowCol; i++) {
+ ret[(i * nRowCol) + i] = 1;
+ }
+ return ret;
+ }
+
+ @Override
+ public double getValue(int i) {
+ final int nCol = nRowCol;
+ final int row = i / nCol;
+ if(row > nRowCol)
+ return 0;
+ final int col = i % nCol;
+ return row == col ? 1 : 0;
+ }
+
+ @Override
+ public final double getValue(int r, int c, int nCol) {
+ return r == c ? 1 : 0;
+ }
+
+ @Override
+ public long getInMemorySize() {
+ return 4 + 4 + 8; // int + padding + softReference
+ }
+
+ public static long getInMemorySize(int numberColumns) {
+ return 4 + 4 + 8;
+ }
+
+ @Override
+ public double aggregate(double init, Builtin fn) {
+ if(fn.getBuiltinCode() == BuiltinCode.MAX)
+ return fn.execute(init, 1);
+ else if(fn.getBuiltinCode() == BuiltinCode.MIN)
+ return fn.execute(init, 0);
+ else
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public double aggregateWithReference(double init, Builtin fn, double[]
reference, boolean def) {
+ return getMBDict().aggregateWithReference(init, fn, reference,
def);
+ }
+
+ @Override
+ public double[] aggregateRows(Builtin fn, int nCol) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, fn.execute(1, 0));
+ return ret;
+ }
+
+ @Override
+ public double[] aggregateRowsWithDefault(Builtin fn, double[]
defaultTuple) {
+ return getMBDict().aggregateRowsWithDefault(fn, defaultTuple);
+ }
+
+ @Override
+ public double[] aggregateRowsWithReference(Builtin fn, double[]
reference) {
+ return getMBDict().aggregateRowsWithReference(fn, reference);
+ }
+
+ @Override
+ public void aggregateCols(double[] c, Builtin fn, int[] colIndexes) {
+ for(int i = 0; i < nRowCol; i++) {
+ final int idx = colIndexes[i];
+ c[idx] = fn.execute(c[idx], 0);
+ c[idx] = fn.execute(c[idx], 1);
+ }
+ }
+
+ @Override
+ public void aggregateColsWithReference(double[] c, Builtin fn, int[]
colIndexes, double[] reference, boolean def) {
+ getMBDict().aggregateColsWithReference(c, fn, colIndexes,
reference, def);
+ }
+
+ @Override
+ public ADictionary applyScalarOp(ScalarOperator op) {
+ return getMBDict().applyScalarOp(op);
+ }
+
+ @Override
+ public ADictionary applyScalarOpAndAppend(ScalarOperator op, double v0,
int nCol) {
+
+ return getMBDict().applyScalarOpAndAppend(op, v0, nCol);
+ }
+
+ @Override
+ public ADictionary applyUnaryOp(UnaryOperator op) {
+ return getMBDict().applyUnaryOp(op);
+ }
+
+ @Override
+ public ADictionary applyUnaryOpAndAppend(UnaryOperator op, double v0,
int nCol) {
+ return getMBDict().applyUnaryOpAndAppend(op, v0, nCol);
+ }
+
+ @Override
+ public ADictionary applyScalarOpWithReference(ScalarOperator op,
double[] reference, double[] newReference) {
+ return getMBDict().applyScalarOpWithReference(op, reference,
newReference);
+ }
+
+ @Override
+ public ADictionary applyUnaryOpWithReference(UnaryOperator op, double[]
reference, double[] newReference) {
+ return getMBDict().applyUnaryOpWithReference(op, reference,
newReference);
+ }
+
+ @Override
+ public ADictionary binOpLeft(BinaryOperator op, double[] v, int[]
colIndexes) {
+ return getMBDict().binOpLeft(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpLeftAndAppend(BinaryOperator op, double[] v,
int[] colIndexes) {
+ return getMBDict().binOpLeftAndAppend(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpLeftWithReference(BinaryOperator op, double[]
v, int[] colIndexes, double[] reference,
+ double[] newReference) {
+ return getMBDict().binOpLeftWithReference(op, v, colIndexes,
reference, newReference);
+
+ }
+
+ @Override
+ public ADictionary binOpRight(BinaryOperator op, double[] v, int[]
colIndexes) {
+ return getMBDict().binOpRight(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpRightAndAppend(BinaryOperator op, double[] v,
int[] colIndexes) {
+ return getMBDict().binOpRightAndAppend(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpRight(BinaryOperator op, double[] v) {
+ return getMBDict().binOpRight(op, v);
+ }
+
+ @Override
+ public ADictionary binOpRightWithReference(BinaryOperator op, double[]
v, int[] colIndexes, double[] reference,
+ double[] newReference) {
+ return getMBDict().binOpRightWithReference(op, v, colIndexes,
reference, newReference);
+ }
+
+ @Override
+ public ADictionary clone() {
+ return new IdentityDictionary(nRowCol);
+ }
+
+ @Override
+ public DictType getDictType() {
+ return DictType.Identity;
+ }
+
+ @Override
+ public int getNumberOfValues(int ncol) {
+ return nRowCol;
+ }
+
+ @Override
+ public double[] sumAllRowsToDouble(int nrColumns) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ for(int i = 0; i < defaultTuple.length; i++)
+ ret[i] += defaultTuple[i];
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleWithReference(double[] reference) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ for(int i = 0; i < reference.length; i++)
+ ret[i] += reference[i] * nRowCol;
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleSq(int nrColumns) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) {
+ return
getMBDict().sumAllRowsToDoubleSqWithDefault(defaultTuple);
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleSqWithReference(double[] reference) {
+ return getMBDict().sumAllRowsToDoubleSqWithReference(reference);
+ }
+
+ @Override
+ public double[] productAllRowsToDouble(int nCol) {
+ return new double[nRowCol];
+ }
+
+ @Override
+ public double[] productAllRowsToDoubleWithDefault(double[]
defaultTuple) {
+ return new double[nRowCol];
+ }
+
+ @Override
+ public double[] productAllRowsToDoubleWithReference(double[] reference)
{
+ return
getMBDict().productAllRowsToDoubleWithReference(reference);
+ }
+
+ @Override
+ public void colSum(double[] c, int[] counts, int[] colIndexes) {
+ for(int i = 0; i < colIndexes.length; i++) {
+ // very nice...
+ final int idx = colIndexes[i];
+ c[idx] = counts[i];
+ }
+ }
+
+ @Override
+ public void colSumSq(double[] c, int[] counts, int[] colIndexes) {
+ colSum(c, counts, colIndexes);
+ }
+
+ @Override
+ public void colProduct(double[] res, int[] counts, int[] colIndexes) {
+ for(int i = 0; i < colIndexes.length; i++) {
+ res[colIndexes[i]] = 0;
+ }
+ }
+
+ @Override
+ public void colProductWithReference(double[] res, int[] counts, int[]
colIndexes, double[] reference) {
+ getMBDict().colProductWithReference(res, counts, colIndexes,
reference);
+
+ }
+
+ @Override
+ public void colSumSqWithReference(double[] c, int[] counts, int[]
colIndexes, double[] reference) {
+ getMBDict().colSumSqWithReference(c, counts, colIndexes,
reference);
+ }
+
+ @Override
+ public double sum(int[] counts, int ncol) {
+ double s = 0.0;
+ for(int v : counts)
+ s += v;
+ return s;
+ }
+
+ @Override
+ public double sumSq(int[] counts, int ncol) {
+ return sum(counts, ncol);
+ }
+
+ @Override
+ public double sumSqWithReference(int[] counts, double[] reference) {
+ return getMBDict().sumSqWithReference(counts, reference);
+ }
+
+ @Override
+ public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int
previousNumberOfColumns) {
+ return getMBDict().sliceOutColumnRange(idxStart, idxEnd,
previousNumberOfColumns);
+ }
+
+ @Override
+ public boolean containsValue(double pattern) {
+ return pattern == 0.0 || pattern == 1.0;
+ }
+
+ @Override
+ public boolean containsValueWithReference(double pattern, double[]
reference) {
+ return getMBDict().containsValueWithReference(pattern,
reference);
+ }
+
+ @Override
+ public long getNumberNonZeros(int[] counts, int nCol) {
+ return (long) sum(counts, nCol);
+ }
+
+ @Override
+ public long getNumberNonZerosWithReference(int[] counts, double[]
reference, int nRows) {
+ return getMBDict().getNumberNonZerosWithReference(counts,
reference, nRows);
+ }
+
+ @Override
+ public void addToEntry(final double[] v, final int fr, final int to,
final int nCol) {
+ getMBDict().addToEntry(v, fr, to, nCol);
+ }
+
+ @Override
+ public void addToEntry(final double[] v, final int fr, final int to,
final int nCol, int rep) {
+ getMBDict().addToEntry(v, fr, to, nCol, rep);
+ }
+
+ @Override
+ public void addToEntryVectorized(double[] v, int f1, int f2, int f3,
int f4, int f5, int f6, int f7, int f8, int t1,
+ int t2, int t3, int t4, int t5, int t6, int t7, int t8, int
nCol) {
+ getMBDict().addToEntryVectorized(v, f1, f2, f3, f4, f5, f6, f7,
f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol);
+ }
+
+ @Override
+ public ADictionary subtractTuple(double[] tuple) {
+ return getMBDict().subtractTuple(tuple);
+ }
+
+ public MatrixBlockDictionary getMBDict() {
+ return getMBDict(nRowCol);
+ }
+
+ @Override
+ public MatrixBlockDictionary getMBDict(int nCol) {
+ if(cache != null) {
+ MatrixBlockDictionary r = cache.get();
+ if(r != null)
+ return r;
+ }
+ MatrixBlockDictionary ret = createMBDict();
+ cache = new SoftReference<>(ret);
+ return ret;
+ }
+
+ private MatrixBlockDictionary createMBDict() {
+ MatrixBlock identity = new MatrixBlock(nRowCol, nRowCol, true);
+ for(int i = 0; i < nRowCol; i++)
+ identity.quickSetValue(i, i, 1.0);
+
+ return new MatrixBlockDictionary(identity);
+ }
+
+ @Override
+ public String getString(int colIndexes) {
+ return "IdentityMatrix of size: " + nRowCol;
+ }
+
+ @Override
+ public String toString() {
+ return "IdentityMatrix of size: " + nRowCol;
+ }
+
+ @Override
+ public ADictionary scaleTuples(int[] scaling, int nCol) {
+ return getMBDict().scaleTuples(scaling, nCol);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeByte(DictionaryFactory.Type.IDENTITY.ordinal());
+ out.writeInt(nRowCol);
+ }
+
+ public static IdentityDictionary read(DataInput in) throws IOException {
+ return new IdentityDictionary(in.readInt());
+ }
+
+ @Override
+ public long getExactSizeOnDisk() {
+ return 1 + 4;
+ }
+
+ @Override
+ public ADictionary preaggValuesFromDense(final int numVals, final int[]
colIndexes, final int[] aggregateColumns,
+ final double[] b, final int cut) {
+ return getMBDict().preaggValuesFromDense(numVals, colIndexes,
aggregateColumns, b, cut);
+ }
+
+ @Override
+ public ADictionary replace(double pattern, double replace, int nCol) {
+ if(containsValue(pattern))
+ return getMBDict().replace(pattern, replace, nCol);
+ else
+ return this;
+ }
+
+ @Override
+ public ADictionary replaceWithReference(double pattern, double replace,
double[] reference) {
+ if(containsValueWithReference(pattern, reference))
+ return getMBDict().replaceWithReference(pattern,
replace, reference);
+ else
+ return this;
+ }
+
+ @Override
+ public void product(double[] ret, int[] counts, int nCol) {
+ getMBDict().product(ret, counts, nCol);
+ }
+
+ @Override
+ public void productWithDefault(double[] ret, int[] counts, double[]
def, int defCount) {
+ getMBDict().productWithDefault(ret, counts, def, defCount);
+ }
+
+ @Override
+ public void productWithReference(double[] ret, int[] counts, double[]
reference, int refCount) {
+ getMBDict().productWithReference(ret, counts, reference,
refCount);
+ }
+
+ @Override
+ public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn,
int[] counts, int nRows) {
+ return getMBDict().centralMoment(ret, fn, counts, nRows);
+ }
+
+ @Override
+ public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret,
ValueFunction fn, int[] counts, double def,
+ int nRows) {
+ return getMBDict().centralMomentWithDefault(ret, fn, counts,
def, nRows);
+ }
+
+ @Override
+ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret,
ValueFunction fn, int[] counts, double reference,
+ int nRows) {
+ return getMBDict().centralMomentWithReference(ret, fn, counts,
reference, nRows);
+ }
+
+ @Override
+ public ADictionary rexpandCols(int max, boolean ignore, boolean cast,
int nCol) {
+ return getMBDict().rexpandCols(max, ignore, cast, nCol);
+ }
+
+ @Override
+ public ADictionary rexpandColsWithReference(int max, boolean ignore,
boolean cast, int reference) {
+ return getMBDict().rexpandColsWithReference(max, ignore, cast,
reference);
+ }
+
+ @Override
+ public double getSparsity() {
+ // non-zeros / n cells
+ // nRowCol / (nRowCol * nRowCol)
+ // simplifies to
+ return 1.0d / (double) nRowCol;
+ }
+
+ @Override
+ public void multiplyScalar(double v, double[] ret, int off, int
dictIdx, int[] cols) {
+ getMBDict().multiplyScalar(v, ret, off, dictIdx, cols);
+ }
+
+ @Override
+ protected void TSMMWithScaling(int[] counts, int[] rows, int[] cols,
MatrixBlock ret) {
+ getMBDict().TSMMWithScaling(counts, rows, cols, ret);
+ }
+
+ @Override
+ protected void MMDict(ADictionary right, int[] rowsLeft, int[]
colsRight, MatrixBlock result) {
+ getMBDict().MMDict(right, rowsLeft, colsRight, result);
+ // should replace with add to right to output cells.
+ }
+
+ @Override
+ protected void MMDictDense(double[] left, int[] rowsLeft, int[]
colsRight, MatrixBlock result) {
+ getMBDict().MMDictDense(left, rowsLeft, colsRight, result);
+ // should replace with add to right to output cells.
+ }
+
+ @Override
+ protected void MMDictSparse(SparseBlock left, int[] rowsLeft, int[]
colsRight, MatrixBlock result) {
+ getMBDict().MMDictSparse(left, rowsLeft, colsRight, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangle(ADictionary right, int[] rowsLeft,
int[] colsRight, MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangle(right, rowsLeft, colsRight,
result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleDense(double[] left, int[] rowsLeft,
int[] colsRight, MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleDense(left, rowsLeft, colsRight,
result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleSparse(SparseBlock left, int[]
rowsLeft, int[] colsRight, MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleSparse(left, rowsLeft,
colsRight, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleScaling(ADictionary right, int[]
rowsLeft, int[] colsRight, int[] scale,
+ MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleScaling(right, rowsLeft,
colsRight, scale, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleDenseScaling(double[] left, int[]
rowsLeft, int[] colsRight, int[] scale,
+ MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleDenseScaling(left, rowsLeft,
colsRight, scale, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleSparseScaling(SparseBlock left, int[]
rowsLeft, int[] colsRight, int[] scale,
+ MatrixBlock result) {
+
+ getMBDict().TSMMToUpperTriangleSparseScaling(left, rowsLeft,
colsRight, scale, result);
+ }
+
+ @Override
+ public boolean equals(ADictionary o) {
+ if(o instanceof IdentityDictionary)
+ return ((IdentityDictionary) o).nRowCol == nRowCol;
+
+ MatrixBlock mb = getMBDict().getMatrixBlock();
+ if(o instanceof MatrixBlockDictionary)
+ return mb.equals(((MatrixBlockDictionary)
o).getMatrixBlock());
+ else if(o instanceof Dictionary) {
+ if(mb.isInSparseFormat())
+ return mb.getSparseBlock().equals(((Dictionary)
o)._values, nRowCol);
+ final double[] dv = mb.getDenseBlockValues();
+ return Arrays.equals(dv, ((Dictionary) o)._values);
+ }
+
+ return false;
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java
b/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java
index 7136bdb8c1..6ea603edac 100644
--- a/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java
@@ -246,6 +246,7 @@ public class LibTensorAgg {
}
case INT64:
case INT32:
+ case UINT4:
case UINT8: {
DenseBlock a = in.getDenseBlock();
long sum = 0;
diff --git a/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
b/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
index c9f8b6c0ef..5047ee2524 100644
--- a/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
@@ -648,6 +648,8 @@ public class TensorBlock implements
CacheBlock<TensorBlock>, Externalizable {
long size = 8 + 1;
if (!bt.isSparse()) {
switch (bt._vt) {
+ case UINT4:
+ size += getLength() / 2 + getLength() %
2;
case UINT8:
size += 1 * getLength(); break;
case INT32:
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
index 3e04d54551..69c3a9985f 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
@@ -71,7 +71,6 @@ import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DMVUtils;
import org.apache.sysds.runtime.util.EMAUtils;
@@ -88,9 +87,6 @@ public class FrameBlock implements CacheBlock<FrameBlock>,
Externalizable {
/** Buffer size variable: 1M elements, size of default matrix block */
public static final int BUFFER_SIZE = 1 * 1000 * 1000;
- /** internal configuration */
- private static final boolean REUSE_RECODE_MAPS = true;
-
/** The schema of the data frame as an ordered list of value types */
private ValueType[] _schema = null;
@@ -169,18 +165,17 @@ public class FrameBlock implements
CacheBlock<FrameBlock>, Externalizable {
}
/**
- * allocate a FrameBlock with the given data arrays.
+ * allocate a FrameBlock with the given data arrays.
*
- * The data is in row major, making the first dimension number of rows.
- * second number of columns.
+ * The data is in row major, making the first dimension number of rows.
second number of columns.
*
* @param schema the schema to allocate
- * @param names The names of the column
- * @param data The data.
+ * @param names The names of the column
+ * @param data The data.
*/
public FrameBlock(ValueType[] schema, String[] names, String[][] data) {
_schema = schema;
- if(names != null){
+ if(names != null) {
_colnames = names;
if(schema.length != names.length)
throw new DMLRuntimeException("Invalid
FrameBlock construction, invalid schema and names combination");
@@ -821,9 +816,11 @@ public class FrameBlock implements CacheBlock<FrameBlock>,
Externalizable {
.map(x ->
x.getInMemorySize()).reduce(0L, Long::sum);
}).get();
pool.shutdown();
+
}
catch(InterruptedException | ExecutionException
e) {
pool.shutdown();
+ LOG.error(e);
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
}
@@ -831,6 +828,7 @@ public class FrameBlock implements CacheBlock<FrameBlock>,
Externalizable {
else {
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
+
}
}
return size;
@@ -1187,34 +1185,8 @@ public class FrameBlock implements
CacheBlock<FrameBlock>, Externalizable {
* @param col is the column # from frame data which contains Recode map
generated earlier.
* @return map of token and code for every element in the input column
of a frame containing Recode map
*/
- public HashMap<String, Long> getRecodeMap(int col) {
- // probe cache for existing map
- if(REUSE_RECODE_MAPS) {
- SoftReference<HashMap<String, Long>> tmp =
_coldata[col].getCache();
- HashMap<String, Long> map = (tmp != null) ? tmp.get() :
null;
- if(map != null)
- return map;
- }
-
- // construct recode map
- HashMap<String, Long> map = new HashMap<>();
- Array<?> ldata = _coldata[col];
- int nRow = _coldata[0].size();
- if(nRow != _nRow)
- throw new DMLRuntimeException("Invalid intermediate
size:" + nRow + " " + _nRow);
- for(int i = 0; i < getNumRows(); i++) {
- Object val = ldata.get(i);
- if(val != null) {
- String[] tmp =
ColumnEncoderRecode.splitRecodeMapEntry(val.toString());
- map.put(tmp[0], Long.parseLong(tmp[1]));
- }
- }
-
- // put created map into cache
- if(REUSE_RECODE_MAPS)
- _coldata[col].setCache(new SoftReference<>(map));
-
- return map;
+ public HashMap<Object, Long> getRecodeMap(int col) {
+ return _coldata[col].getRecodeMap();
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
index 2be9e10170..e706672e17 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.frame.data.columns;
import java.lang.ref.SoftReference;
import java.util.HashMap;
+import java.util.Iterator;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
@@ -37,9 +38,11 @@ import org.apache.sysds.runtime.matrix.data.Pair;
*/
public abstract class Array<T> implements Writable {
protected static final Log LOG =
LogFactory.getLog(Array.class.getName());
+ /** internal configuration */
+ private static final boolean REUSE_RECODE_MAPS = true;
/** A soft reference to a memorization of this arrays mapping, used in
transformEncode */
- protected SoftReference<HashMap<String, Long>> _rcdMapCache = null;
+ protected SoftReference<HashMap<T, Long>> _rcdMapCache = null;
/** The current allocated number of elements in this Array */
protected int _size;
@@ -59,7 +62,7 @@ public abstract class Array<T> implements Writable {
*
* @return The cached object
*/
- public final SoftReference<HashMap<String, Long>> getCache() {
+ public final SoftReference<HashMap<T, Long>> getCache() {
return _rcdMapCache;
}
@@ -68,10 +71,43 @@ public abstract class Array<T> implements Writable {
*
* @param m The element to cache.
*/
- public final void setCache(SoftReference<HashMap<String, Long>> m) {
+ public final void setCache(SoftReference<HashMap<T, Long>> m) {
_rcdMapCache = m;
}
+ public HashMap<T, Long> getRecodeMap() {
+ // probe cache for existing map
+ if(REUSE_RECODE_MAPS) {
+ SoftReference<HashMap<T, Long>> tmp = getCache();
+ HashMap<T, Long> map = (tmp != null) ? tmp.get() : null;
+ if(map != null)
+ return map;
+ }
+
+ // construct recode map
+ HashMap<T, Long> map = createRecodeMap();
+
+ // put created map into cache
+ if(REUSE_RECODE_MAPS)
+ setCache(new SoftReference<>(map));
+
+ return map;
+ }
+
+
+ protected HashMap<T, Long> createRecodeMap(){
+ HashMap<T, Long> map = new HashMap<>();
+ long id = 0;
+ for(int i = 0; i < size(); i++) {
+ T val = get(i);
+ if(val != null && !map.containsKey(val))
+ map.put(val, id++);
+ }
+ return map;
+ }
+
+
+
/**
* Get the number of elements in the array, this does not necessarily
reflect the current allocated size.
*
@@ -306,6 +342,15 @@ public abstract class Array<T> implements Writable {
return null;
}
+ /**
+ * analyze if the array contains null values.
+ *
+ * @return If the array contains null.
+ */
+ public boolean containsNull(){
+ return false;
+ }
+
public Array<?> changeTypeWithNulls(ValueType t) {
final ABooleanArray nulls = getNulls();
if(nulls == null)
@@ -321,6 +366,7 @@ public abstract class Array<T> implements Writable {
return new
OptionalArray<Float>(changeTypeFloat(), nulls);
case FP64:
return new
OptionalArray<Double>(changeTypeDouble(), nulls);
+ case UINT4:
case UINT8:
throw new NotImplementedException();
case INT32:
@@ -354,6 +400,7 @@ public abstract class Array<T> implements Writable {
return changeTypeFloat();
case FP64:
return changeTypeDouble();
+ case UINT4:
case UINT8:
throw new NotImplementedException();
case INT32:
@@ -520,4 +567,26 @@ public abstract class Array<T> implements Writable {
return this.getClass().getSimpleName();
}
+
+ public ArrayIterator getIterator(){
+ return new ArrayIterator();
+ }
+
+ public class ArrayIterator implements Iterator<T> {
+ int index = -1;
+
+ public int getIndex(){
+ return index;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return index < size()-1;
+ }
+
+ @Override
+ public T next() {
+ return get(++index);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
index 88c6ff2040..8af5623708 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
@@ -82,6 +82,7 @@ public interface ArrayFactory {
return Array.baseMemoryCost() + (long)
MemoryEstimates.longArrayCost(_numRows);
case FP64:
return Array.baseMemoryCost() + (long)
MemoryEstimates.doubleArrayCost(_numRows);
+ case UINT4:
case UINT8:
case INT32:
return Array.baseMemoryCost() + (long)
MemoryEstimates.intArrayCost(_numRows);
@@ -111,6 +112,7 @@ public interface ArrayFactory {
return new OptionalArray<>(new
BitSetArray(nRow), true);
else
return new OptionalArray<>(new
BooleanArray(new boolean[nRow]), true);
+ case UINT4:
case UINT8:
case INT32:
return new OptionalArray<>(new IntegerArray(new
int[nRow]), true);
@@ -140,6 +142,7 @@ public interface ArrayFactory {
switch(v) {
case BOOLEAN:
return allocateBoolean(nRow);
+ case UINT4:
case UINT8:
case INT32:
return new IntegerArray(new int[nRow]);
@@ -261,8 +264,9 @@ public interface ArrayFactory {
return FloatArray.parseFloat(s);
case FP64:
return DoubleArray.parseDouble(s);
- case INT32:
+ case UINT4:
case UINT8:
+ case INT32:
return IntegerArray.parseInt(s);
case INT64:
return LongArray.parseLong(s);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
index 32a9c867ed..fe85c2530a 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
@@ -425,6 +425,11 @@ public class OptionalArray<T> extends Array<T> {
}
+ @Override
+ public boolean containsNull(){
+ return !_n.isAllTrue();
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
index 250b876152..862014b39e 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
@@ -24,6 +24,7 @@ import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
+import java.util.HashMap;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.common.Types.ValueType;
@@ -32,6 +33,7 @@ import
org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType;
import org.apache.sysds.runtime.frame.data.lib.FrameUtil;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.utils.MemoryEstimates;
public class StringArray extends Array<String> {
@@ -580,6 +582,14 @@ public class StringArray extends Array<String> {
return false;
return true;
}
+
+ @Override
+ public boolean containsNull(){
+ for(int i = 0; i < _data.length; i++)
+ if(_data[i] == null)
+ return true;
+ return false;
+ }
@Override
public Array<String> select(int[] indices) {
@@ -604,6 +614,28 @@ public class StringArray extends Array<String> {
return _data[i] != null && !_data[i].equals("0");
}
+ @Override
+ protected HashMap<String, Long> createRecodeMap(){
+ try{
+
+ HashMap<String, Long> map = new HashMap<>();
+ for(int i = 0; i < size(); i++) {
+ Object val = get(i);
+ if(val != null) {
+ String[] tmp =
ColumnEncoderRecode.splitRecodeMapEntry(val.toString());
+ map.put(tmp[0], Long.parseLong(tmp[1]));
+ }
+ else // once we hit null return.
+ break;
+ }
+ return map;
+ }
+ catch(Exception e){
+ return super.createRecodeMap();
+ }
+ }
+
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size * 5 + 2);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
index 6ec120c906..4a10ae0421 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
@@ -152,7 +152,7 @@ public class MultiReturnParameterizedBuiltinSPInstruction
extends ComputationSPI
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(spec, colnames, fo.getSchema(),
(int) fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows() - ((omap != null) ?
omap.getNumRmRows() : 0),
- (int) fo.getNumColumns() +
encoder.getNumExtraCols());
+ (int) encoder.getNumOutCols());
Broadcast<MultiColumnEncoder> bmeta =
sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap != null) ?
sec.getSparkContext().broadcast(omap) : null;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index f4c29ea9dd..e5b8fea07a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -502,7 +502,7 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colnames,
fo.getSchema(), (int) fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows() - ((omap != null) ?
omap.getNumRmRows() : 0),
- (int) fo.getNumColumns() +
encoder.getNumExtraCols());
+ (int)encoder.getNumOutCols());
Broadcast<MultiColumnEncoder> bmeta =
sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap != null) ?
sec.getSparkContext().broadcast(omap) : null;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 5ffd10142a..610e0cc414 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -206,6 +206,11 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
// do nothing
}
+ public int getDomainSize(){
+ return 1;
+ }
+
+
/**
* Partial build of internal data structures (e.g., in distributed
spark operations).
*
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 3809af821f..b532dc04be 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -387,6 +387,15 @@ public class ColumnEncoderBin extends ColumnEncoder {
}
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ return sb.toString();
+ }
+
public enum BinMethod {
INVALID, EQUI_WIDTH, EQUI_HEIGHT
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 1060aa2c1d..a033bfa30f 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -107,12 +107,16 @@ public class ColumnEncoderComposite extends ColumnEncoder
{
@Override
public void build(CacheBlock<?> in, Map<Integer, double[]>
equiHeightMaxs) {
- for(ColumnEncoder columnEncoder : _columnEncoders)
- if(columnEncoder instanceof ColumnEncoderBin &&
((ColumnEncoderBin) columnEncoder).getBinMethod() ==
ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
- columnEncoder.build(in,
equiHeightMaxs.get(columnEncoder.getColID()));
- } else {
- columnEncoder.build(in);
- }
+ if(equiHeightMaxs == null)
+ build(in);
+ else{
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ if(columnEncoder instanceof ColumnEncoderBin &&
((ColumnEncoderBin) columnEncoder).getBinMethod() ==
ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+ columnEncoder.build(in,
equiHeightMaxs.get(columnEncoder.getColID()));
+ } else {
+ columnEncoder.build(in);
+ }
+ }
}
@Override
@@ -321,9 +325,7 @@ public class ColumnEncoderComposite extends ColumnEncoder {
sb.append("CompositeEncoder(").append(_columnEncoders.size()).append("):\n");
for(ColumnEncoder columnEncoder : _columnEncoders) {
sb.append("-- ");
- sb.append(columnEncoder.getClass().getSimpleName());
- sb.append(": ");
- sb.append(columnEncoder._colID);
+ sb.append(columnEncoder);
sb.append("\n");
}
return sb.toString();
@@ -410,6 +412,28 @@ public class ColumnEncoderComposite extends ColumnEncoder {
}).collect(Collectors.toSet());
}
+ @Override
+ public int getDomainSize() {
+ return _columnEncoders.stream()//
+
.map(ColumnEncoder::getDomainSize).reduce(Integer::max).get();
+ }
+
+
+ public boolean isRecodeToDummy(){
+ return _columnEncoders.size() == 2 //
+ && _columnEncoders.get(0) instanceof
ColumnEncoderRecode //
+ && _columnEncoders.get(1) instanceof
ColumnEncoderDummycode;
+ }
+
+ public boolean isRecode(){
+ return _columnEncoders.size() == 1 //
+ && _columnEncoders.get(0) instanceof ColumnEncoderRecode;
+ }
+
+ public boolean isPassThrough(){
+ return _columnEncoders.size() == 1 //
+ && _columnEncoders.get(0) instanceof
ColumnEncoderPassThrough;
+ }
private static class ColumnCompositeUpdateDCTask implements
Callable<Object> {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 970df3aaa3..f30743ff27 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -266,10 +266,22 @@ public class ColumnEncoderDummycode extends ColumnEncoder
{
return result;
}
+ @Override
public int getDomainSize() {
return _domainSize;
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ sb.append(" --- DomainSize : ");
+ sb.append(_domainSize);
+ return sb.toString();
+ }
+
private static class DummycodeSparseApplyTask extends
ColumnApplyTask<ColumnEncoderDummycode> {
protected DummycodeSparseApplyTask(ColumnEncoderDummycode
encoder, MatrixBlock input,
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
index cfa69d1c4b..12e3f80b70 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
@@ -155,6 +155,15 @@ public class ColumnEncoderFeatureHash extends
ColumnEncoder {
_K = in.readLong();
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ return sb.toString();
+ }
+
public static class FeatureHashSparseApplyTask extends
ColumnApplyTask<ColumnEncoderFeatureHash>{
public FeatureHashSparseApplyTask(ColumnEncoderFeatureHash
encoder, CacheBlock<?> input,
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index 63c27469ca..9d775a7a5f 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -145,6 +145,15 @@ public class ColumnEncoderPassThrough extends
ColumnEncoder {
// do nothing
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ return sb.toString();
+ }
+
public static class PassThroughSparseApplyTask extends
ColumnApplyTask<ColumnEncoderPassThrough>{
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index 799ed37098..eb7e706e0c 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -47,18 +47,19 @@ public class ColumnEncoderRecode extends ColumnEncoder {
public static boolean SORT_RECODE_MAP = false;
// recode maps and custom map for partial recode maps
- private HashMap<String, Long> _rcdMap = new HashMap<>();
+ private HashMap<Object, Long> _rcdMap;
private HashSet<Object> _rcdMapPart = null;
public ColumnEncoderRecode(int colID) {
super(colID);
+ _rcdMap = new HashMap<>();
}
public ColumnEncoderRecode() {
this(-1);
}
- private ColumnEncoderRecode(int colID, HashMap<String, Long> rcdMap) {
+ protected ColumnEncoderRecode(int colID, HashMap<Object, Long> rcdMap) {
super(colID);
_rcdMap = rcdMap;
}
@@ -75,7 +76,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
return constructRecodeMapEntry(token, code, sb);
}
- private static String constructRecodeMapEntry(String token, Long code,
StringBuilder sb) {
+ private static String constructRecodeMapEntry(Object token, Long code,
StringBuilder sb) {
sb.setLength(0); // reset reused string builder
return
sb.append(token).append(Lop.DATATYPE_PREFIX).append(code.longValue()).toString();
}
@@ -93,7 +94,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
return new String[] {value.substring(0, pos),
value.substring(pos + 1)};
}
- public HashMap<String, Long> getCPRecodeMaps() {
+ public HashMap<Object, Long> getCPRecodeMaps() {
return _rcdMap;
}
@@ -105,15 +106,15 @@ public class ColumnEncoderRecode extends ColumnEncoder {
sortCPRecodeMaps(_rcdMap);
}
- private static void sortCPRecodeMaps(HashMap<String, Long> map) {
- String[] keys = map.keySet().toArray(new String[0]);
+ private static void sortCPRecodeMaps(HashMap<Object, Long> map) {
+ Object[] keys = map.keySet().toArray(new Object[0]);
Arrays.sort(keys);
map.clear();
- for(String key : keys)
+ for(Object key : keys)
putCode(map, key);
}
- private static void makeRcdMap(CacheBlock<?> in, HashMap<String, Long>
map, int colID, int startRow, int blk) {
+ private static void makeRcdMap(CacheBlock<?> in, HashMap<Object, Long>
map, int colID, int startRow, int blk) {
for(int row = startRow; row < getEndIndex(in.getNumRows(),
startRow, blk); row++){
String key = in.getString(row, colID - 1);
if(key != null && !key.isEmpty() &&
!map.containsKey(key))
@@ -124,9 +125,8 @@ public class ColumnEncoderRecode extends ColumnEncoder {
}
}
- private long lookupRCDMap(String key) {
- Long tmp = _rcdMap.get(key);
- return (tmp != null) ? tmp : -1;
+ private long lookupRCDMap(Object key) {
+ return _rcdMap.getOrDefault(key, -1L);
}
public void computeRCDMapSizeEstimate(CacheBlock<?> in, int[]
sampleIndices) {
@@ -202,7 +202,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
* @param map column map
* @param key key for the new entry
*/
- protected static void putCode(HashMap<String, Long> map, String key) {
+ protected static void putCode(HashMap<Object, Long> map, Object key) {
map.put(key, (long) (map.size() + 1));
}
@@ -270,10 +270,10 @@ public class ColumnEncoderRecode extends ColumnEncoder {
assert other._colID == _colID;
// merge together overlapping columns
ColumnEncoderRecode otherRec = (ColumnEncoderRecode) other;
- HashMap<String, Long> otherMap = otherRec._rcdMap;
+ HashMap<Object, Long> otherMap = otherRec._rcdMap;
if(otherMap != null) {
// for each column, add all non present recode values
- for(Map.Entry<String, Long> entry :
otherMap.entrySet()) {
+ for(Map.Entry<Object, Long> entry :
otherMap.entrySet()) {
if(lookupRCDMap(entry.getKey()) == -1) {
// key does not yet exist
putCode(_rcdMap, entry.getKey());
@@ -305,7 +305,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
// create compact meta data representation
StringBuilder sb = new StringBuilder(); // for reuse
int rowID = 0;
- for(Entry<String, Long> e : _rcdMap.entrySet()) {
+ for(Entry<Object, Long> e : _rcdMap.entrySet()) {
meta.set(rowID++, _colID - 1, // 1-based
constructRecodeMapEntry(e.getKey(),
e.getValue(), sb));
}
@@ -330,8 +330,9 @@ public class ColumnEncoderRecode extends ColumnEncoder {
public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);
out.writeInt(_rcdMap.size());
- for(Entry<String, Long> e : _rcdMap.entrySet()) {
- out.writeUTF(e.getKey());
+
+ for(Entry<Object, Long> e : _rcdMap.entrySet()) {
+ out.writeUTF(e.getKey().toString());
out.writeLong(e.getValue());
}
}
@@ -362,10 +363,21 @@ public class ColumnEncoderRecode extends ColumnEncoder {
return Objects.hash(_rcdMap);
}
- public HashMap<String, Long> getRcdMap() {
+ public HashMap<Object, Long> getRcdMap() {
return _rcdMap;
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ sb.append(" --- map: ");
+ sb.append(_rcdMap);
+ return sb.toString();
+ }
+
private static class RecodeSparseApplyTask extends
ColumnApplyTask<ColumnEncoderRecode>{
public RecodeSparseApplyTask(ColumnEncoderRecode encoder,
CacheBlock<?> input, MatrixBlock out, int outputCol) {
@@ -416,9 +428,9 @@ public class ColumnEncoderRecode extends ColumnEncoder {
}
@Override
- public HashMap<String, Long> call() throws Exception {
+ public Object call() throws Exception {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
- HashMap<String, Long> partialMap = new HashMap<>();
+ HashMap<Object, Long> partialMap = new HashMap<>();
makeRcdMap(_input, partialMap, _colID, _startRow,
_blockSize);
synchronized(_partialMaps) {
_partialMaps.put(_startRow, partialMap);
@@ -448,11 +460,11 @@ public class ColumnEncoderRecode extends ColumnEncoder {
@Override
public Object call() throws Exception {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
- HashMap<String, Long> rcdMap = _encoder.getRcdMap();
+ HashMap<Object, Long> rcdMap = _encoder.getRcdMap();
_partialMaps.forEach((start_row, map) -> {
((HashMap<?, ?>) map).forEach((k, v) -> {
- if(!rcdMap.containsKey((String) k))
- putCode(rcdMap, (String) k);
+ if(!rcdMap.containsKey(k))
+ putCode(rcdMap, k);
});
});
_encoder._rcdMap = rcdMap;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
new file mode 100644
index 0000000000..b4bcd3c5a1
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.encode;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
+import
org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
+import
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
+import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
+import org.apache.sysds.runtime.compress.utils.Util;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class CompressedEncode {
+ protected static final Log LOG =
LogFactory.getLog(CompressedEncode.class.getName());
+
+ private final MultiColumnEncoder enc;
+ private final FrameBlock in;
+
+ private CompressedEncode(MultiColumnEncoder enc, FrameBlock in) {
+ this.enc = enc;
+ this.in = in;
+ }
+
+ public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in)
{
+ return new CompressedEncode(enc, in).apply();
+ }
+
+ private MatrixBlock apply() {
+ List<ColumnEncoderComposite> encoders = enc.getColumnEncoders();
+
+ List<AColGroup> groups = new ArrayList<>(encoders.size());
+
+ for(ColumnEncoderComposite c : encoders)
+ groups.add(encode(c));
+
+ int cols = shiftGroups(groups);
+
+ MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(),
cols, -1, false, groups);
+ mb.recomputeNonZeros();
+ logging(mb);
+ return mb;
+ }
+
+ /**
+ * Shift the column groups to the correct column numbers.
+ *
+ * @param groups the groups to shift
+ * @return The total number of columns contained.
+ */
+ private int shiftGroups(List<AColGroup> groups) {
+ int cols = groups.get(0).getColIndices().length;
+ for(int i = 1; i < groups.size(); i++) {
+ groups.set(i, groups.get(i).shiftColIndices(cols));
+ cols += groups.get(i).getColIndices().length;
+ }
+ return cols;
+ }
+
+ private AColGroup encode(ColumnEncoderComposite c) {
+ if(c.isRecodeToDummy())
+ return recodeToDummy(c);
+ else if(c.isRecode())
+ return recode(c);
+ else if(c.isPassThrough())
+ return passThrough(c);
+ else
+ throw new NotImplementedException("Not supporting : " +
c);
+ }
+
+ @SuppressWarnings("unchecked")
+ private AColGroup recodeToDummy(ColumnEncoderComposite c) {
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ HashMap<?, Long> map = a.getRecodeMap();
+ int domain = map.size();
+
+ // int domain = c.getDomainSize();
+ int[] colIndexes = Util.genColsIndices(0, domain);
+
+ ADictionary d = new IdentityDictionary(colIndexes.length);
+
+ AMapToData m = createMappingAMapToData(a, map);
+
+ List<ColumnEncoder> r = c.getEncoders();
+ r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>)
map));
+
+ return ColGroupDDC.create(colIndexes, d, m, null);
+
+ }
+
+ @SuppressWarnings("unchecked")
+ private AColGroup recode(ColumnEncoderComposite c) {
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ HashMap<?, Long> map = a.getRecodeMap();
+ int domain = map.size();
+
+ // int domain = c.getDomainSize();
+ int[] colIndexes = new int[1];
+ MatrixBlock incrementing = new MatrixBlock(domain, 1, false);
+ for(int i = 0; i < domain; i++)
+ incrementing.quickSetValue(i, 0, i + 1);
+
+ ADictionary d = MatrixBlockDictionary.create(incrementing);
+
+ AMapToData m = createMappingAMapToData(a, map);
+
+ List<ColumnEncoder> r = c.getEncoders();
+ r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>)
map));
+
+ return ColGroupDDC.create(colIndexes, d, m, null);
+
+ }
+
+ @SuppressWarnings("unchecked")
+ private AColGroup passThrough(ColumnEncoderComposite c) {
+ int[] colIndexes = new int[1];
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ HashMap<Object, Long> map = (HashMap<Object, Long>)
a.getRecodeMap();
+
+ double[] vals = new double[map.size() + (a.containsNull() ? 1 :
0)];
+ for(int i = 0; i < a.size(); i++) {
+ Object v = a.get(i);
+ if(map.containsKey(v)) {
+ vals[map.get(v).intValue()] = a.getAsDouble(i);
+ }
+ else {
+ map.put(null, (long) map.size());
+ vals[map.get(v).intValue()] = a.getAsDouble(i);
+ }
+ }
+ ADictionary d = Dictionary.create(vals);
+ AMapToData m = createMappingAMapToData(a, map);
+ return ColGroupDDC.create(colIndexes, d, m, null);
+ }
+
+ private AMapToData createMappingAMapToData(Array<?> a, HashMap<?, Long>
map) {
+ AMapToData m = MapToFactory.create(in.getNumRows(), map.size());
+ Array<?>.ArrayIterator it = a.getIterator();
+ while(it.hasNext()) {
+ Object v = it.next();
+ if(v != null) {
+ m.set(it.getIndex(), map.get(v).intValue());
+ }
+ }
+ return m;
+ }
+
+ private void logging(MatrixBlock mb) {
+ if(LOG.isDebugEnabled()) {
+ LOG.debug(String.format("Uncompressed transform encode
Dense size: %16d", mb.estimateSizeDenseInMemory()));
+ LOG.debug(String.format("Uncompressed transform encode
Sparse size: %16d", mb.estimateSizeSparseInMemory()));
+ LOG.debug(String.format("Compressed transform encode
size: %16d", mb.estimateSizeInMemory()));
+
+ double ratio = Math.min(mb.estimateSizeDenseInMemory(),
mb.estimateSizeSparseInMemory()) /
+ mb.estimateSizeInMemory();
+ double denseRatio = mb.estimateSizeDenseInMemory() /
mb.estimateSizeInMemory();
+ LOG.debug(String.format("Compression ratio: %10.3f",
ratio));
+ LOG.debug(String.format("Dense ratio: %10.3f",
denseRatio));
+ }
+
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 22190e518e..59c1a3d09f 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -42,6 +42,7 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.estim.ComEstSample;
@@ -90,11 +91,21 @@ public class MultiColumnEncoder implements Encoder {
}
public MatrixBlock encode(CacheBlock<?> in, int k) {
- MatrixBlock out;
+ return encode(in, k, false);
+ }
+
+ public MatrixBlock encode(CacheBlock<?> in, boolean compressedOut) {
+ return encode(in, 1, compressedOut);
+ }
+
+ public MatrixBlock encode(CacheBlock<?> in, int k, boolean
compressedOut){
+
deriveNumRowPartitions(in, k);
try {
- if(k > 1 && !MULTI_THREADED_STAGES &&
!hasLegacyEncoder()) {
- out = new MatrixBlock();
+ if(isCompressedTransformEncode(in, compressedOut))
+ return CompressedEncode.encode(this,
(FrameBlock ) in);
+ else if(k > 1 && !MULTI_THREADED_STAGES &&
!hasLegacyEncoder()) {
+ MatrixBlock out = new MatrixBlock();
DependencyThreadPool pool = new
DependencyThreadPool(k);
LOG.debug("Encoding with full DAG on " + k + "
Threads");
try {
@@ -106,6 +117,7 @@ public class MultiColumnEncoder implements Encoder {
}
pool.shutdown();
outputMatrixPostProcessing(out);
+ return out;
}
else {
LOG.debug("Encoding with staged approach on: "
+ k + " Threads");
@@ -123,16 +135,20 @@ public class MultiColumnEncoder implements Encoder {
}
// apply meta data
t0 = System.nanoTime();
- out = apply(in, k);
+ MatrixBlock out = apply(in, k);
t1 = System.nanoTime();
LOG.debug("Elapsed time for apply phase: "+
((double) t1 - t0) / 1000000 + " ms");
+ return out;
}
}
catch(Exception ex) {
LOG.error("Failed transform-encode frame with \n" +
this);
throw ex;
}
- return out;
+ }
+
+ protected List<ColumnEncoderComposite> getEncoders() {
+ return _columnEncoders;
}
/* TASK DETAILS:
@@ -245,21 +261,7 @@ public class MultiColumnEncoder implements Encoder {
}
public void build(CacheBlock<?> in, int k) {
- if(hasLegacyEncoder() && !(in instanceof FrameBlock))
- throw new DMLRuntimeException("LegacyEncoders do not
support non FrameBlock Inputs");
- if(!_partitionDone) //happens if this method is directly called
- deriveNumRowPartitions(in, k);
- if(k > 1) {
- buildMT(in, k);
- }
- else {
- for(ColumnEncoderComposite columnEncoder :
_columnEncoders) {
- columnEncoder.build(in);
- columnEncoder.updateAllDCEncoders();
- }
- }
- if(hasLegacyEncoder())
- legacyBuild((FrameBlock) in);
+ build(in, k, null);
}
public void build(CacheBlock<?> in, int k, Map<Integer, double[]>
equiHeightBinMaxs) {
@@ -317,7 +319,7 @@ public class MultiColumnEncoder implements Encoder {
boolean hasUDF = _columnEncoders.stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderUDF.class));
for(ColumnEncoderComposite columnEncoder : _columnEncoders)
columnEncoder.updateAllDCEncoders();
- int numCols = in.getNumColumns() + getNumExtraCols();
+ int numCols = getNumOutCols();
long estNNz = (long) in.getNumRows() * (hasUDF ? numCols :
(long) in.getNumColumns());
boolean sparse =
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) &&
!hasUDF;
MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols,
sparse, estNNz);
@@ -654,6 +656,8 @@ public class MultiColumnEncoder implements Encoder {
long t0 = System.nanoTime();
if(_meta != null)
return _meta;
+ if(meta == null)
+ meta = new FrameBlock(_columnEncoders.size(),
ValueType.STRING);
this.allocateMetaData(meta);
if (k > 1) {
try {
@@ -854,24 +858,11 @@ public class MultiColumnEncoder implements Encoder {
return getEncoderTypes(-1);
}
- public int getNumExtraCols() {
- List<ColumnEncoderDummycode> dc =
getColumnEncoders(ColumnEncoderDummycode.class);
- if(dc.isEmpty()) {
- return 0;
- }
- if(dc.stream().anyMatch(e -> e.getDomainSize() < 0)) {
- throw new DMLRuntimeException("Trying to get extra
columns when DC encoders are not ready");
- }
- return
dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() -
dc.size();
- }
-
- public int getNumExtraCols(IndexRange ixRange) {
- List<ColumnEncoderDummycode> dc =
getColumnEncoders(ColumnEncoderDummycode.class).stream()
- .filter(dce ->
ixRange.inColRange(dce._colID)).collect(Collectors.toList());
- if(dc.isEmpty()) {
- return 0;
- }
- return
dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() -
dc.size();
+ public int getNumOutCols() {
+ int sum = 0;
+ for(int i = 0; i < _columnEncoders.size(); i++)
+ sum += _columnEncoders.get(i).getDomainSize();
+ return sum;
}
public <T extends ColumnEncoder> boolean containsEncoderForID(int
colID, Class<T> type) {
@@ -998,6 +989,11 @@ public class MultiColumnEncoder implements Encoder {
return hasLegacyEncoder(EncoderMVImpute.class) ||
hasLegacyEncoder(EncoderOmit.class);
}
+ public boolean isCompressedTransformEncode(CacheBlock<?> in, boolean
enabled){
+ return (enabled ||
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_TRANSFORMENCODE))
&&
+ in instanceof FrameBlock && _colOffset == 0;
+ }
+
public <T extends LegacyEncoder> boolean hasLegacyEncoder(Class<T>
type) {
if(type.equals(EncoderMVImpute.class))
return _legacyMVImpute != null;
@@ -1027,6 +1023,22 @@ public class MultiColumnEncoder implements Encoder {
_legacyMVImpute.shiftCols(_colOffset);
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.getClass().getSimpleName());
+ sb.append("\nIs Legacy: ");
+ sb.append(_legacyMVImpute);
+ sb.append("\nEncoders:\n");
+
+ for(int i = 0; i < _columnEncoders.size(); i++) {
+ sb.append(_columnEncoders.get(i));
+ sb.append("\n");
+ }
+
+ return sb.toString();
+ }
+
/*
* Currently, not in use will be integrated in the future
*/
@@ -1081,7 +1093,7 @@ public class MultiColumnEncoder implements Encoder {
@Override
public Object call() throws Exception {
boolean hasUDF =
_encoder.getColumnEncoders().stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderUDF.class));
- int numCols = _input.getNumColumns() +
_encoder.getNumExtraCols();
+ int numCols = _encoder.getNumOutCols();
boolean hasDC =
_encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
long estNNz = (long) _input.getNumRows() * (hasUDF ?
numCols : (long) _input.getNumColumns());
boolean sparse =
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) &&
!hasUDF;
diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
index b06054181d..987f85a733 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
@@ -1119,6 +1119,7 @@ public class DataConverter {
sb.append(dfFormat(df, value));
break;
case UINT8:
+ case UINT4:
case INT32:
case INT64:
sb.append(tb.get(ix));
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index ea7717388d..5568603104 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -485,6 +485,7 @@ public class UtilFunctions {
switch( vt ) {
case STRING: return in;
case BOOLEAN: return Boolean.parseBoolean(in);
+ case UINT4:
case UINT8:
case INT32: return Integer.parseInt(in);
case INT64: return Long.parseLong(in);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index ece9c77a56..bade9ddbe8 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -842,6 +842,10 @@ public class TestUtils
}
public static void compareFrames(FrameBlock expected, FrameBlock
actual, boolean checkMeta) {
+ if(expected == null && actual == null)
+ return;
+ assertTrue("Expected frame was null pointer", expected != null);
+ assertTrue("Actual frame was null pointer", actual != null);
assertEquals("Number of columns and rows are not equivalent",
expected.getNumRows(), actual.getNumRows());
assertEquals("Number of columns and rows are not equivalent",
expected.getNumColumns(), actual.getNumColumns());
@@ -2417,6 +2421,7 @@ public class TestUtils
*/
public static Object generateRandomValueFromValueType(ValueType
valueType, Random random){
switch (valueType){
+ case UINT4: return random.nextInt(16);
case UINT8: return random.nextInt(256);
case FP32: return random.nextFloat();
case FP64: return random.nextDouble();
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
index c68332e735..cc6c3510af 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
@@ -1150,7 +1150,7 @@ public class CustomArrayTests {
@Test
public void mappingCache() {
- Array<?> a = new StringArray(new String[] {"1", null});
+ Array<String> a = new StringArray(new String[] {"1", null});
assertEquals(null, a.getCache());
a.setCache(new SoftReference<HashMap<String, Long>>(null));
assertTrue(null != a.getCache());
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/transform/transformCompressed.java
b/src/test/java/org/apache/sysds/test/component/frame/transform/transformCompressed.java
new file mode 100644
index 0000000000..343aaf05ec
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/frame/transform/transformCompressed.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.frame.transform;
+
+import static org.junit.Assert.fail;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class transformCompressed {
+ protected static final Log LOG =
LogFactory.getLog(transformCompressed.class.getName());
+
+ private final FrameBlock data;
+
+ public transformCompressed() {
+ try {
+
+ data = TestUtils.generateRandomFrameBlock(100, new
ValueType[] {ValueType.UINT4}, 231);
+ data.setSchema(new ValueType[] {ValueType.INT32});
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ throw e;
+ }
+ }
+
+ @Test
+ public void testRecode() {
+ test("{recode:[C1]}");
+ }
+
+ @Test
+ public void testDummyCode() {
+ test("{dummycode:[C1]}");
+ }
+
+ // @Test
+ // public void testBin() {
+ // test("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}");
+ // }
+
+ // @Test
+ // public void testBin2() {
+ // test("{ids:true, bin:[{id:1, method:equi-width, numbins:100}]}");
+ // }
+
+ // @Test
+ // public void testBin3() {
+ // test("{ids:true, bin:[{id:1, method:equi-width, numbins:2}]}");
+ // }
+
+ // @Test
+ // public void testBin4() {
+ // test("{ids:true, bin:[{id:1, method:equi-height, numbins:2}]}");
+ // }
+
+ // @Test
+ // public void testBin5() {
+ // test("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}");
+ // }
+
+ public void test(String spec) {
+ try {
+
+ FrameBlock meta = null;
+ MultiColumnEncoder encoderCompressed =
EncoderFactory.createEncoder(spec, data.getColumnNames(),
+ data.getNumColumns(), meta);
+ MatrixBlock outCompressed =
encoderCompressed.encode(data, true);
+ FrameBlock outCompressedMD =
encoderCompressed.getMetaData(null);
+ MultiColumnEncoder encoderNormal =
EncoderFactory.createEncoder(spec, data.getColumnNames(),
+ data.getNumColumns(), meta);
+ MatrixBlock outNormal = encoderNormal.encode(data);
+ FrameBlock outNormalMD =
encoderNormal.getMetaData(null);
+
+
+ LOG.error(outNormal);
+ LOG.error(outCompressed);
+ LOG.error(outCompressedMD);
+ LOG.error(outNormalMD);
+
+ TestUtils.compareMatrices(outNormal, outCompressed, 0,
"Not Equal after apply");
+ TestUtils.compareFrames(outNormalMD, outCompressedMD,
true);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+}