This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a6d8bc0975 [MINOR] minor cleanups and optimizations to CLA MM
primitives
a6d8bc0975 is described below
commit a6d8bc09752043f34f9593590f77112c8abfd449
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Feb 3 18:17:07 2025 +0100
[MINOR] minor cleanups and optimizations to CLA MM primitives
This commit include specialized decompressing MM for DDC with identity
matrix dictionaries.
Closes #2210
---
.../runtime/compress/colgroup/ColGroupDDC.java | 26 ++-
.../sysds/runtime/compress/lib/CLALibMMChain.java | 59 +++++-
.../runtime/compress/lib/CLALibRightMultBy.java | 211 +++++++++++++--------
.../sysds/runtime/compress/lib/CLALibTSMM.java | 48 ++---
4 files changed, 230 insertions(+), 114 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index 86ebb4400e..c1b9c65f22 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -601,6 +601,30 @@ public class ColGroupDDC extends APreAgg implements
IMapToDataGroup {
@Override
public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret,
int rl, int ru, int nRows, int crl, int cru) {
+ if(_dict instanceof IdentityDictionary)
+ identityRightDecompressingMult(right, ret, rl, ru, crl,
cru);
+ else
+ defaultRightDecompressingMult(right, ret, rl, ru, crl,
cru);
+ }
+
+ private void identityRightDecompressingMult(MatrixBlock right,
MatrixBlock ret, int rl, int ru, int crl, int cru) {
+ final double[] b = right.getDenseBlockValues();
+ final double[] c = ret.getDenseBlockValues();
+ final int jd = right.getNumColumns();
+ final int vLen = 8;
+ final int lenJ = cru - crl;
+ final int end = cru - (lenJ % vLen);
+ for(int i = rl; i < ru; i++) {
+ int k = _data.getIndex(i);
+ final int offOut = i * jd + crl;
+ final double aa = 1;
+ final int k_right = _colIndexes.get(k);
+ vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right,
vLen);
+
+ }
+ }
+
+ private void defaultRightDecompressingMult(MatrixBlock right,
MatrixBlock ret, int rl, int ru, int crl, int cru) {
final double[] a = _dict.getValues();
final double[] b = right.getDenseBlockValues();
final double[] c = ret.getDenseBlockValues();
@@ -930,8 +954,6 @@ public class ColGroupDDC extends APreAgg implements
IMapToDataGroup {
}
}
-
-
private void leftMMIdentityPreAggregateDenseSingleRow(double[] values,
int pos, double[] values2, int pos2, int cl,
int cu) {
IdentityDictionary a = (IdentityDictionary) _dict;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
index 6207460d3d..d82d58e323 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
@@ -34,6 +34,7 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.utils.stats.Timing;
/**
* Support compressed MM chain operation to fuse the following cases :
@@ -53,6 +54,9 @@ import
org.apache.sysds.runtime.matrix.operators.BinaryOperator;
public final class CLALibMMChain {
static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());
+ /** Reusable cache intermediate double array for temporary
decompression */
+ private static ThreadLocal<double[]> cacheIntermediate = null;
+
private CLALibMMChain() {
// private constructor
}
@@ -87,20 +91,31 @@ public final class CLALibMMChain {
public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock
v, MatrixBlock w, MatrixBlock out,
ChainType ctype, int k) {
+ Timing t = new Timing();
if(x.isEmpty())
return returnEmpty(x, out);
// Morph the columns to efficient types for the operation.
x = filterColGroups(x);
+ double preFilterTime = t.stop();
// Allow overlapping intermediate if the intermediate is
guaranteed not to be overlapping.
final boolean allowOverlap = x.getColGroups().size() == 1 &&
isOverlappingAllowed();
// Right hand side multiplication
- MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v,
null, k, allowOverlap);
+ MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v,
null, k, true);
+
+ double rmmTime = t.stop();
- if(ctype == ChainType.XtwXv) // Multiply intermediate with
vector if needed
+ if(ctype == ChainType.XtwXv) { // Multiply intermediate with
vector if needed
tmp = binaryMultW(tmp, w, k);
+ }
+
+ if(!allowOverlap && tmp instanceof CompressedMatrixBlock) {
+ tmp = decompressIntermediate((CompressedMatrixBlock)
tmp, k);
+ }
+
+ double decompressTime = t.stop();
if(tmp instanceof CompressedMatrixBlock)
// Compressed Compressed Matrix Multiplication
@@ -109,12 +124,50 @@ public final class CLALibMMChain {
// LMM with Compressed - uncompressed multiplication.
CLALibLeftMultBy.leftMultByMatrixTransposed(x, tmp,
out, k);
+ double lmmTime = t.stop();
if(out.getNumColumns() != 1) // transpose the output to make it
a row output if needed
out = LibMatrixReorg.transposeInPlace(out, k);
+ if(LOG.isDebugEnabled()) {
+ StringBuilder sb = new StringBuilder("\n");
+ sb.append("\nPreFilter Time : " + preFilterTime);
+ sb.append("\nChain RMM : " + rmmTime);
+ sb.append("\nChain RMM Decompress: " + decompressTime);
+ sb.append("\nChain LMM : " + lmmTime);
+ sb.append("\nChain Transpose : " + t.stop());
+ LOG.debug(sb.toString());
+ }
+
return out;
}
+ private static MatrixBlock decompressIntermediate(CompressedMatrixBlock
tmp, int k) {
+ // cacheIntermediate
+ final int rows = tmp.getNumRows();
+ final int cols = tmp.getNumColumns();
+ final int nCells = rows * cols;
+ final double[] tmpArr;
+ if(cacheIntermediate == null) {
+ tmpArr = new double[nCells];
+ cacheIntermediate = new ThreadLocal<>();
+ cacheIntermediate.set(tmpArr);
+ }
+ else {
+ double[] cachedArr = cacheIntermediate.get();
+ if(cachedArr == null || cachedArr.length < nCells) {
+ tmpArr = new double[nCells];
+ cacheIntermediate.set(tmpArr);
+ }
+ else {
+ tmpArr = cachedArr;
+ }
+ }
+
+ final MatrixBlock tmpV = new MatrixBlock(tmp.getNumRows(),
tmp.getNumColumns(), tmpArr);
+ CLALibDecompress.decompressTo((CompressedMatrixBlock) tmp,
tmpV, 0, 0, k, false, true);
+ return tmpV;
+ }
+
private static boolean isOverlappingAllowed() {
return
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING);
}
@@ -146,6 +199,8 @@ public final class CLALibMMChain {
final List<AColGroup> groups = x.getColGroups();
final boolean shouldFilter =
CLALibUtils.shouldPreFilter(groups);
if(shouldFilter) {
+ if(CLALibUtils.alreadyPreFiltered(groups,
x.getNumColumns()))
+ return x;
final int nCol = x.getNumColumns();
final double[] constV = new double[nCol];
final List<AColGroup> filteredGroups =
CLALibUtils.filterGroups(groups, constV);
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
index 966051cd8b..f14d6833d9 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.compress.lib;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
@@ -30,23 +29,20 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
-import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
-import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
public final class CLALibRightMultBy {
private static final Log LOG =
LogFactory.getLog(CLALibRightMultBy.class.getName());
- private CLALibRightMultBy(){
+ private CLALibRightMultBy() {
// private constructor
}
@@ -59,42 +55,104 @@ public final class CLALibRightMultBy {
public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, int k,
boolean allowOverlap) {
- final int rr = m1.getNumRows();
- final int rc = m2.getNumColumns();
+ try {
+ final int rr = m1.getNumRows();
+ final int rc = m2.getNumColumns();
- if(m1.isEmpty() || m2.isEmpty()) {
- LOG.trace("Empty right multiply");
- if(ret == null)
- ret = new MatrixBlock(rr, rc, 0);
- else
- ret.reset(rr, rc, 0);
- return ret;
+ if(m1.isEmpty() || m2.isEmpty()) {
+ LOG.trace("Empty right multiply");
+ if(ret == null)
+ ret = new MatrixBlock(rr, rc, 0);
+ else
+ ret.reset(rr, rc, 0);
+ return ret;
+ }
+ else {
+ if(m2 instanceof CompressedMatrixBlock)
+ m2 = ((CompressedMatrixBlock)
m2).getUncompressed("Uncompressed right side of right MM", k);
+
+ if(betterIfDecompressed(m1)) {
+ // perform uncompressed multiplication.
+ return decompressingMatrixMult(m1, m2,
k);
+ }
+
+ if(!allowOverlap) {
+ LOG.trace("Overlapping output not
allowed in call to Right MM");
+ return RMM(m1, m2, k);
+ }
+
+ final CompressedMatrixBlock retC =
RMMOverlapping(m1, m2, k);
+
+ if(retC.isEmpty())
+ return retC;
+ else {
+ if(retC.isOverlapping())
+ retC.setNonZeros((long) rr *
rc); // set non zeros to fully dense in case of overlapping.
+ else
+ retC.recomputeNonZeros(k); //
recompute if non overlapping compressed out.
+ return retC;
+ }
+ }
}
- else {
- if(m2 instanceof CompressedMatrixBlock)
- m2 = ((CompressedMatrixBlock)
m2).getUncompressed("Uncompressed right side of right MM", k);
+ catch(Exception e) {
+ throw new RuntimeException("Failed Right MM", e);
+ }
+ }
- if(!allowOverlap) {
- LOG.trace("Overlapping output not allowed in
call to Right MM");
- return RMM(m1, m2, k);
+ private static MatrixBlock
decompressingMatrixMult(CompressedMatrixBlock m1, MatrixBlock m2, int k)
+ throws Exception {
+ final ExecutorService pool = CommonThreadPool.get(k);
+ try {
+ final int rl = m1.getNumRows();
+ final int cr = m2.getNumColumns();
+ // final int rr = m2.getNumRows(); // shared dim
+ final MatrixBlock ret = new MatrixBlock(rl, cr, false);
+ ret.allocateBlock();
+
+ // MatrixBlock m1uc = m1.decompress(k);
+ final List<Future<Long>> tasks = new ArrayList<>();
+ final List<AColGroup> groups = m1.getColGroups();
+ final int blkI = Math.max((int) Math.ceil((double) rl /
k), 16);
+ final int blkJ = blkI > 16 ? cr : Math.max((cr / k),
512); // make it a multiplicative of 8.
+ for(int i = 0; i < rl; i += blkI) {
+ final int startI = i;
+ final int endI = Math.min(i + blkI, rl);
+ for(int j = 0; j < cr; j += blkJ) {
+ final int startJ = j;
+ final int endJ = Math.min(j + blkJ, cr);
+ tasks.add(pool.submit(() -> {
+ for(AColGroup g : groups)
+
g.rightDecompressingMult(m2, ret, startI, endI, rl, startJ, endJ);
+ return
ret.recomputeNonZeros(startI, endI - 1, startJ, endJ - 1);
+ }));
+ }
}
+ long nnz = 0;
+ for(Future<Long> t : tasks)
+ nnz += t.get();
- final CompressedMatrixBlock retC = RMMOverlapping(m1,
m2, k);
+ ret.setNonZeros(nnz);
+ ret.examSparsity();
+ return ret;
+ }
+ finally {
+ pool.shutdown();
+ }
- if(retC.isEmpty())
- return retC;
- else {
- if(retC.isOverlapping())
- retC.setNonZeros((long) rr * rc); //
set non zeros to fully dense in case of overlapping.
- else
- retC.recomputeNonZeros(); // recompute
if non overlapping compressed out.
- return retC;
+ }
+
+ private static boolean betterIfDecompressed(CompressedMatrixBlock m) {
+ for(AColGroup g : m.getColGroups()) {
+ if(!(g instanceof ColGroupUncompressed) &&
g.getNumValues() * 2 >= m.getNumRows()) {
+ return true;
}
}
-
+ return false;
}
- private static CompressedMatrixBlock
RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) {
+ private static CompressedMatrixBlock
RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k)
+ throws Exception {
+
final int rl = m1.getNumRows();
final int cr = that.getNumColumns();
final int rr = that.getNumRows(); // shared dim
@@ -103,13 +161,19 @@ public final class CLALibRightMultBy {
final CompressedMatrixBlock ret = new CompressedMatrixBlock(rl,
cr);
final boolean shouldFilter =
CLALibUtils.shouldPreFilter(colGroups);
+ final double[] constV;
+ final List<AColGroup> filteredGroups;
- double[] constV = shouldFilter ? new double[rr] : null;
- final List<AColGroup> filteredGroups =
CLALibUtils.filterGroups(colGroups, constV);
- if(colGroups == filteredGroups)
+ if(shouldFilter) {
+ constV = new double[rr];
+ filteredGroups = CLALibUtils.filterGroups(colGroups,
constV);
+ }
+ else {
+ filteredGroups = colGroups;
constV = null;
+ }
- if(k == 1)
+ if(k == 1 || filteredGroups.size() == 1)
RMMSingle(filteredGroups, that, retCg);
else
RMMParallel(filteredGroups, that, retCg, k);
@@ -117,7 +181,7 @@ public final class CLALibRightMultBy {
if(constV != null) {
final MatrixBlock cb = new MatrixBlock(1,
constV.length, constV);
final MatrixBlock cbRet = new MatrixBlock(1,
that.getNumColumns(), false);
- LibMatrixMult.matrixMult(cb, that, cbRet);
+ LibMatrixMult.matrixMult(cb, that, cbRet); // mm on row
vector left.
if(!cbRet.isEmpty())
addConstant(cbRet, retCg);
}
@@ -133,35 +197,18 @@ public final class CLALibRightMultBy {
}
private static void addConstant(MatrixBlock constantRow,
List<AColGroup> out) {
- final int nCol = constantRow.getNumColumns();
- int bestCandidate = -1;
- int bestCandidateValuesSize = Integer.MAX_VALUE;
- for(int i = 0; i < out.size(); i++) {
- AColGroup g = out.get(i);
- if(g instanceof ColGroupDDC && g.getNumCols() == nCol
&& g.getNumValues() < bestCandidateValuesSize)
- bestCandidate = i;
- }
-
constantRow.sparseToDense();
-
- if(bestCandidate != -1) {
- AColGroup bc = out.get(bestCandidate);
- out.remove(bestCandidate);
- AColGroup ng = bc.binaryRowOpRight(new
BinaryOperator(Plus.getPlusFnObject(), 1),
- constantRow.getDenseBlockValues(), true);
- out.add(ng);
- }
- else
-
out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
+
out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
}
- private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock
that, int k) {
+ private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock
that, int k) throws Exception {
+
+ // Timing t = new Timing();
// this version returns a decompressed result.
final int rl = m1.getNumRows();
final int cr = that.getNumColumns();
final int rr = that.getNumRows(); // shared dim
final List<AColGroup> colGroups = m1.getColGroups();
- final List<AColGroup> retCg = new ArrayList<>();
final boolean shouldFilter =
CLALibUtils.shouldPreFilter(colGroups);
@@ -169,11 +216,25 @@ public final class CLALibRightMultBy {
MatrixBlock ret = new MatrixBlock(rl, cr, false);
final Future<MatrixBlock> f = ret.allocateBlockAsync();
- double[] constV = shouldFilter ? new double[rr] : null;
- final List<AColGroup> filteredGroups =
CLALibUtils.filterGroups(colGroups, constV);
- if(colGroups == filteredGroups)
+ double[] constV;
+ final List<AColGroup> filteredGroups;
+
+ if(shouldFilter) {
+ if(CLALibUtils.alreadyPreFiltered(colGroups, cr)) {
+ filteredGroups = new
ArrayList<>(colGroups.size() - 1);
+ constV =
CLALibUtils.filterGroupsAndSplitPreAggOneConst(colGroups, filteredGroups);
+ }
+ else {
+ constV = new double[rr];
+ filteredGroups =
CLALibUtils.filterGroups(colGroups, constV);
+ }
+ }
+ else {
+ filteredGroups = colGroups;
constV = null;
+ }
+ final List<AColGroup> retCg = new
ArrayList<>(filteredGroups.size());
if(k == 1)
RMMSingle(filteredGroups, that, retCg);
else
@@ -186,21 +247,12 @@ public final class CLALibRightMultBy {
constV = mmTemp.isEmpty() ? null :
mmTemp.getDenseBlockValues();
}
- ret = asyncRet(f);
+ ret = f.get();
CLALibDecompress.decompressDense(ret, retCg, constV, 0, k,
true);
return ret;
}
- private static <T> T asyncRet(Future<T> in) {
- try {
- return in.get();
- }
- catch(Exception e) {
- throw new DMLRuntimeException(e);
- }
- }
-
private static boolean RMMSingle(List<AColGroup> filteredGroups,
MatrixBlock that, List<AColGroup> retCg) {
boolean containsNull = false;
final IColIndex allCols =
ColIndexFactory.create(that.getNumColumns());
@@ -214,7 +266,8 @@ public final class CLALibRightMultBy {
return containsNull;
}
- private static boolean RMMParallel(List<AColGroup> filteredGroups,
MatrixBlock that, List<AColGroup> retCg, int k) {
+ private static boolean RMMParallel(List<AColGroup> filteredGroups,
MatrixBlock that, List<AColGroup> retCg, int k)
+ throws Exception {
final ExecutorService pool = CommonThreadPool.get(k);
boolean containsNull = false;
try {
@@ -230,10 +283,7 @@ public final class CLALibRightMultBy {
containsNull = true;
}
}
- catch(InterruptedException | ExecutionException e) {
- throw new DMLRuntimeException(e);
- }
- finally{
+ finally {
pool.shutdown();
}
return containsNull;
@@ -253,13 +303,8 @@ public final class CLALibRightMultBy {
}
@Override
- public AColGroup call() {
- try {
- return _colGroup.rightMultByMatrix(_b,
_allCols, _k);
- }
- catch(Exception e) {
- throw new DMLRuntimeException(e);
- }
+ public AColGroup call() throws Exception {
+ return _colGroup.rightMultByMatrix(_b, _allCols, _k);
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java
index 5f5e63c9ac..a1d47a9b15 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java
@@ -52,8 +52,15 @@ public final class CLALibTSMM {
* @param k The parallelization degree allowed
*/
public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb,
MatrixBlock ret, int k) {
+
final List<AColGroup> groups = cmb.getColGroups();
+
final int numColumns = cmb.getNumColumns();
+ if(groups.size() >= numColumns) {
+ MatrixBlock m = cmb.getUncompressed("TSMM to many
columngroups", k);
+ LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k);
+ return;
+ }
final int numRows = cmb.getNumRows();
final boolean shouldFilter =
CLALibUtils.shouldPreFilter(groups);
final boolean overlapping = cmb.isOverlapping();
@@ -63,8 +70,10 @@ public final class CLALibTSMM {
tsmmColGroups(filteredGroups, ret, numRows,
overlapping, k);
addCorrectionLayer(filteredGroups, ret, numRows,
numColumns, constV);
}
- else
+ else {
+
tsmmColGroups(groups, ret, numRows, overlapping, k);
+ }
ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret));
ret.examSparsity();
@@ -77,10 +86,7 @@ public final class CLALibTSMM {
addCorrectionLayer(constV, filteredColSum, nRows, retV);
}
- public static void addCorrectionLayer(double[] constV, double[]
correctedSum, int nRow, double[] ret) {
- outerProductUpperTriangle(constV, correctedSum, ret);
- outerProductUpperTriangleWithScaling(correctedSum, constV,
nRow, ret);
- }
+
private static void tsmmColGroups(List<AColGroup> groups, MatrixBlock
ret, int nRows, boolean overlapping, int k) {
if(k <= 1)
@@ -108,7 +114,7 @@ public final class CLALibTSMM {
}
private static void tsmmColGroupsMultiThread(List<AColGroup> groups,
MatrixBlock ret, int nRows, int k) {
- final ExecutorService pool = CommonThreadPool.get(k);
+ final ExecutorService pool = CommonThreadPool.get(k);
try {
final ArrayList<Callable<MatrixBlock>> tasks = new
ArrayList<>((groups.size() * (1 + groups.size())) / 2);
for(int i = 0; i < groups.size(); i++) {
@@ -123,31 +129,19 @@ public final class CLALibTSMM {
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
- finally{
+ finally {
pool.shutdown();
}
}
- private static void outerProductUpperTriangle(final double[]
leftRowSum, final double[] rightColumnSum,
- final double[] result) {
- for(int row = 0; row < leftRowSum.length; row++) {
- final int offOut = rightColumnSum.length * row;
- final double vLeft = leftRowSum[row];
- for(int col = row; col < rightColumnSum.length; col++) {
- result[offOut + col] += vLeft *
rightColumnSum[col];
- }
- }
- }
-
- private static void outerProductUpperTriangleWithScaling(final double[]
leftRowSum, final double[] rightColumnSum,
- final int scale, final double[] result) {
- // note this scaling is a bit different since it is
encapsulating two scalar multiplications via an addition in
- // the outer loop.
- for(int row = 0; row < leftRowSum.length; row++) {
- final int offOut = rightColumnSum.length * row;
- final double vLeft = leftRowSum[row] +
rightColumnSum[row] * scale;
- for(int col = row; col < rightColumnSum.length; col++) {
- result[offOut + col] += vLeft *
rightColumnSum[col];
+ public static void addCorrectionLayer(double[] constV, double[]
filteredColSum, int nRow, double[] ret) {
+ final int nColRow = constV.length;
+ for(int row = 0; row < nColRow; row++){
+ int offOut = nColRow * row;
+ final double v1l = constV[row];
+ final double v2l = filteredColSum[row] + constV[row] *
nRow;
+ for(int col = row; col < nColRow; col++){
+ ret[offOut + col] += v1l * filteredColSum[col]
+ v2l * constV[col];
}
}
}