This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 0c7e0468e9 [SYSTEMDS-3904] New OOC matrix-vector multiplication 0c7e0468e9 is described below commit 0c7e0468e9a9ad43ebc8a80d16d42e71f03be2e0 Author: Janardhan Pulivarthi <j...@protonmail.com> AuthorDate: Sun Aug 10 10:35:00 2025 +0200 [SYSTEMDS-3904] New OOC matrix-vector multiplication Closes #2305. --- .../java/org/apache/sysds/hops/AggBinaryOp.java | 8 ++ .../runtime/instructions/OOCInstructionParser.java | 4 + .../ooc/MatrixVectorBinaryOOCInstruction.java | 142 +++++++++++++++++++++ .../runtime/instructions/ooc/OOCInstruction.java | 2 +- .../ooc/MatrixVectorBinaryMultiplicationTest.java | 132 +++++++++++++++++++ .../functions/ooc/MatrixVectorMultiplication.dml | 29 +++++ 6 files changed, 316 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index 0be3143206..fb20cc41d0 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -240,6 +240,14 @@ public class AggBinaryOp extends MultiThreadedHop { default: throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops."); } + } else if (et == ExecType.OOC) { + Lop in1 = getInput().get(0).constructLops(); + Lop in2 = getInput().get(1).constructLops(); + MatMultCP matmult = new MatMultCP(in1, in2, getDataType(), getValueType(), + et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads)); + setOutputDimensions(matmult); + setLineNumbers(matmult); + setLops(matmult); } } else throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops."); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index a744b5d813..9b1165b819 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -56,6 +57,9 @@ public class OOCInstructionParser extends InstructionParser { return UnaryOOCInstruction.parseInstruction(str); case Binary: return BinaryOOCInstruction.parseInstruction(str); + case AggregateBinary: + case MAPMM: + return MatrixVectorBinaryOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java new file mode 100644 index 0000000000..a36dc7c885 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.util.CommonThreadPool; + +public class MatrixVectorBinaryOOCInstruction extends ComputationOOCInstruction { + + + protected MatrixVectorBinaryOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { + super(type, op, in1, in2, out, opcode, istr); + } + + public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 4); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed) + CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory) + CPOperand out = new CPOperand(parts[3]); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + + return new MatrixVectorBinaryOOCInstruction(OOCType.MAPMM, ba, in1, in2, out, opcode, str); + } + + @Override + public void processInstruction( ExecutionContext ec ) { + // 1. Identify the inputs + MatrixObject min = ec.getMatrixObject(input1); // big matrix + MatrixBlock vin = ec.getMatrixObject(input2) + .acquireReadAndRelease(); // in-memory vector + + // 2. Pre-partition the in-memory vector into a hashmap + HashMap<Long, MatrixBlock> partitionedVector = new HashMap<>(); + int blksize = vin.getDataCharacteristics().getBlocksize(); + if (blksize < 0) + blksize = ConfigurationManager.getBlocksize(); + for (int i=0; i<vin.getNumRows(); i+=blksize) { + long key = (long) (i/blksize) + 1; // the key starts at 1 + int end_row = Math.min(i + blksize, vin.getNumRows()); + MatrixBlock vectorSlice = vin.slice(i, end_row - 1); + partitionedVector.put(key, vectorSlice); + } + + LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle(); + LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>(); + BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + ec.getMatrixObject(output).setStreamHandle(qOut); + + ExecutorService pool = CommonThreadPool.get(); + try { + // Core logic: background thread + pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + HashMap<Long, MatrixBlock> partialResults = new HashMap<>(); + while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); + long rowIndex = tmp.getIndexes().getRowIndex(); + long colIndex = tmp.getIndexes().getColumnIndex(); + MatrixBlock vectorSlice = partitionedVector.get(colIndex); + + // Now, call the operation with the correct, specific operator. + MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( + matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); + + // for single column block, no aggregation neeeded + if( min.getNumColumns() <= min.getBlocksize() ) { + qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); + } + else { + MatrixBlock currAgg = partialResults.get(rowIndex); + if (currAgg == null) + partialResults.put(rowIndex, partialResult); + else + currAgg.binaryOperationsInPlace(plus, partialResult); + } + } + + // emit aggregated blocks + if( min.getNumColumns() > min.getBlocksize() ) { + for (Map.Entry<Long, MatrixBlock> entry : partialResults.entrySet()) { + MatrixIndexes outIndexes = new MatrixIndexes(entry.getKey(), 1L); + qOut.enqueueTask(new IndexedMatrixValue(outIndexes, entry.getValue())); + } + } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + qOut.closeInput(); + } + }); + } catch (Exception e) { + throw new DMLRuntimeException(e); + } + finally { + pool.shutdown(); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index db3d2da8b1..d3c2dfcbd7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - Reblock, AggregateUnary, Binary, Unary + Reblock, AggregateUnary, Binary, Unary, MAPMM, AggregateBinary } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/MatrixVectorBinaryMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixVectorBinaryMultiplicationTest.java new file mode 100644 index 0000000000..de4e7e9912 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixVectorBinaryMultiplicationTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class MatrixVectorBinaryMultiplicationTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "MatrixVectorMultiplication"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + MatrixVectorBinaryMultiplicationTest.class.getSimpleName() + "/"; + private final static double eps = 1e-10; + private static final String INPUT_NAME = "X"; + private static final String INPUT_NAME2 = "v"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 5000; + private final static int cols_wide = 2000; + private final static int cols_skinny = 500; + + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testMVBinaryMultiplication1() { + runMatrixVectorMultiplicationTest(cols_wide, false); + } + + @Test + public void testMVBinaryMultiplication2() { + runMatrixVectorMultiplicationTest(cols_skinny, false); + } + + private void runMatrixVectorMultiplicationTest(int cols, boolean sparse ) + { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try + { + getAndLoadTestConfiguration(TEST_NAME1); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[]{"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 10); + double[][] x_data = getRandomMatrix(cols, 1, 0, 1, 1.0, 10); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + MatrixBlock x_mb = DataConverter.convertToMatrixBlock(x_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + // 5. Write vector x to a binary SequenceFile + writer.writeMatrixToHDFS(x_mb, input(INPUT_NAME2), cols, 1, 1000, x_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(cols, 1, 1000, x_mb.getNonZeros()), Types.FileFormat.BINARY); + + boolean exceptionExpected = false; + runTest(true, exceptionExpected, null, -1); + + double[][] C1 = readMatrix(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000); + double result = 0.0; + for(int i = 0; i < rows; i++) { // verify the results with Java + double expected = 0.0; + for(int j = 0; j < cols; j++) { + expected += A_mb.get(i, j) * x_mb.get(j,0); + } + result = C1[i][0]; + Assert.assertEquals(expected, result, eps); + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } + + private static double[][] readMatrix(String fname, Types.FileFormat fmt, long rows, long cols, int brows, int bcols ) + throws IOException + { + MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); + double[][] C = DataConverter.convertToDoubleMatrix(mb); + return C; + } +} diff --git a/src/test/scripts/functions/ooc/MatrixVectorMultiplication.dml b/src/test/scripts/functions/ooc/MatrixVectorMultiplication.dml new file mode 100644 index 0000000000..c72db07780 --- /dev/null +++ b/src/test/scripts/functions/ooc/MatrixVectorMultiplication.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read input matrix and operator from command line args +X = read($1); +v = read($2); + +# Operation under test +res = X %*% v; + +write(res, $3, format="binary")