This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cc8eb95135 [SYSTEMDS-2830] Functional Compression
cc8eb95135 is described below
commit cc8eb951358320a984ea6798d3d252b4eed5c80c
Author: wedenigt <[email protected]>
AuthorDate: Fri Jun 24 17:42:24 2022 +0200
[SYSTEMDS-2830] Functional Compression
This commit adds a new column group class for functional compression.
Initial implementation covers a linear compression scheme.
The new colgroup supports construction from matrix,
most of the operations and tests.
Closes #1634
Closes #1645
---
.../sysds/runtime/compress/colgroup/AColGroup.java | 5 +-
.../runtime/compress/colgroup/ColGroupFactory.java | 26 +-
.../colgroup/ColGroupLinearFunctional.java | 665 +++++++++++++++++++++
.../runtime/compress/colgroup/ColGroupSizes.java | 9 +
.../compress/colgroup/ColGroupUncompressed.java | 30 +-
.../runtime/compress/colgroup/ColGroupUtils.java | 39 ++
.../colgroup/functional/LinearRegression.java | 75 +++
.../compress/estim/CompressedSizeInfoColGroup.java | 2 +
.../colgroup/ColGroupLinearFunctionalBase.java | 252 ++++++++
.../colgroup/ColGroupLinearFunctionalTest.java | 346 +++++++++++
.../compress/functional/LinearRegressionTests.java | 115 ++++
11 files changed, 1526 insertions(+), 38 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
index 45a0d62df7..557c0269b3 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
@@ -52,7 +52,7 @@ public abstract class AColGroup implements Serializable {
/** Public super types of compression ColGroups supported */
public static enum CompressionType {
- UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCFOR, DDCFOR,
DeltaDDC
+ UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCFOR, DDCFOR,
DeltaDDC, LinearFunctional;
}
/**
@@ -61,7 +61,8 @@ public abstract class AColGroup implements Serializable {
* Protected such that outside the ColGroup package it should be
unknown which specific subtype is used.
*/
protected static enum ColGroupType {
- UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCSingle,
SDCSingleZeros, SDCZeros, SDCFOR, DDCFOR, DeltaDDC;
+ UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCSingle,
SDCSingleZeros, SDCZeros, SDCFOR, DDCFOR, DeltaDDC,
+ LinearFunctional;
}
/** The ColGroup Indexes contained in the ColGroup */
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 8a73ec1b54..c9a8e894c7 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -43,6 +43,7 @@ import
org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
+import org.apache.sysds.runtime.compress.colgroup.functional.LinearRegression;
import
org.apache.sysds.runtime.compress.colgroup.insertionsort.AInsertionSorter;
import
org.apache.sysds.runtime.compress.colgroup.insertionsort.InsertionSorterFactory;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
@@ -129,13 +130,13 @@ public class ColGroupFactory {
}
private List<AColGroup> compress() {
- try{
+ try {
List<AColGroup> ret = compressExecute();
if(pool != null)
pool.shutdown();
- return ret;
+ return ret;
}
- catch(Exception e ){
+ catch(Exception e) {
if(pool != null)
pool.shutdown();
throw new DMLCompressionException("Compression Failed",
e);
@@ -359,6 +360,8 @@ public class ColGroupFactory {
return compressSDCFromSparseTransposedBlock(colIndexes,
nrUniqueEstimate, cg.getTupleSparsity());
else if(ct == CompressionType.DDC)
return directCompressDDC(colIndexes, cg);
+ else if(ct == CompressionType.LinearFunctional)
+ return compressLinearFunctional(colIndexes, in, cs);
else {
LOG.debug("Default slow path: " + ct + " " +
cs.transposed + " " + Arrays.toString(colIndexes));
final int numRows = cs.transposed ? in.getNumColumns()
: in.getNumRows();
@@ -445,19 +448,20 @@ public class ColGroupFactory {
if(dict == null)
// Again highly unlikely but possible.
return new ColGroupEmpty(colIndexes);
- try{
+ try {
if(extra)
d.replace(fill, map.size());
-
+
final int nUnique = map.size() + (extra ? 1 : 0);
-
+
final AMapToData resData = MapToFactory.resize(d,
nUnique);
return ColGroupDDC.create(colIndexes, nRow, dict,
resData, null);
}
- catch(Exception e ){
+ catch(Exception e) {
ReaderColumnSelection reader =
ReaderColumnSelection.createReader(in, colIndexes, cs.transposed, 0, nRow);
- throw new DMLCompressionException("direct compress DDC
Multi col failed extra:" + extra + " with reader type:" +
reader.getClass().getSimpleName(), e);
+ throw new DMLCompressionException("direct compress DDC
Multi col failed extra:" + extra + " with reader type:"
+ + reader.getClass().getSimpleName(), e);
}
}
@@ -653,6 +657,12 @@ public class ColGroupFactory {
return ColGroupSDCSingle.create(colIndexes, rlen, dict,
defaultTuple, off, null);
}
+ private static AColGroup compressLinearFunctional(int[] colIndexes,
MatrixBlock in, CompressionSettings cs) {
+ double[] coefficients = LinearRegression.regressMatrixBlock(in,
colIndexes, cs.transposed);
+ int numRows = cs.transposed ? in.getNumColumns() :
in.getNumRows();
+ return ColGroupLinearFunctional.create(colIndexes,
coefficients, numRows);
+ }
+
private static AColGroup compressDDC(int[] colIndexes, int rlen,
ABitmap ubm, CompressionSettings cs,
double tupleSparsity) {
boolean zeros = ubm.getNumOffsets() < rlen;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
new file mode 100644
index 0000000000..ecb516724a
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
@@ -0,0 +1,665 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
+import org.apache.sysds.runtime.compress.utils.Util;
+import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.CMOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+import org.apache.sysds.utils.MemoryEstimates;
+
+public class ColGroupLinearFunctional extends AColGroupCompressed {
+
+ private static final long serialVersionUID = -2811822570758221975L;
+
+ // Needed for numerical robustness when checking if a value is
contained in a column
+ private final static double CONTAINS_VALUE_THRESHOLD = 1e-6;
+
+ protected double[] _coefficents;
+
+ protected int _numRows;
+
+ /** Constructor for serialization */
+ protected ColGroupLinearFunctional() {
+ super();
+ }
+
+ /**
+ * Constructs a Linear Functional Column Group that compresses its
content using a linear functional.
+ *
+ * @param colIndices The Column indexes for the column group.
+ * @param coefficents Array where the first `colIndices.length` entries
are the intercepts and the next
+ * `colIndices.length` entries are the slopes
+ * @param numRows Number of rows encoded within this column group.
+ */
+ private ColGroupLinearFunctional(int[] colIndices, double[]
coefficents, int numRows) {
+ super(colIndices);
+ this._coefficents = coefficents;
+ this._numRows = numRows;
+ }
+
+ /**
+ * Generate a linear functional column group.
+ *
+ * @param colIndices The specific column indexes that is contained in
this column group.
+ * @param coefficents Array where the first `colIndices.length` entries
are the intercepts and the next
+ * `colIndices.length` entries are the slopes
+ * @param numRows Number of rows encoded within this column group.
+ * @return A LinearFunctional column group.
+ */
+ public static AColGroup create(int[] colIndices, double[] coefficents,
int numRows) {
+ if(coefficents.length != 2 * colIndices.length)
+ throw new DMLCompressionException("Invalid size of
values compared to columns");
+
+ boolean allSlopesConstant = true;
+ for(int j = 0; j < colIndices.length; j++) {
+ if(coefficents[colIndices.length + j] != 0) {
+ allSlopesConstant = false;
+ break;
+ }
+ }
+
+ if(allSlopesConstant) {
+ boolean allInterceptsZero = true;
+ for(int j = 0; j < colIndices.length; j++) {
+ if(coefficents[j] != 0) {
+ allInterceptsZero = false;
+ break;
+ }
+ }
+
+ if(allInterceptsZero)
+ return new ColGroupEmpty(colIndices);
+ else {
+ double[] intercepts = new
double[colIndices.length];
+ System.arraycopy(coefficents, 0, intercepts, 0,
colIndices.length);
+ return ColGroupConst.create(colIndices,
intercepts);
+ }
+ }
+ else
+ return new ColGroupLinearFunctional(colIndices,
coefficents, numRows);
+ }
+
+ public double getInterceptForColumn(int colIdx) {
+ return this._coefficents[colIdx];
+ }
+
+ public double getSlopeForColumn(int colIdx) {
+ return this._coefficents[this._colIndexes.length + colIdx];
+ }
+
+ public int getNumRows() {
+ return _numRows;
+ }
+
+ @Override
+ protected void computeRowMxx(double[] c, Builtin builtin, int rl, int
ru, double[] preAgg) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public CompressionType getCompType() {
+ return CompressionType.LinearFunctional;
+ }
+
+ @Override
+ public ColGroupType getColGroupType() {
+ return ColGroupType.LinearFunctional;
+ }
+
+ @Override
+ public double getMin() {
+ double min = Double.POSITIVE_INFINITY;
+
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ if(slope >= 0 && (intercept + slope) < min) {
+ min = intercept + slope;
+ }
+ else if(slope < 0 && (intercept + _numRows * slope) <
min) {
+ min = intercept + _numRows * slope;
+ }
+ }
+
+ return min;
+ }
+
+ @Override
+ public double getMax() {
+ double max = Double.NEGATIVE_INFINITY;
+
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ if(slope >= 0 && (intercept + _numRows * slope) > max) {
+ max = intercept + _numRows * slope;
+ }
+ else if(slope < 0 && (intercept + slope) > max) {
+ max = intercept + slope;
+ }
+ }
+
+ return max;
+ }
+
+ @Override
+ public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int
offR, int offC) {
+ final int nCol = getNumCols();
+ final double[] accumulators = new double[nCol];
+
+ // copy intercepts into accumulators array
+ System.arraycopy(_coefficents, 0, accumulators, 0, nCol);
+
+ int offT = rl + offR;
+ for(int row = rl; row < ru; row++, offT++) {
+ final double[] c = db.values(offT);
+ final int off = db.pos(offT) + offC;
+
+ for(int j = 0; j < nCol; j++) {
+ accumulators[j] += getSlopeForColumn(j);
+ c[off + _colIndexes[j]] += accumulators[j];
+ }
+ }
+ }
+
+ @Override
+ public void decompressToSparseBlock(SparseBlock ret, int rl, int ru,
int offR, int offC) {
+ final int nCol = _colIndexes.length;
+ for(int i = rl, offT = rl + offR; i < ru; i++, offT++) {
+ for(int j = 0; j < nCol; j++)
+ ret.append(offT, _colIndexes[j] + offC,
getIdx(i, j));
+ }
+ }
+
+ @Override
+ public double getIdx(int r, int colIdx) {
+ return getInterceptForColumn(colIdx) +
getSlopeForColumn(colIdx) * (r + 1);
+ }
+
+ @Override
+ public AColGroup scalarOperation(ScalarOperator op) {
+ double[] coefficients_new = new double[_coefficents.length];
+
+ if(op.fn instanceof Plus || op.fn instanceof Minus) {
+ // copy slopes into new array, since they do not change
if we add/subtract a scalar
+ System.arraycopy(_coefficents, 0, coefficients_new,
getNumCols(), getNumCols());
+ // absorb plus/minus into intercept
+ for(int col = 0; col < getNumCols(); col++)
+ coefficients_new[col] =
op.executeScalar(_coefficents[col]);
+
+ return create(_colIndexes, coefficients_new, _numRows);
+ }
+ else if(op.fn instanceof Multiply || op.fn instanceof Divide) {
+ // multiply/divide changes intercepts & slopes
+ for(int j = 0; j < _coefficents.length; j++)
+ coefficients_new[j] =
op.executeScalar(_coefficents[j]);
+
+ return create(_colIndexes, coefficients_new, _numRows);
+ }
+ else {
+ throw new NotImplementedException();
+ }
+
+ }
+
+ @Override
+ public AColGroup unaryOperation(UnaryOperator op) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean
isRowSafe) {
+ return binaryRowOp(op, v, isRowSafe, true);
+ }
+
+ @Override
+ public AColGroup binaryRowOpRight(BinaryOperator op, double[] v,
boolean isRowSafe) {
+ return binaryRowOp(op, v, isRowSafe, false);
+ }
+
+ private AColGroup binaryRowOp(BinaryOperator op, double[] v, boolean
isRowSafe, boolean left) {
+ double[] coefficients_new = new double[_coefficents.length];
+
+ if(op.fn instanceof Plus || op.fn instanceof Minus) {
+ // copy slopes into new array, since they do not change
if we add/subtract a scalar
+ System.arraycopy(_coefficents, 0, coefficients_new,
getNumCols(), getNumCols());
+
+ // absorb plus/minus into intercept
+ if(left) {
+ for(int col = 0; col < getNumCols(); col++)
+ coefficients_new[col] =
op.fn.execute(v[_colIndexes[col]], _coefficents[col]);
+ }
+ else {
+ for(int col = 0; col < getNumCols(); col++)
+ coefficients_new[col] =
op.fn.execute(_coefficents[col], v[_colIndexes[col]]);
+ }
+
+ return create(_colIndexes, coefficients_new, _numRows);
+ }
+ else if(op.fn instanceof Multiply || op.fn instanceof Divide) {
+ // multiply/divide changes intercepts & slopes
+ if(left) {
+ for(int col = 0; col < getNumCols(); col++) {
+ // update intercept
+ coefficients_new[col] =
op.fn.execute(v[_colIndexes[col]], _coefficents[col]);
+ // update slope
+ coefficients_new[col + getNumCols()] =
op.fn.execute(v[_colIndexes[col]],
+ _coefficents[col +
getNumCols()]);
+ }
+ }
+ else {
+ for(int col = 0; col < getNumCols(); col++) {
+ // update intercept
+ coefficients_new[col] =
op.fn.execute(_coefficents[col], v[_colIndexes[col]]);
+ // update slope
+ coefficients_new[col + getNumCols()] =
op.fn.execute(_coefficents[col + getNumCols()],
+ v[_colIndexes[col]]);
+ }
+ }
+
+ return create(_colIndexes, coefficients_new, _numRows);
+ }
+ else {
+ throw new NotImplementedException();
+ }
+ }
+
+ @Override
+ protected double computeMxx(double c, Builtin builtin) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ protected void computeColMxx(double[] c, Builtin builtin) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ protected void computeSum(double[] c, int nRows) {
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ c[0] += nRows * (intercept + (nRows + 1) * slope / 2);
+ }
+ }
+
+ @Override
+ public void computeColSums(double[] c, int nRows) {
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ c[_colIndexes[col]] += nRows * (intercept + (nRows + 1)
* slope / 2);
+ }
+ }
+
+ @Override
+ protected void computeSumSq(double[] c, int nRows) {
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ // Given the intercept and slope of a column, the sum
of the squared components of the column reads
+ // \sum_{i=1}^n (intercept + slope * i)^2
+ // We get a closed form expression by expanding the
binomial and using the fact that
+ // \sum_{i=1}^n i = n(n+1)/2 and \sum_{i=1}^n i^2 =
n(n+1)(2n+1)/6
+
+ c[0] += nRows * (Math.pow(intercept, 2) + (nRows + 1) *
slope * intercept +
+ (nRows + 1) * (2 * nRows + 1) * Math.pow(slope,
2) / 6);
+ }
+ }
+
+ @Override
+ protected void computeColSumsSq(double[] c, int nRows) {
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ c[_colIndexes[col]] += nRows * (Math.pow(intercept, 2)
+ (nRows + 1) * slope * intercept +
+ (nRows + 1) * (2 * nRows + 1) * Math.pow(slope,
2) / 6);
+ }
+ }
+
+ @Override
+ protected void computeRowSums(double[] c, int rl, int ru, double[]
preAgg) {
+ double intercept_sum = preAgg[0];
+ double slope_sum = preAgg[1];
+
+ for(int rix = rl; rix < ru; rix++)
+ c[rix] += intercept_sum + slope_sum * (rix + 1);
+ }
+
+ @Override
+ public int getNumValues() {
+ return 0;
+ }
+
+ @Override
+ public AColGroup rightMultByMatrix(MatrixBlock right) {
+ final int nColR = right.getNumColumns();
+ final int[] outputCols = Util.genColsIndices(nColR);
+
+ // TODO: add specialization for sparse/dense matrix blocks
+ MatrixBlock result = new MatrixBlock(_numRows, nColR, false);
+ for(int j = 0; j < nColR; j++) {
+ double bias_accum = 0.0;
+ double slope_accum = 0.0;
+
+ for(int c = 0; c < _colIndexes.length; c++) {
+ bias_accum += right.getValue(_colIndexes[c], j)
* getInterceptForColumn(c);
+ slope_accum += right.getValue(_colIndexes[c],
j) * getSlopeForColumn(c);
+ }
+
+ for(int r = 0; r < _numRows; r++) {
+ result.setValue(r, j, bias_accum + (r + 1) *
slope_accum);
+ }
+ }
+
+ // returns an uncompressed ColGroup
+ return ColGroupUncompressed.create(result, outputCols);
+ }
+
+ @Override
+ public void tsmm(double[] ret, int numColumns, int nRows) {
+ // runs in O(tCol^2) since dot-products take O(1) time to
compute when both vectors are linearly compressed
+ final int tCol = _colIndexes.length;
+
+ final double sumIndices = nRows * (nRows + 1) / 2.0;
+ final double sumSquaredIndices = nRows * (nRows + 1) * (2 *
nRows + 1) / 6.0;
+ for(int row = 0; row < tCol; row++) {
+ final double alpha1 = nRows *
getInterceptForColumn(row) + sumIndices * getSlopeForColumn(row);
+ final double alpha2 = sumIndices *
getInterceptForColumn(row) + sumSquaredIndices * getSlopeForColumn(row);
+ final int offRet = _colIndexes[row] * numColumns;
+ for(int col = row; col < tCol; col++) {
+ ret[offRet + _colIndexes[col]] += alpha1 *
getInterceptForColumn(col) + alpha2 * getSlopeForColumn(col);
+ }
+ }
+ }
+
+ @Override
+ public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock
result, int rl, int ru, int cl, int cu) {
+ throw new DMLCompressionException("This method should never be
called");
+ }
+
+ @Override
+ public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result) {
+ if(lhs instanceof ColGroupEmpty) {
+ return;
+ }
+
+ MatrixBlock tmpRet = new MatrixBlock(lhs.getNumCols(),
_colIndexes.length, 0);
+
+ if(lhs instanceof ColGroupUncompressed) {
+ ColGroupUncompressed lhsUC = (ColGroupUncompressed) lhs;
+ int numRowsLeft = lhsUC.getData().getNumRows();
+
+ double[] colSumsAndWeightedColSums = new double[2 *
lhs.getNumCols()];
+ for(int j = 0, offTmp = 0; j < lhs.getNumCols(); j++,
offTmp += 2) {
+ for(int i = 0; i < numRowsLeft; i++) {
+ colSumsAndWeightedColSums[offTmp] +=
lhs.getIdx(i, j);
+ colSumsAndWeightedColSums[offTmp + 1]
+= (i + 1) * lhs.getIdx(i, j);
+ }
+ }
+
+ MatrixBlock sumMatrix = new
MatrixBlock(lhs.getNumCols(), 2, colSumsAndWeightedColSums);
+ MatrixBlock coefficientMatrix = new MatrixBlock(2,
_colIndexes.length, _coefficents);
+
+ LibMatrixMult.matrixMult(sumMatrix, coefficientMatrix,
tmpRet);
+ }
+ else if(lhs instanceof ColGroupLinearFunctional) {
+ ColGroupLinearFunctional lhsLF =
(ColGroupLinearFunctional) lhs;
+
+ final double sumIndices = _numRows * (_numRows + 1) /
2.0;
+ final double sumSquaredIndices = _numRows * (_numRows +
1) * (2 * _numRows + 1) / 6.0;
+
+ MatrixBlock weightMatrix = new MatrixBlock(2, 2,
+ new double[] {_numRows, sumIndices, sumIndices,
sumSquaredIndices});
+ MatrixBlock coefficientMatrixLhs = new MatrixBlock(2,
lhsLF._colIndexes.length, lhsLF._coefficents);
+ MatrixBlock coefficientMatrixRhs = new MatrixBlock(2,
_colIndexes.length, _coefficents);
+
+ coefficientMatrixLhs =
LibMatrixReorg.transposeInPlace(coefficientMatrixLhs,
+ InfrastructureAnalyzer.getLocalParallelism());
+
+ // We simply compute a matrix multiplication chain in
coefficient space, i.e.,
+ // t(L) %*% R = t(coeff(L)) %*% W %*% coeff(R)
+ // where W is a weight matrix capturing the size of the
shared dimension (weightMatrix above)
+ // and coeff(X) denotes the 2 x n matrix of the m x n
matrix X.
+ MatrixBlock tmp = new MatrixBlock(lhs.getNumCols(), 2,
false);
+ LibMatrixMult.matrixMult(coefficientMatrixLhs,
weightMatrix, tmp);
+ LibMatrixMult.matrixMult(tmp, coefficientMatrixRhs,
tmpRet);
+ }
+ else if(lhs instanceof APreAgg) {
+ // TODO: implement
+ throw new NotImplementedException();
+ }
+ else {
+ throw new NotImplementedException();
+ }
+
+ ColGroupUtils.copyValuesColGroupMatrixBlocks(lhs, this, tmpRet,
result);
+ }
+
+ @Override
+ public void tsmmAColGroup(AColGroup other, MatrixBlock result) {
+ throw new DMLCompressionException("Should not be called");
+ }
+
+ @Override
+ protected AColGroup sliceSingleColumn(int idx) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[]
outputCols) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public AColGroup copy() {
+ return this;
+ }
+
+ @Override
+ public boolean containsValue(double pattern) {
+ for(int col = 0; col < getNumCols(); col++) {
+ if(colContainsValue(col, pattern))
+ return true;
+ }
+
+ return false;
+ }
+
+ public boolean colContainsValue(int col, double pattern) {
+ if(pattern == getInterceptForColumn(col))
+ return Math.abs(getSlopeForColumn(col)) <
CONTAINS_VALUE_THRESHOLD;
+
+ double div = (pattern - getInterceptForColumn(col)) /
getSlopeForColumn(col);
+ double diffToNextInt = Math.min(Math.ceil(div) - div, div -
Math.floor(div));
+
+ return Math.abs(diffToNextInt) < CONTAINS_VALUE_THRESHOLD;
+ }
+
+ @Override
+ public long getNumberNonZeros(int nRows) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public AColGroup replace(double pattern, double replace) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public long getExactSizeOnDisk() {
+ long ret = super.getExactSizeOnDisk();
+ ret += MemoryEstimates.doubleArrayCost(_coefficents.length);
+ ret += 4L; // _numRows
+ return ret;
+ }
+
+ @Override
+ protected void computeProduct(double[] c, int nRows) {
+ if(containsValue(0)) {
+ c[0] = 0;
+ return;
+ }
+
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+
+ for(int i = 0; i < nRows; i++) {
+ c[0] *= intercept + slope * (i + 1);
+ }
+ }
+ }
+
+ @Override
+ protected void computeRowProduct(double[] c, int rl, int ru, double[]
preAgg) {
+ for(int rix = rl; rix < ru; rix++) {
+ for(int col = 0; col < getNumCols(); col++) {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ c[rix] *= intercept + slope * (rix + 1);
+ }
+ }
+ }
+
+ @Override
+ protected void computeColProduct(double[] c, int nRows) {
+ for(int col = 0; col < getNumCols(); col++) {
+ if(colContainsValue(col, 0)) {
+ c[_colIndexes[col]] = 0;
+ }
+ else {
+ double intercept = getInterceptForColumn(col);
+ double slope = getSlopeForColumn(col);
+ for(int i = 0; i < nRows; i++) {
+ c[_colIndexes[col]] *= intercept +
slope * (i + 1);
+ }
+ }
+ }
+ }
+
+ @Override
+ protected double[] preAggSumRows() {
+ double intercept_sum = 0;
+ for(int col = 0; col < getNumCols(); col++)
+ intercept_sum += getInterceptForColumn(col);
+
+ double slope_sum = 0;
+ for(int col = 0; col < getNumCols(); col++)
+ slope_sum += getSlopeForColumn(col);
+
+ return new double[] {intercept_sum, slope_sum};
+ }
+
+ @Override
+ protected double[] preAggSumSqRows() {
+ return null;
+ }
+
+ @Override
+ protected double[] preAggProductRows() {
+ return null;
+ }
+
+ @Override
+ protected double[] preAggBuiltinRows(Builtin builtin) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public long estimateInMemorySize() {
+ return
ColGroupSizes.estimateInMemorySizeLinearFunctional(getNumCols());
+ }
+
+ @Override
+ public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int
nRows) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public double getCost(ComputationCostEstimator e, int nRows) {
+ LOG.warn("Cost calculation for LinearFunctional ColGroup is not
precise");
+ final int nCols = getNumCols();
+ // We store 2 tuples in this column group, namely intercepts
and slopes
+ return e.getCost(nRows, nRows, nCols, 2, 1.0);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(super.toString());
+ sb.append(String.format("\n%15s", " Intercepts: " +
Arrays.toString(getIntercepts())));
+ sb.append(String.format("\n%15s", " Slopes: " +
Arrays.toString(getSlopes())));
+ return sb.toString();
+ }
+
+ public double[] getIntercepts() {
+ double[] intercepts = new double[getNumCols()];
+ for(int col = 0; col < getNumCols(); col++)
+ intercepts[col] = getInterceptForColumn(col);
+
+ return intercepts;
+ }
+
+ public double[] getSlopes() {
+ double[] slopes = new double[getNumCols()];
+ for(int col = 0; col < getNumCols(); col++)
+ slopes[col] = getSlopeForColumn(col);
+
+ return slopes;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
index dc69ae7e59..5f273e9242 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
@@ -107,4 +107,13 @@ public final class ColGroupSizes {
size += MatrixBlock.estimateSizeInMemory(nrRows, nrColumns,
(nrColumns > 1) ? sparsity : 1);
return size;
}
+
+ public static long estimateInMemorySizeLinearFunctional(int nrColumns) {
+ long size = 0;
+ // Since the Object is a col group the overhead from the Memory
Size group is added
+ size += estimateInMemorySizeGroup(nrColumns);
+ size += MemoryEstimates.doubleArrayCost(2L * nrColumns); //
coefficients; per column, we store 2 doubles (slope & intercept)
+ size += 4; // _numRows
+ return size;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index 862ee09bb1..0d78b95dc1 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -351,7 +351,7 @@ public class ColGroupUncompressed extends AColGroup {
}
}
- // @Override
+// @Override
public void leftMultByMatrix(MatrixBlock matrix, MatrixBlock result,
int rl, int ru) {
final MatrixBlock tmpRet = new MatrixBlock(ru - rl,
_data.getNumColumns(), false);
@@ -560,33 +560,7 @@ public class ColGroupUncompressed extends AColGroup {
LibMatrixMult.matrixMult(transposed,
this._data, tmpRet);
}
- final double[] resV = result.getDenseBlockValues();
- if(tmpRet.isEmpty())
- return;
- else if(tmpRet.isInSparseFormat()) {
- SparseBlock sb = tmpRet.getSparseBlock();
- for(int row = 0; row < lhs._colIndexes.length;
row++) {
- if(sb.isEmpty(row))
- continue;
- final int apos = sb.pos(row);
- final int alen = sb.size(row) + apos;
- final int[] aix = sb.indexes(row);
- final double[] avals = sb.values(row);
- final int offRes = lhs._colIndexes[row]
* result.getNumColumns();
- for(int col = apos; col < alen; col++)
- resV[offRes +
_colIndexes[aix[col]]] += avals[col];
- }
- }
- else {
- double[] tmpRetV = tmpRet.getDenseBlockValues();
- for(int row = 0; row < lhs._colIndexes.length;
row++) {
- final int offRes = lhs._colIndexes[row]
* result.getNumColumns();
- final int offTmp =
lhs._colIndexes.length * row;
- for(int col = 0; col <
_colIndexes.length; col++) {
- resV[offRes + _colIndexes[col]]
+= tmpRetV[offTmp + col];
- }
- }
- }
+ ColGroupUtils.copyValuesColGroupMatrixBlocks(lhs, this,
tmpRet, result);
}
else if(lhs instanceof APreAgg) {
// throw new NotImplementedException();
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java
index f33d2dee29..55b7be3243 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java
@@ -19,7 +19,9 @@
package org.apache.sysds.runtime.compress.colgroup;
+import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
public class ColGroupUtils {
@@ -62,4 +64,41 @@ public class ColGroupUtils {
return ret;
}
+ /**
+ * Copy values from tmpResult into correct positions of result
(according to colIndexes in lhs and rhs)
+ *
+ * @param lhs Left ColumnGroup
+ * @param rhs Right ColumnGroup
+ * @param tmpResult The matrix block to move values from
+ * @param result The result matrix block to move values to
+ */
+ protected final static void copyValuesColGroupMatrixBlocks(AColGroup
lhs, AColGroup rhs, MatrixBlock tmpResult, MatrixBlock result) {
+ final double[] resV = result.getDenseBlockValues();
+ if(tmpResult.isEmpty())
+ return;
+ else if(tmpResult.isInSparseFormat()) {
+ SparseBlock sb = tmpResult.getSparseBlock();
+ for(int row = 0; row < lhs._colIndexes.length; row++) {
+ if(sb.isEmpty(row))
+ continue;
+ final int apos = sb.pos(row);
+ final int alen = sb.size(row) + apos;
+ final int[] aix = sb.indexes(row);
+ final double[] avals = sb.values(row);
+ final int offRes = lhs._colIndexes[row] *
result.getNumColumns();
+ for(int col = apos; col < alen; col++)
+ resV[offRes +
rhs._colIndexes[aix[col]]] += avals[col];
+ }
+ }
+ else {
+ double[] tmpRetV = tmpResult.getDenseBlockValues();
+ for(int row = 0; row < lhs.getNumCols(); row++) {
+ final int offRes = lhs._colIndexes[row] *
result.getNumColumns();
+ final int offTmp = row * rhs.getNumCols();
+ for(int col = 0; col < rhs.getNumCols(); col++)
{
+ resV[offRes + rhs._colIndexes[col]] +=
tmpRetV[offTmp + col];
+ }
+ }
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/functional/LinearRegression.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/functional/LinearRegression.java
new file mode 100644
index 0000000000..42a7cb7a51
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/functional/LinearRegression.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup.functional;
+
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.readers.ReaderColumnSelection;
+import org.apache.sysds.runtime.compress.utils.DblArray;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class LinearRegression {
+
+ public static double[] regressMatrixBlock(MatrixBlock rawBlock, int[]
colIndexes, boolean transposed) {
+ final int nRows = transposed ? rawBlock.getNumColumns() :
rawBlock.getNumRows();
+
+ if(nRows <= 1)
+ throw new DMLCompressionException("At least 2 data
points are required to fit a linear function.");
+ else if(colIndexes.length < 1)
+ throw new DMLCompressionException("At least 1 column
must be specified for compression.");
+
+ // the first `colIndexes.length` entries represent the
intercepts (beta0)
+ // the second `colIndexes.length` entries represent the slopes
(beta1)
+ double[] beta0_beta1 = new double[2 * colIndexes.length];
+
+ double s_xx = (Math.pow(nRows, 3) - nRows) / 12;
+ double x_bar = (double) (nRows + 1) / 2;
+
+ double[] colSums = new double[colIndexes.length];
+ double[] weightedColSums = new double[colIndexes.length];
+
+ if(colIndexes.length == 1) {
+ for (int rowIdx = 0; rowIdx < nRows; rowIdx++) {
+ double value = transposed ?
rawBlock.getValue(colIndexes[0], rowIdx) : rawBlock.getValue(rowIdx,
colIndexes[0]);
+ colSums[0] += value;
+ weightedColSums[0] += (rowIdx + 1) * value;
+ }
+ } else {
+ ReaderColumnSelection reader =
ReaderColumnSelection.createReader(rawBlock, colIndexes, transposed);
+
+ DblArray cellVals;
+ while((cellVals = reader.nextRow()) != null) {
+ int rowIdx = reader.getCurrentRowIndex() + 1;
+ double[] row = cellVals.getData();
+
+ for(int i = 0; i < colIndexes.length; i++) {
+ colSums[i] += row[i];
+ weightedColSums[i] += rowIdx * row[i];
+ }
+ }
+ }
+
+ for(int i = 0; i < colIndexes.length; i++) {
+ beta0_beta1[colIndexes.length + i] = (-x_bar *
colSums[i] + weightedColSums[i]) / s_xx;
+ beta0_beta1[i] = (colSums[i] / nRows) -
beta0_beta1[colIndexes.length + i] * x_bar;
+ }
+
+ return beta0_beta1;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
index 7c330c1a1e..3b0667bd19 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
@@ -189,6 +189,8 @@ public class CompressedSizeInfoColGroup {
private static long getCompressionSize(int numCols, CompressionType ct,
EstimationFactors fact) {
int nv;
switch(ct) {
+ case LinearFunctional:
+ return
ColGroupSizes.estimateInMemorySizeLinearFunctional(numCols);
case DeltaDDC: // TODO add proper extraction
case DDC:
nv = fact.numVals + (fact.numOffs <
fact.numRows ? 1 : 0);
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalBase.java
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalBase.java
new file mode 100644
index 0000000000..96311c9111
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalBase.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.colgroup;
+
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.EnumSet;
+import java.util.Random;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupLinearFunctional;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorExact;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.utils.Util;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public abstract class ColGroupLinearFunctionalBase {
+
+ protected static final Log LOG =
LogFactory.getLog(ColGroupLinearFunctionalBase.class.getName());
+ private final static Random random = new Random();
+ protected final AColGroup base;
+ protected final ColGroupLinearFunctional lin;
+ protected final AColGroup baseLeft;
+ protected final int nRowLeft;
+ protected final int nColLeft;
+
+ protected final int nRowRight;
+ protected final int nColRight;
+
+ protected final AColGroup cgLeft;
+ protected final ColGroupUncompressed cgRight;
+ protected final int nRow;
+ protected final double tolerance;
+
+ @Parameters
+ public static Collection<Object[]> data() {
+ ArrayList<Object[]> tests = new ArrayList<>();
+
+ try {
+ addLinCases(tests);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail("failed constructing tests");
+ }
+
+ return tests;
+ }
+
+ public ColGroupLinearFunctionalBase(AColGroup base,
ColGroupLinearFunctional lin, AColGroup baseLeft,
+ AColGroup cgLeft, int nRowLeft, int nColLeft, int nRowRight,
int nColRight, ColGroupUncompressed cgRight,
+ double tolerance) {
+ if(lin.getNumCols() != base.getNumCols())
+ fail("Linearly compressed ColGroup and Base ColGroup
must have same number of columns");
+
+ if(nRowLeft != lin.getNumRows())
+ fail("Transposed left ColGroup and center ColGroup
(`lin`) must have compatible dimensions");
+
+ int[] colIndices = lin.getColIndices();
+ if(colIndices[colIndices.length - 1] > nRowRight)
+ fail("Right ColGroup must have at least as many rows as
the largest column index of center ColGroup (`lin`)");
+
+ this.base = base;
+ this.lin = lin;
+ this.baseLeft = baseLeft;
+ this.nRowLeft = nRowLeft;
+ this.nColLeft = nColLeft;
+ this.nRowRight = nRowRight;
+ this.nColRight = nColRight;
+ this.cgLeft = cgLeft;
+ this.cgRight = cgRight;
+ this.tolerance = tolerance;
+ this.nRow = lin.getNumRows();
+ }
+
+ protected static void addLinCases(ArrayList<Object[]> tests) {
+ double[][] data = new double[][] {{1, 2, 3, 4, 5}, {-4, 2, 8,
14, 20}};
+ double[][] dataRight = new double[][] {{1, -2, 23, 7}, {4, 11,
-10, -2}};
+ double[][] dataLeft = new double[][] {{8, 3, 7, 12, -3}, {-1,
8, 4, -2, -2}, {3, 4, 2, 0, -1}};
+ int[] colIndexesLeft = new int[] {0, 2};
+
+ double[][] dataLeftCompressed = new double[][] {{8, 4, 0, -4,
-8}, {-1, 0, 1, 2, 3}};
+ int[] colIndexesLeftCompressed = new int[] {0};
+
+ tests
+ .add(createInitParams(data, true, null, dataLeft, true,
colIndexesLeft, false, dataRight, true, null, 0.001));
+
+ tests.add(createInitParams(data, true, null,
dataLeftCompressed, true, colIndexesLeftCompressed, true, dataRight,
+ true, null, 0.001));
+
+ tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5}},
true, null, null, true, null, true, dataRight, true,
+ null, 0.001));
+
+ tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5}},
true, null, null, true, null, true, dataRight, true,
+ null, 0.001));
+
+ tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5}, {1,
1, 1, 1, 1}, {4, 2, 4, 2, 4}}, true,
+ new int[] {0, 1}, null, true, null, true, dataRight,
true, null, 0.001));
+
+ tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5},
{-1, -2, -3, -4, -5}}, true, null, null, true, null,
+ true, dataRight, true, null, 0.001));
+
+ double[][] randomData = generateTestMatrixLinear(80, 100, -100,
100, -1, 1, 42);
+ double[][] randomDataLeft = generateTestMatrixLinear(80, 50,
-100, 100, -1, 1, 43);
+ double[][] randomDataRight = generateTestMatrixLinear(100, 500,
-100, 100, -1, 1, 44);
+
+ tests.add(createInitParams(randomData, false, null,
randomDataLeft, false, null, true, randomDataRight, true,
+ null, 0.001));
+ }
+
+ protected static Object[] createInitParams(double[][] data, boolean
isTransposed, int[] colIndexes,
+ double[][] dataLeft, boolean transposedLeft, int[]
colIndexesLeft, boolean linCompressLeft, double[][] dataRight,
+ boolean transposedRight, int[] colIndexesRight, double
tolerance) {
+ if(dataLeft == null)
+ dataLeft = data;
+
+ // int nRow = isTransposed ? data[0].length : data.length;
+ int nCol = isTransposed ? data.length : data[0].length;
+ int nRowLeft = transposedLeft ? dataLeft[0].length :
dataLeft.length;
+ int nColLeft = transposedLeft ? dataLeft.length :
dataLeft[0].length;
+ int nRowRight = transposedRight ? dataRight[0].length :
dataRight.length;
+ int nColRight = transposedRight ? dataRight.length :
dataRight[0].length;
+
+ if(colIndexes == null)
+ colIndexes = Util.genColsIndices(nCol);
+
+ if(colIndexesLeft == null)
+ colIndexesLeft = Util.genColsIndices(nColLeft);
+
+ if(colIndexesRight == null)
+ colIndexesRight = Util.genColsIndices(nColRight);
+
+ return new Object[] {cgUncompressed(data, colIndexes,
isTransposed),
+ cgLinCompressed(data, colIndexes, isTransposed),
cgUncompressed(dataLeft, colIndexesLeft, transposedLeft),
+ linCompressLeft ? cgLinCompressed(dataLeft,
colIndexesLeft, transposedLeft) : cgUncompressed(dataLeft,
+ colIndexesLeft, transposedLeft),
+ nRowLeft, nColLeft, nRowRight, nColRight,
cgUncompressed(dataRight, colIndexesRight, transposedRight),
+ tolerance};
+ }
+
+ protected static AColGroup cgUncompressed(double[][] data, int[]
colIndexes, boolean isTransposed) {
+ MatrixBlock mbt = DataConverter.convertToMatrixBlock(data);
+ return createColGroup(mbt, colIndexes, isTransposed,
AColGroup.CompressionType.UNCOMPRESSED);
+ }
+
+ protected static AColGroup cgLinCompressed(double[][] data, boolean
isTransposed) {
+ final int numCols = isTransposed ? data.length : data[0].length;
+ return cgLinCompressed(data, Util.genColsIndices(numCols),
isTransposed);
+ }
+
+ protected static AColGroup cgLinCompressed(double[][] data, int[]
colIndexes, boolean isTransposed) {
+ MatrixBlock mbt = DataConverter.convertToMatrixBlock(data);
+ return createColGroup(mbt, colIndexes, isTransposed,
AColGroup.CompressionType.LinearFunctional);
+ }
+
+ public static AColGroup createColGroup(MatrixBlock mbt, int[]
colIndexes, boolean isTransposed,
+ AColGroup.CompressionType cgType) {
+ CompressionSettings cs = new
CompressionSettingsBuilder().setSamplingRatio(1.0)
+ .setValidCompressions(EnumSet.of(cgType)).create();
+ cs.transposed = isTransposed;
+
+ final CompressedSizeInfoColGroup cgi = new
CompressedSizeEstimatorExact(mbt, cs).getColGroupInfo(colIndexes);
+ CompressedSizeInfo csi = new CompressedSizeInfo(cgi);
+ return ColGroupFactory.compressColGroups(mbt, csi, cs,
1).get(0);
+ }
+
+ public static double[] generateLinearColumn(double intercept, double
slope, int length) {
+ double[] result = new double[length];
+ for(int i = 0; i < length; i++) {
+ result[i] = intercept + slope * (i + 1);
+ }
+
+ return result;
+ }
+
+ public static double[][] generateTestMatrixLinear(int rows, int cols,
double minIntercept, double maxIntercept,
+ double minSlope, double maxSlope, long seed) {
+ double[][] coefficients = generateRandomInterceptsSlopes(cols,
minIntercept, maxIntercept, minSlope, maxSlope,
+ seed);
+ return generateTestMatrixLinearColumns(rows, cols,
coefficients[0], coefficients[1]);
+ }
+
+ public static double[][] generateRandomInterceptsSlopes(int cols,
double minIntercept, double maxIntercept,
+ double minSlope, double maxSlope, long seed) {
+
+ double[] intercepts = new double[cols];
+ double[] slopes = new double[cols];
+
+ random.setSeed(seed);
+ for(int j = 0; j < cols; j++) {
+ intercepts[j] = minIntercept + random.nextDouble() *
(maxIntercept - minIntercept);
+ slopes[j] = minSlope + random.nextDouble() * (maxSlope
- minSlope);
+ }
+
+ return new double[][] {intercepts, slopes};
+ }
+
+ public static double[][] generateTestMatrixLinearColumns(int rows, int
cols, double[] intercepts, double[] slopes) {
+ if(intercepts.length != slopes.length || intercepts.length !=
cols)
+ fail("Intercepts and slopes array must both have length
`cols`");
+
+ double[][] data = new double[rows][cols];
+
+ for(int j = 0; j < cols; j++) {
+ double[] linCol = generateLinearColumn(intercepts[j],
slopes[j], rows);
+ for(int i = 0; i < rows; i++) {
+ data[i][j] = linCol[i];
+ }
+ }
+
+ return data;
+ }
+
+ protected double[] getValues(AColGroup cg) {
+ MatrixBlock mb = new MatrixBlock(nRow, cg.getNumCols(), false);
+ mb.allocateDenseBlock();
+ cg.decompressToDenseBlock(mb.getDenseBlock(), 0, nRow);
+ return mb.getDenseBlockValues();
+ }
+
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalTest.java
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalTest.java
new file mode 100644
index 0000000000..9a37030d13
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalTest.java
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.colgroup;
+
+import static org.junit.Assert.fail;
+
+import java.util.Arrays;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupLinearFunctional;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
+import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+public class ColGroupLinearFunctionalTest extends ColGroupLinearFunctionalBase
{
+ protected static final Log LOG =
LogFactory.getLog(ColGroupLinearFunctionalTest.class.getName());
+
+ public ColGroupLinearFunctionalTest(AColGroup base,
ColGroupLinearFunctional lin, AColGroup baseLeft,
+ AColGroup cgLeft, int nRowLeft, int nColLeft, int nRowRight,
int nColRight, ColGroupUncompressed cgRight,
+ double tolerance) {
+ super(base, lin, baseLeft, cgLeft, nRowLeft, nColLeft,
nRowRight, nColRight, cgRight, tolerance);
+ }
+
+ @Test
+ public void testContainsValue() {
+ double[] linValues = getValues(lin);
+ double[] baseValues = getValues(base);
+
+ for(int i = 0; i < linValues.length; i++) {
+ Assert.assertEquals("Base ColGroup and linear ColGroup
must be initialized with the same values", linValues[i],
+ baseValues[i], tolerance);
+ if(!lin.containsValue(baseValues[i])) {
+ // debug
+ System.out.println(baseValues[i]);
+ System.out.println(i);
+
Assert.assertTrue(base.containsValue(baseValues[i]) &&
lin.containsValue(baseValues[i]));
+
+ }
+ Assert.assertTrue(base.containsValue(baseValues[i]) &&
lin.containsValue(baseValues[i]));
+ }
+ }
+
+ @Test
+ public void testTsmm() {
+ int nCol = lin.getNumCols();
+
+ final MatrixBlock resultUncompressed = new
MatrixBlock(lin.getNumCols(), nCol, false);
+ resultUncompressed.allocateDenseBlock();
+ base.tsmm(resultUncompressed, nRow);
+
+ final MatrixBlock resultCompressed = new MatrixBlock(nCol,
nCol, false);
+ resultCompressed.allocateDenseBlock();
+ lin.tsmm(resultCompressed, nRow);
+
+
Assert.assertArrayEquals(resultUncompressed.getDenseBlockValues(),
resultCompressed.getDenseBlockValues(),
+ tolerance);
+ }
+
+ @Test
+ public void testRightMultByMatrix() {
+ MatrixBlock mbtRight = cgRight.getData();
+
+ AColGroup colGroupResultExpected =
base.rightMultByMatrix(mbtRight);
+ MatrixBlock resultExpected = ((ColGroupUncompressed)
colGroupResultExpected).getData();
+ AColGroup colGroupResult = lin.rightMultByMatrix(mbtRight);
+ MatrixBlock result = ((ColGroupUncompressed)
colGroupResult).getData();
+
+ Assert.assertArrayEquals(resultExpected.getDenseBlockValues(),
result.getDenseBlockValues(), tolerance);
+ }
+
+ @Test
+ public void testLeftMultByAColGroup() {
+ if(cgLeft.getCompType() ==
AColGroup.CompressionType.LinearFunctional)
+ leftMultByAColGroup(true);
+ else if(cgLeft.getCompType() ==
AColGroup.CompressionType.UNCOMPRESSED)
+ leftMultByAColGroup(false);
+ else
+ fail("CompressionType not supported for
leftMultByAColGrup");
+ }
+
+ public void leftMultByAColGroup(boolean compressedLeft) {
+ final MatrixBlock result = new MatrixBlock(nRowLeft, nColRight,
false);
+ final MatrixBlock resultExpected = new MatrixBlock(nRowLeft,
nColRight, false);
+ result.allocateDenseBlock();
+ resultExpected.allocateDenseBlock();
+
+ base.leftMultByAColGroup(baseLeft, resultExpected);
+ lin.leftMultByAColGroup(cgLeft, result);
+
+ Assert.assertArrayEquals(resultExpected.getDenseBlockValues(),
result.getDenseBlockValues(), tolerance);
+ }
+
+ @Test
+ public void testColSumsSq() {
+ double[] colSumsExpected = new double[base.getNumCols()];
+ AggregateOperator aop = new AggregateOperator(0,
KahanPlusSq.getKahanPlusSqFnObject());
+ AggregateUnaryOperator auop = new AggregateUnaryOperator(aop,
ReduceRow.getReduceRowFnObject());
+
+ if(base instanceof AColGroupCompressed) {
+ AColGroupCompressed baseComp = (AColGroupCompressed)
base;
+ baseComp.unaryAggregateOperations(auop,
colSumsExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+ }
+ else if(base instanceof ColGroupUncompressed) {
+ MatrixBlock mb = ((ColGroupUncompressed)
base).getData();
+
+ for(int j = 0; j < base.getNumCols(); j++) {
+ double colSum = 0;
+ for(int i = 0; i < nRow; i++) {
+ colSum += Math.pow(mb.getDouble(i, j),
2);
+ }
+ colSumsExpected[j] = colSum;
+ }
+ }
+ else {
+ fail("Base ColGroup type does not support colSumSq.");
+ }
+
+ double[] colSums = new double[lin.getNumCols()];
+ lin.unaryAggregateOperations(auop, colSums, nRow, 0, nRow,
lin.preAggRows(auop));
+
+ Assert.assertArrayEquals(colSumsExpected, colSums, tolerance);
+ }
+
+ @Test
+ public void testProduct() {
+ double[] productExpected = new double[] {1};
+
+ AggregateOperator aop = new AggregateOperator(0,
Multiply.getMultiplyFnObject());
+ AggregateUnaryOperator auop = new AggregateUnaryOperator(aop,
ReduceAll.getReduceAllFnObject());
+
+ if(base instanceof AColGroupCompressed) {
+ AColGroupCompressed baseComp = (AColGroupCompressed)
base;
+ baseComp.unaryAggregateOperations(auop,
productExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+ }
+ else if(base instanceof ColGroupUncompressed) {
+ MatrixBlock mb = ((ColGroupUncompressed)
base).getData();
+
+ for(int j = 0; j < base.getNumCols(); j++) {
+ for(int i = 0; i < nRow; i++) {
+ productExpected[0] *= mb.getDouble(i,
j);
+ }
+ }
+ }
+ else {
+ fail("Base ColGroup type does not support colProduct.");
+ }
+
+ double[] product = new double[] {1};
+ lin.unaryAggregateOperations(auop, product, nRow, 0, nRow,
lin.preAggRows(auop));
+
+ // use relative tolerance since products can get very large
+ double relTolerance = tolerance * Math.abs(productExpected[0]);
+ Assert.assertEquals(productExpected[0], product[0],
relTolerance);
+ }
+
+ @Test
+ public void testMax() {
+ Assert.assertEquals(base.getMax(), lin.getMax(), tolerance);
+ }
+
+ @Test
+ public void testMin() {
+ Assert.assertEquals(base.getMin(), lin.getMin(), tolerance);
+ }
+
+ @Test
+ public void testColProducts() {
+ double[] colProductsExpected = new double[base.getNumCols()];
+
+ AggregateOperator aop = new AggregateOperator(0,
Multiply.getMultiplyFnObject());
+ AggregateUnaryOperator auop = new AggregateUnaryOperator(aop,
ReduceRow.getReduceRowFnObject());
+
+ if(base instanceof AColGroupCompressed) {
+ AColGroupCompressed baseComp = (AColGroupCompressed)
base;
+ baseComp.unaryAggregateOperations(auop,
colProductsExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+ }
+ else if(base instanceof ColGroupUncompressed) {
+ MatrixBlock mb = ((ColGroupUncompressed)
base).getData();
+
+ for(int j = 0; j < base.getNumCols(); j++) {
+ double colProduct = 1;
+ for(int i = 0; i < nRow; i++) {
+ colProduct *= mb.getDouble(i, j);
+ }
+ colProductsExpected[j] = colProduct;
+ }
+ }
+ else {
+ fail("Base ColGroup type does not support colProduct.");
+ }
+
+ double[] colProducts = new double[base.getNumCols()];
+ for(int j = 0; j < base.getNumCols(); j++) {
+ colProducts[j] = 1;
+ }
+
+ lin.unaryAggregateOperations(auop, colProducts, nRow, 0, nRow,
lin.preAggRows(auop));
+
+ // use relative tolerance since column products can get very
large
+ double relTolerance = tolerance *
Math.abs(Arrays.stream(colProductsExpected).max().orElse(0));
+ Assert.assertArrayEquals(colProductsExpected, colProducts,
relTolerance);
+ }
+
+ @Test
+ public void testSumSq() {
+ double[] sumSqExpected = new double[] {0};
+
+ AggregateOperator aop = new AggregateOperator(0,
KahanPlusSq.getKahanPlusSqFnObject());
+ AggregateUnaryOperator auop = new AggregateUnaryOperator(aop,
ReduceAll.getReduceAllFnObject());
+
+ if(base instanceof AColGroupCompressed) {
+ AColGroupCompressed baseComp = (AColGroupCompressed)
base;
+ baseComp.unaryAggregateOperations(auop, sumSqExpected,
nRow, 0, nRow, baseComp.preAggRows(auop));
+ }
+ else if(base instanceof ColGroupUncompressed) {
+ MatrixBlock mb = ((ColGroupUncompressed)
base).getData();
+
+ for(int j = 0; j < base.getNumCols(); j++) {
+ for(int i = 0; i < nRow; i++) {
+ sumSqExpected[0] +=
Math.pow(mb.getDouble(i, j), 2);
+ }
+ }
+ }
+ else {
+ fail("Base ColGroup type does not support sumSq.");
+ }
+
+ double[] sumSq = new double[] {0};
+ lin.unaryAggregateOperations(auop, sumSq, nRow, 0, nRow,
lin.preAggRows(auop));
+
+ Assert.assertEquals(sumSqExpected[0], sumSq[0], tolerance);
+ }
+
+ @Test
+ public void testSum() {
+ double[] colSums = new double[base.getNumCols()];
+ base.computeColSums(colSums, nRow);
+ double sumExpected = Arrays.stream(colSums).sum();
+
+ double[] sum = new double[1];
+ AggregateOperator aop = new AggregateOperator(0,
Plus.getPlusFnObject());
+ AggregateUnaryOperator auop = new AggregateUnaryOperator(aop,
ReduceAll.getReduceAllFnObject());
+ lin.unaryAggregateOperations(auop, sum, nRow, 0, nRow,
lin.preAggRows(auop));
+
+ Assert.assertEquals(sumExpected, sum[0], tolerance);
+ }
+
+ @Test
+ public void testRowSums() {
+ double[] rowSumsExpected = new double[nRow];
+
+ AggregateOperator aop = new AggregateOperator(0,
Plus.getPlusFnObject());
+ AggregateUnaryOperator auop = new AggregateUnaryOperator(aop,
ReduceCol.getReduceColFnObject());
+
+ if(base instanceof AColGroupCompressed) {
+ AColGroupCompressed baseComp = (AColGroupCompressed)
base;
+ baseComp.unaryAggregateOperations(auop,
rowSumsExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+ }
+ else if(base instanceof ColGroupUncompressed) {
+ MatrixBlock mb = ((ColGroupUncompressed)
base).getData();
+
+ for(int i = 0; i < nRow; i++) {
+ double rowSum = 0;
+ for(int j = 0; j < base.getNumCols(); j++) {
+ rowSum += mb.getDouble(i, j);
+ }
+ rowSumsExpected[i] = rowSum;
+ }
+ }
+ else {
+ fail("Base ColGroup type does not support rowSum.");
+ }
+
+ double[] rowSums = new double[nRow];
+ lin.unaryAggregateOperations(auop, rowSums, nRow, 0, nRow,
lin.preAggRows(auop));
+
+ Assert.assertArrayEquals(rowSumsExpected, rowSums, tolerance);
+ }
+
+ @Test
+ public void testColSums() {
+ double[] colSumsExpected = new double[base.getNumCols()];
+ double[] colSums = new double[base.getNumCols()];
+ base.computeColSums(colSumsExpected, nRow);
+ lin.computeColSums(colSums, nRow);
+
+ Assert.assertArrayEquals(colSumsExpected, colSums, tolerance);
+ }
+
+ @Test
+ public void testColumnGroupConstruction() {
+ double[][] constColumn = new double[][] {{1, 1, 1, 1, 1}};
+ AColGroup cgConst = cgLinCompressed(constColumn, true);
+ Assert.assertSame(AColGroup.CompressionType.CONST,
cgConst.getCompType());
+
+ double[][] zeroColumn = new double[][] {{0, 0, 0, 0, 0}};
+ AColGroup cgEmpty = cgLinCompressed(zeroColumn, true);
+ Assert.assertSame(AColGroup.CompressionType.EMPTY,
cgEmpty.getCompType());
+ }
+
+ @Test
+ public void testDecompressToDenseBlock() {
+ MatrixBlock ret = new MatrixBlock(nRow, lin.getNumCols(),
false);
+ ret.allocateDenseBlock();
+ lin.decompressToDenseBlock(ret.getDenseBlock(), 0, nRow);
+
+ MatrixBlock expected = new MatrixBlock(nRow, lin.getNumCols(),
false);
+ expected.allocateDenseBlock();
+ base.decompressToDenseBlock(expected.getDenseBlock(), 0, nRow);
+
+ Assert.assertArrayEquals(expected.getDenseBlockValues(),
ret.getDenseBlockValues(), tolerance);
+ }
+
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java
b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java
new file mode 100644
index 0000000000..b3b9a12fca
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.functional;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.stream.DoubleStream;
+
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.functional.LinearRegression;
+import org.apache.sysds.runtime.compress.utils.Util;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import
org.apache.sysds.test.component.compress.colgroup.ColGroupLinearFunctionalBase;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+public class LinearRegressionTests {
+ protected final double[][] data;
+ protected final int[] colIndexes;
+ protected final boolean isTransposed;
+ protected final double[] expectedCoefficients;
+ protected final Exception expectedException;
+
+ protected final double EQUALITY_TOLERANCE = 1e-4;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ ArrayList<Object[]> tests = new ArrayList<>();
+ try {
+ addCases(tests);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail("failed constructing tests");
+ }
+
+ return tests;
+ }
+
+ public LinearRegressionTests(double[][] data, int[] colIndexes, boolean
isTransposed, double[] expectedCoefficients,
+ Exception expectedException) {
+ this.data = data;
+ this.colIndexes = colIndexes;
+ this.isTransposed = isTransposed;
+ this.expectedCoefficients = expectedCoefficients;
+ this.expectedException = expectedException;
+ }
+
+ protected static void addCases(ArrayList<Object[]> tests) {
+ double[][] data = new double[][] {{1, 1, -3, 4, 5}, {2, 2, 3,
4, 5}, {3, 3, 3, 4, 5}};
+ int[] colIndexes = new int[] {0, 1, 3, 4};
+ double[] trueCoefficients = new double[] {0, 0, 4, 5, 1, 1, 0,
0};
+ tests.add(new Object[] {data, colIndexes, false,
trueCoefficients, null});
+
+ // expect exception if passing columns with single data points
each
+ tests.add(new Object[] {new double[][] {{1, 2, 3}},
Util.genColsIndices(1), false, null,
+ new DMLCompressionException("At least 2 data points are
required to fit a linear function.")});
+
+ // expect exception if passing no colIndexes
+ tests.add(new Object[] {new double[][] {{1, 2, 3}, {2, 3, 4}},
Util.genColsIndices(0), false, null,
+ new DMLCompressionException("At least 1 column must be
specified for compression.")});
+
+ // random matrix
+ int rows = 100;
+ int cols = 200;
+ // TODO: move generateRandomInterceptsSlopes in an appropriate
Util class
+ double[][] randomCoefficients =
ColGroupLinearFunctionalBase.generateRandomInterceptsSlopes(cols, -1000, 1000,
+ -20, 20, 42);
+ // TODO: move generateTestMatrixLinearColumns in an appropriate
Util class
+ double[][] testData =
ColGroupLinearFunctionalBase.generateTestMatrixLinearColumns(rows, cols,
+ randomCoefficients[0], randomCoefficients[1]);
+ tests.add(new Object[] {testData, Util.genColsIndices(cols),
false,
+
DoubleStream.concat(Arrays.stream(randomCoefficients[0]),
Arrays.stream(randomCoefficients[1])).toArray(),
+ null});
+ }
+
+ @Test
+ public void testLinearRegression() {
+ MatrixBlock mbt = DataConverter.convertToMatrixBlock(data);
+ try {
+ double[] coefficients =
LinearRegression.regressMatrixBlock(mbt, colIndexes, isTransposed);
+ assertArrayEquals(expectedCoefficients, coefficients,
EQUALITY_TOLERANCE);
+ }
+ catch(Exception e) {
+ assertEquals(expectedException.getClass(),
e.getClass());
+ assertEquals(expectedException.getMessage(),
e.getMessage());
+ }
+ }
+
+}