This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 7440ef7ff1 [SYSTEMDS-3892] Initial out-of-core base instruction and parser 7440ef7ff1 is described below commit 7440ef7ff129542af8811854df9d8debec802dd1 Author: Jessica Priebe <jessica.pri...@web.de> AuthorDate: Sat Jul 12 01:36:58 2025 +0200 [SYSTEMDS-3892] Initial out-of-core base instruction and parser Closes #2289. --- src/main/java/org/apache/sysds/api/DMLOptions.java | 6 ++ src/main/java/org/apache/sysds/api/DMLScript.java | 3 + src/main/java/org/apache/sysds/common/Types.java | 4 +- src/main/java/org/apache/sysds/hops/Hop.java | 3 + .../sysds/runtime/instructions/Instruction.java | 6 +- .../runtime/instructions/InstructionParser.java | 5 + .../runtime/instructions/InstructionUtils.java | 5 + .../runtime/instructions/OOCInstructionParser.java | 56 +++++++++++ .../runtime/instructions/ooc/OOCInstruction.java | 85 +++++++++++++++++ .../functions/ooc/SumScalarMultiplicationTest.java | 103 +++++++++++++++++++++ .../functions/ooc/SumScalarMultiplication.dml | 24 +++++ 11 files changed, 298 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index 763ac7b938..97d5f54a4a 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -65,6 +65,7 @@ public class DMLOptions { public ExecMode execMode = OptimizerUtils.getDefaultExecutionMode(); // Execution mode standalone, MR, Spark or a hybrid public boolean gpu = false; // Whether to use the GPU public boolean forceGPU = false; // Whether to ignore memory & estimates and always use the GPU + public boolean ooc = false; // Whether to use the OOC backend public boolean debug = false; // to go into debug mode to be able to step through a program public String filePath = null; // path to script public String script = null; // the script itself @@ -109,6 +110,7 @@ public class DMLOptions { ", execMode=" + execMode + ", gpu=" + gpu + ", forceGPU=" + forceGPU + + ", ooc=" + ooc + ", debug=" + debug + ", filePath='" + filePath + '\'' + ", script='" + script + '\'' + @@ -182,6 +184,7 @@ public class DMLOptions { } } } + dmlOptions.ooc = line.hasOption("ooc"); if (line.hasOption("exec")){ String execMode = line.getOptionValue("exec"); if (execMode.equalsIgnoreCase("singlenode")) dmlOptions.execMode = ExecMode.SINGLE_NODE; @@ -388,6 +391,8 @@ public class DMLOptions { Option gpuOpt = OptionBuilder.withArgName("force") .withDescription("uses CUDA instructions when reasonable; set <force> option to skip conservative memory estimates and use GPU wherever possible; default off") .hasOptionalArg().create("gpu"); + Option oocOpt = OptionBuilder.withDescription("uses OOC backend") + .create("ooc"); Option debugOpt = OptionBuilder.withDescription("runs in debug mode; default off") .create("debug"); Option pythonOpt = OptionBuilder @@ -441,6 +446,7 @@ public class DMLOptions { options.addOption(explainOpt); options.addOption(execOpt); options.addOption(gpuOpt); + options.addOption(oocOpt); options.addOption(debugOpt); options.addOption(lineageOpt); options.addOption(fedOpt); diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index d6853891e2..2bc8d3b816 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -147,6 +147,8 @@ public class DMLScript public static boolean FORCE_ACCELERATOR = DMLOptions.defaultOptions.forceGPU; // Enable synchronizing GPU after every instruction public static boolean SYNCHRONIZE_GPU = true; + // Set OOC backend + public static boolean USE_OOC = DMLOptions.defaultOptions.ooc; // Enable eager CUDA free on rmvar public static boolean EAGER_CUDA_FREE = false; @@ -266,6 +268,7 @@ public class DMLScript JMLC_MEM_STATISTICS = dmlOptions.memStats; USE_ACCELERATOR = dmlOptions.gpu; FORCE_ACCELERATOR = dmlOptions.forceGPU; + USE_OOC = dmlOptions.ooc; EXPLAIN = dmlOptions.explainType; EXEC_MODE = dmlOptions.execMode; LINEAGE = dmlOptions.lineage; diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index e69ad375b2..c5ad9ded2b 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -32,7 +32,7 @@ public interface Types { * Execution mode for entire script. This setting specify which {@link ExecType}s are allowed. */ public enum ExecMode { - /** Execute all operations in {@link ExecType#CP} and if available {@link ExecType#GPU} */ + /** Execute all operations in {@link ExecType#CP}, {@link ExecType#OOC} and if available {@link ExecType#GPU} */ SINGLE_NODE, /** * The default and encouraged ExecMode. Execute operations while leveraging all available options: @@ -58,6 +58,8 @@ public interface Types { GPU, /** FED: indicate that the instruction should be executed as a Federated instruction */ FED, + /** Out of Core: indicate that the operation should be executed out of core. */ + OOC, /** invalid is used for debugging or if it is undecided where the current instruction should be executed */ INVALID } diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index b32a1a74aa..68e5bc94c0 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -263,6 +263,9 @@ public abstract class Hop implements ParseInfo { if(_etypeForced != ExecType.CP && _etypeForced != ExecType.GPU) _etypeForced = ExecType.CP; } + else if (DMLScript.USE_OOC){ + _etypeForced = ExecType.OOC; + } else { // enabled with -exec singlenode option _etypeForced = ExecType.CP; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java index 50238aadd8..6d27df34ed 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java @@ -35,7 +35,8 @@ public abstract class Instruction BREAKPOINT, SPARK, GPU, - FEDERATED + FEDERATED, + OUT_OF_CORE } protected static final Log LOG = LogFactory.getLog(Instruction.class.getName()); @@ -53,6 +54,7 @@ public abstract class Instruction public static final String SP_INST_PREFIX = "sp_"; public static final String GPU_INST_PREFIX = "gpu_"; public static final String FEDERATED_INST_PREFIX = "fed_"; + public static final String OOC_INST_PREFIX = "ooc_"; //basic instruction meta data protected String instString = null; @@ -184,6 +186,8 @@ public abstract class Instruction extendedOpcode = GPU_INST_PREFIX + getOpcode(); else if( getType() == IType.FEDERATED) extendedOpcode = FEDERATED_INST_PREFIX + getOpcode(); + else if( getType() == IType.OUT_OF_CORE) + extendedOpcode = OOC_INST_PREFIX + getOpcode(); else extendedOpcode = getOpcode(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java index fbe7c1d757..85ab05cf34 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java @@ -53,6 +53,11 @@ public class InstructionParser if( fedtype == null ) throw new DMLRuntimeException("Unknown FEDERATED instruction: " + str); return FEDInstructionParser.parseSingleInstruction (fedtype, str); + case OOC: + InstructionType ooctype = InstructionUtils.getOOCType(str); + if( ooctype == null ) + throw new DMLRuntimeException("Unknown OOC instruction: " + str); + return OOCInstructionParser.parseSingleInstruction (ooctype, str); default: throw new DMLRuntimeException("Unknown execution type in instruction: " + str); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index 3c1cf9d775..e244e9cd27 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -281,6 +281,11 @@ public class InstructionUtils { return Opcodes.getTypeByOpcode(op, Types.ExecType.FED); } + public static InstructionType getOOCType(String str) { + String op = getOpCode(str); + return Opcodes.getTypeByOpcode(op, Types.ExecType.OOC); + } + public static boolean isBuiltinFunction( String opcode ) { Builtin.BuiltinCode bfc = Builtin.String2BuiltinCode.get(opcode); return (bfc != null); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java new file mode 100644 index 0000000000..191976f094 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.InstructionType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; + +public class OOCInstructionParser extends InstructionParser { + protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); + + public static OOCInstruction parseSingleInstruction(String str) { + if(str == null || str.isEmpty()) + return null; + InstructionType ooctype = InstructionUtils.getOOCType(str); + if(ooctype == null) + throw new DMLRuntimeException("Unable derive ooctype for instruction: " + str); + OOCInstruction oocinst = parseSingleInstruction(ooctype, str); + if(oocinst == null) + throw new DMLRuntimeException("Unable to parse instruction: " + str); + return oocinst; + } + + public static OOCInstruction parseSingleInstruction(InstructionType ooctype, String str) { + if(str == null || str.isEmpty()) + return null; + switch(ooctype) { + + // TODO: + case AggregateUnary: + case Binary: + + default: + throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java new file mode 100644 index 0000000000..83cc972135 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.matrix.operators.Operator; + +public abstract class OOCInstruction extends Instruction { + protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); + + public enum OOCType { + AggregateUnary, Binary + } + + protected final OOCInstruction.OOCType _ooctype; + protected final boolean _requiresLabelUpdate; + + protected OOCInstruction(OOCInstruction.OOCType type, String opcode, String istr) { + this(type, null, opcode, istr); + } + + protected OOCInstruction(OOCInstruction.OOCType type, Operator op, String opcode, String istr) { + super(op); + _ooctype = type; + instString = istr; + instOpcode = opcode; + + _requiresLabelUpdate = super.requiresLabelUpdate(); + } + + @Override + public IType getType() { + return IType.OUT_OF_CORE; + } + + public OOCInstruction.OOCType getOOCInstructionType() { + return _ooctype; + } + + @Override + public boolean requiresLabelUpdate() { + return _requiresLabelUpdate; + } + + @Override + public String getGraphString() { + return getOpcode(); + } + + @Override + public Instruction preprocessInstruction(ExecutionContext ec) { + // TODO + return super.preprocessInstruction(ec); + } + + @Override + public abstract void processInstruction(ExecutionContext ec); + + @Override + public void postprocessInstruction(ExecutionContext ec) { + if(DMLScript.LINEAGE_DEBUGGER) + ec.maintainLineageDebuggerInfo(this); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java new file mode 100644 index 0000000000..d9d42c913b --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import java.util.HashMap; + +public class SumScalarMultiplicationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "SumScalarMultiplication"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + SumScalarMultiplicationTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME); + addTestConfiguration(TEST_NAME, config); + } + + /** + * Test the sum of scalar multiplication, "sum(X*7)", with OOC backend. + */ + @Test + @Ignore + public void testSumScalarMult() { + + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + int rows = 3; + int cols = 4; + double sparsity = 0.8; + + double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsity, 7); + writeInputMatrixWithMTD(INPUT_NAME, X, true); + + runTest(true, false, null, -1); + + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME); + // only one entry + Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1)); + + double expected = 0.0; + for(int i = 0; i < rows; i++) { + for(int j = 0; j < cols; j++) { + expected += X[i][j] * 7; + } + } + + Assert.assertEquals(expected, result, 1e-10); + + String prefix = Instruction.OOC_INST_PREFIX; + + boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT); + Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult); + + boolean usedOOCSum = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP); + Assert.assertTrue("OOC wasn't used for SUM", usedOOCSum); + + } + finally { + // reset + rtplatform = platformOld; + } + } +} diff --git a/src/test/scripts/functions/ooc/SumScalarMultiplication.dml b/src/test/scripts/functions/ooc/SumScalarMultiplication.dml new file mode 100644 index 0000000000..d8cf9c8494 --- /dev/null +++ b/src/test/scripts/functions/ooc/SumScalarMultiplication.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +res = as.matrix(sum(X*7)) +write(res, $2);