This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 12367cb9f4 [SYSTEMDS-3828] Parallel Compressed Replace
12367cb9f4 is described below
commit 12367cb9f4ba54d174779945b739a6af2a4968da
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Feb 3 15:03:07 2025 +0100
[SYSTEMDS-3828] Parallel Compressed Replace
This commit adds the parallel kernel for compressed
replace of values.
Closes #2209
---
.../runtime/compress/CompressedMatrixBlock.java | 52 +++-------
.../sysds/runtime/compress/lib/CLALibReplace.java | 108 +++++++++++++++++++++
.../cp/ParameterizedBuiltinCPInstruction.java | 4 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 8 +-
.../component/compress/CompressedCustomTests.java | 15 ++-
.../component/compress/CompressedMatrixTest.java | 32 ------
.../component/compress/CompressedTestBase.java | 40 +++++++-
7 files changed, 185 insertions(+), 74 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index a05c076b36..bee86addf2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -58,6 +58,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
+import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
@@ -307,7 +308,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
* @return The cached decompressed matrix, if it does not exist return
null
*/
public MatrixBlock getCachedDecompressed() {
- if( allowCachingUncompressed && decompressedVersion != null) {
+ if(allowCachingUncompressed && decompressedVersion != null) {
final MatrixBlock mb = decompressedVersion.get();
if(mb != null) {
DMLCompressionStatistics.addDecompressCacheCount();
@@ -401,8 +402,8 @@ public class CompressedMatrixBlock extends MatrixBlock {
long total = baseSizeInMemory();
// take into consideration duplicate dictionaries
Set<IDictionary> dicts = new HashSet<>();
- for(AColGroup grp : _colGroups){
- if(grp instanceof ADictBasedColGroup){
+ for(AColGroup grp : _colGroups) {
+ if(grp instanceof ADictBasedColGroup) {
IDictionary dg = ((ADictBasedColGroup)
grp).getDictionary();
if(dicts.contains(dg))
total -= dg.getInMemorySize();
@@ -576,8 +577,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
}
@Override
- public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock
w, MatrixBlock out, ChainType ctype,
- int k) {
+ public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock
w, MatrixBlock out, ChainType ctype, int k) {
checkMMChain(ctype, v, w);
// multi-threaded MMChain of single uncompressed ColGroup
@@ -629,27 +629,8 @@ public class CompressedMatrixBlock extends MatrixBlock {
}
@Override
- public MatrixBlock replaceOperations(MatrixValue result, double
pattern, double replacement) {
- if(Double.isInfinite(pattern)) {
- LOG.info("Ignoring replace infinite in compression
since it does not contain this value");
- return this;
- }
- else if(isOverlapping()) {
- final String message = "replaceOperations " + pattern +
" -> " + replacement;
- return
getUncompressed(message).replaceOperations(result, pattern, replacement);
- }
- else {
-
- CompressedMatrixBlock ret = new
CompressedMatrixBlock(getNumRows(), getNumColumns());
- final List<AColGroup> prev = getColGroups();
- final int colGroupsLength = prev.size();
- final List<AColGroup> retList = new
ArrayList<>(colGroupsLength);
- for(int i = 0; i < colGroupsLength; i++)
- retList.add(prev.get(i).replace(pattern,
replacement));
- ret.allocateColGroupList(retList);
- ret.recomputeNonZeros();
- return ret;
- }
+ public MatrixBlock replaceOperations(MatrixValue result, double
pattern, double replacement, int k) {
+ return CLALibReplace.replace(this, (MatrixBlock) result,
pattern, replacement, k);
}
@Override
@@ -710,10 +691,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
return false;
}
}
-
+
@Override
public boolean containsValue(double pattern, int k) {
- //TODO parallel contains value
+ // TODO parallel contains value
return containsValue(pattern);
}
@@ -775,8 +756,8 @@ public class CompressedMatrixBlock extends MatrixBlock {
return false;
else if(_colGroups == null || nonZeros == 0)
return true;
- else{
- if(nonZeros == -1){
+ else {
+ if(nonZeros == -1) {
// try to use column groups
for(AColGroup g : _colGroups)
if(!g.isEmpty())
@@ -1177,8 +1158,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
}
@Override
- public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i,
int rowoffset, int coloffset,
- boolean deep) {
+ public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i,
int rowoffset, int coloffset, boolean deep) {
throw new DMLCompressionException("Can't append row to
compressed Matrix");
}
@@ -1238,7 +1218,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
}
@Override
- public void denseToSparse(boolean allowCSR, int k){
+ public void denseToSparse(boolean allowCSR, int k) {
// do nothing
}
@@ -1327,13 +1307,13 @@ public class CompressedMatrixBlock extends MatrixBlock {
throw new DMLCompressionException("Invalid to allocate block on
a compressed MatrixBlock");
}
- @Override
+ @Override
public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
}
- @Override
- public MatrixBlock reshape(int rows,int cols, boolean byRow){
+ @Override
+ public MatrixBlock reshape(int rows, int cols, boolean byRow) {
return CLALibReshape.reshape(this, rows, cols, byRow);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java
new file mode 100644
index 0000000000..d86026d663
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+public class CLALibReplace {
+ private static final Log LOG =
LogFactory.getLog(CLALibReplace.class.getName());
+
+ private CLALibReplace(){
+ // private constructor
+ }
+
+ public static MatrixBlock replace(CompressedMatrixBlock in, MatrixBlock
out, double pattern, double replacement,
+ int k) {
+ try {
+
+ if(Double.isInfinite(pattern)) {
+ LOG.info("Ignoring replace infinite in
compression since it does not contain this value");
+ return in;
+ }
+ else if(in.isOverlapping()) {
+ final String message = "replaceOperations " +
pattern + " -> " + replacement;
+ return
in.getUncompressed(message).replaceOperations(out, pattern, replacement);
+ }
+ else
+ return replaceNormal(in, out, pattern,
replacement, k);
+ }
+ catch(Exception e) {
+ throw new RuntimeException("Failed replace pattern: " +
pattern + " replacement: " + replacement, e);
+ }
+ }
+
+ private static MatrixBlock replaceNormal(CompressedMatrixBlock in,
MatrixBlock out, double pattern,
+ double replacement, int k) throws Exception {
+ CompressedMatrixBlock ret = new
CompressedMatrixBlock(in.getNumRows(), in.getNumColumns());
+ final List<AColGroup> prev = in.getColGroups();
+ final int colGroupsLength = prev.size();
+ final List<AColGroup> retList = new
ArrayList<>(colGroupsLength);
+
+ if(k <= 1)
+ replaceSingleThread(pattern, replacement, prev,
colGroupsLength, retList);
+ else
+ replaceMultiThread(pattern, replacement, k, prev,
colGroupsLength, retList);
+
+ ret.allocateColGroupList(retList);
+ if(replacement == 0) // have to recompute!
+ ret.recomputeNonZeros();
+ else if(pattern == 0) // always fully dense.
+ ret.setNonZeros(((long) in.getNumRows()) *
in.getNumColumns());
+ else // same nonzeros as input
+ ret.setNonZeros(in.getNonZeros());
+ return ret;
+ }
+
+ private static void replaceMultiThread(double pattern, double
replacement, int k, final List<AColGroup> prev,
+ final int colGroupsLength, final List<AColGroup> retList)
throws InterruptedException, ExecutionException {
+ ExecutorService pool = CommonThreadPool.get(k);
+
+ try {
+ List<Future<AColGroup>> tasks = new
ArrayList<>(colGroupsLength);
+ for(int i = 0; i < colGroupsLength; i++) {
+ final int j = i;
+ tasks.add(pool.submit(() ->
prev.get(j).replace(pattern, replacement)));
+ }
+ for(int i = 0; i < colGroupsLength; i++) {
+ retList.add(tasks.get(i).get());
+ }
+ }
+ finally {
+ pool.shutdown();
+ }
+ }
+
+ private static void replaceSingleThread(double pattern, double
replacement, final List<AColGroup> prev,
+ final int colGroupsLength, final List<AColGroup> retList) {
+ for(int i = 0; i < colGroupsLength; i++)
+ retList.add(prev.get(i).replace(pattern, replacement));
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 2fb64b170d..119589a303 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -66,6 +66,7 @@ import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
import org.apache.sysds.runtime.util.AutoDiff;
import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction {
private static final Log LOG =
LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
@@ -276,7 +277,8 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
MatrixBlock target = targetObj.acquireRead();
double pattern =
Double.parseDouble(params.get("pattern"));
double replacement =
Double.parseDouble(params.get("replacement"));
- MatrixBlock ret = target.replaceOperations(new
MatrixBlock(), pattern, replacement);
+ MatrixBlock ret = target.replaceOperations(new
MatrixBlock(), pattern, replacement,
+
InfrastructureAnalyzer.getLocalParallelism());
if( ret == target ) //shallow copy (avoid
bufferpool pollution)
ec.setVariable(output.getName(),
targetObj);
else
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index c9086778f0..057811d2db 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5157,9 +5157,13 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
@Override
- public MatrixBlock replaceOperations(MatrixValue result, double
pattern, double replacement) {
+ public final MatrixBlock replaceOperations(MatrixValue result, double
pattern, double replacement) {
+ return replaceOperations(result, pattern, replacement, 1);
+ }
+
+ public MatrixBlock replaceOperations(MatrixValue result, double
pattern, double replacement, int k) {
MatrixBlock ret = checkType(result);
- return LibMatrixReplace.replaceOperations(this, ret, pattern,
replacement);
+ return LibMatrixReplace.replaceOperations(this, ret, pattern,
replacement, k);
}
public MatrixBlock extractTriangular(MatrixBlock ret, boolean lower,
boolean diag, boolean values) {
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
index 32d62fb16c..886198bb22 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
@@ -20,6 +20,7 @@
package org.apache.sysds.test.component.compress;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -38,6 +39,7 @@ import
org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
+import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
@@ -397,9 +399,18 @@ public class CompressedCustomTests {
TestUtils.compareMatricesBitAvgDistance(m1, m2, 0, 0, "no");
}
+ @Test(expected = Exception.class)
+ public void cbindWithError() {
+ CLALibCBind.cbind(null, new MatrixBlock[] {null}, 0);
+ }
@Test(expected = Exception.class)
- public void cbindWithError(){
- CLALibCBind.cbind(null, new MatrixBlock[]{null}, 0);
+ public void replaceWithError() {
+ CLALibReplace.replace(null, null, 0, 0, 10);
+ }
+
+ @Test
+ public void replaceInf() {
+ assertNull(CLALibReplace.replace(null, null,
Double.POSITIVE_INFINITY, 0, 10));
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
index 5de4967517..d36c6167cf 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
@@ -329,38 +329,6 @@ public class CompressedMatrixTest extends
AbstractCompressedUnaryTests {
}
}
- @Test
- public void testReplaceNotContainedValue() {
- double v = min - 1;
- if(v != 0)
- testReplace(v);
- }
-
- @Test
- public void testReplace() {
- if(min != 0)
- testReplace(min);
- }
-
- @Test
- public void testReplaceZero() {
- testReplace(0);
- }
-
- private void testReplace(double value) {
- try {
- if(!(cmb instanceof CompressedMatrixBlock) || rows *
cols > 10000)
- return;
- ucRet = mb.replaceOperations(ucRet, value, 1425);
- MatrixBlock ret2 = cmb.replaceOperations(new
MatrixBlock(), value, 1425);
- compareResultMatrices(ucRet, ret2, 1);
- }
- catch(Exception e) {
- e.printStackTrace();
- throw new DMLRuntimeException(e);
- }
- }
-
@Test
public void testCompressedMatrixConstruction() {
try {
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
index 507a2fc663..8692f56b69 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
@@ -1173,7 +1173,7 @@ public abstract class CompressedTestBase extends TestBase
{
}
catch(AssertionError e) {
e.printStackTrace();
- fail("failed Cbind: " + cmb.toString() );
+ fail("failed Cbind: " + cmb.toString());
}
}
@@ -1299,4 +1299,42 @@ public abstract class CompressedTestBase extends
TestBase {
return new
CompressionSettingsBuilder().setSeed(compressionSeed).setMinimumSampleSize(100);
}
+ @Test
+ public void testReplaceNotContainedValue() {
+ double v = min - 1;
+ if(v != 0)
+ testReplace(v, 132);
+ }
+
+ @Test
+ public void testReplace() {
+ if(min != 0)
+ testReplace(min, 323);
+ }
+
+ @Test
+ public void testReplaceWithZero() {
+ if(min != 0)
+ testReplace(min, 0);
+ }
+
+ @Test
+ public void testReplaceZero() {
+ testReplace(0, 3232);
+ }
+
+ private void testReplace(double value, double replacements) {
+ try {
+ if(!(cmb instanceof CompressedMatrixBlock) || rows *
cols > 10000)
+ return;
+ ucRet = mb.replaceOperations(ucRet, value,
replacements, _k);
+ MatrixBlock ret2 = cmb.replaceOperations(new
MatrixBlock(), value, replacements, _k);
+ compareResultMatrices(ucRet, ret2, 1);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ throw new DMLRuntimeException(e);
+ }
+ }
+
}