This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 58361b2dde [SYSTEMDS-3893] Basic out-of-core binary-read and acquire primitive 58361b2dde is described below commit 58361b2dde9ab2e361b20f27e46a352a90003c1a Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sun Jul 13 12:46:15 2025 +0200 [SYSTEMDS-3893] Basic out-of-core binary-read and acquire primitive This patch introduces a basic integration of the out-of-core backend. For reading, we use a dedicated reblock instruction which creates a queue of blocks, spawns a thread for reading and immediately returns. In addition, we extended the acquireRead functionality to collect such streams of blocks whenever an operations requires the full matrix. Based on these foundations, we can now add other OCC operations that directly work with the input stream of blocks and produce either results or created modified output streams. --- .github/workflows/javaTests.yml | 2 +- src/main/java/org/apache/sysds/common/Opcodes.java | 24 ++-- .../java/org/apache/sysds/hops/AggUnaryOp.java | 5 +- src/main/java/org/apache/sysds/hops/BinaryOp.java | 3 + src/main/java/org/apache/sysds/hops/DataOp.java | 6 +- src/main/java/org/apache/sysds/hops/Hop.java | 4 +- .../hops/rewrite/RewriteBlockSizeAndReblock.java | 7 +- src/main/java/org/apache/sysds/lops/ReBlock.java | 4 +- .../controlprogram/caching/CacheableData.java | 22 +++- .../controlprogram/caching/FrameObject.java | 8 ++ .../controlprogram/caching/MatrixObject.java | 34 +++++- .../controlprogram/caching/TensorObject.java | 9 ++ .../runtime/instructions/OOCInstructionParser.java | 5 +- .../ooc/ComputationOOCInstruction.java | 48 ++++++++ .../runtime/instructions/ooc/OOCInstruction.java | 2 +- .../instructions/ooc/ReblockOOCInstruction.java | 123 +++++++++++++++++++++ .../org/apache/sysds/runtime/io/MatrixReader.java | 3 +- src/main/java/org/apache/sysds/utils/Explain.java | 6 +- .../functions/ooc/SumScalarMultiplicationTest.java | 53 +++++---- 19 files changed, 318 insertions(+), 50 deletions(-) diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml index d13b187fb2..c11f00ed4f 100644 --- a/.github/workflows/javaTests.yml +++ b/.github/workflows/javaTests.yml @@ -73,7 +73,7 @@ jobs: "**.functions.builtin.part1.**", "**.functions.builtin.part2.**", "**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.iogen.**", - "**.functions.dnn.**", + "**.functions.dnn.**,**.functions.ooc.**", "**.functions.paramserv.**", "**.functions.recompile.**,**.functions.misc.**", "**.functions.mlcontext.**", diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index a4081f9292..fd5c6bfd12 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -349,7 +349,7 @@ public enum Opcodes { MAPMIN("mapmin", InstructionType.Binary), //REBLOCK Instruction Opcodes - RBLK("rblk", null, InstructionType.Reblock), + RBLK("rblk", null, InstructionType.Reblock, null, InstructionType.Reblock), CSVRBLK("csvrblk", InstructionType.CSVReblock), LIBSVMRBLK("libsvmrblk", InstructionType.LIBSVMReblock), @@ -398,24 +398,23 @@ public enum Opcodes { // Constructors Opcodes(String name, InstructionType type) { - this._name = name; - this._type = type; - this._spType=null; - this._fedType=null; + this(name, type, null, null, null); } Opcodes(String name, InstructionType type, InstructionType spType){ - this._name=name; - this._type=type; - this._spType=spType; - this._fedType=null; + this(name, type, spType, null, null); } Opcodes(String name, InstructionType type, InstructionType spType, InstructionType fedType){ + this(name, type, spType, fedType, null); + } + + Opcodes(String name, InstructionType type, InstructionType spType, InstructionType fedType, InstructionType oocType){ this._name=name; this._type=type; this._spType=spType; this._fedType=fedType; + this._oocType=oocType; } // Fields @@ -423,6 +422,7 @@ public enum Opcodes { private final InstructionType _type; private final InstructionType _spType; private final InstructionType _fedType; + private final InstructionType _oocType; private static final Map<String, Opcodes> _lookupMap = new HashMap<>(); @@ -451,6 +451,10 @@ public enum Opcodes { public InstructionType getFedType(){ return _fedType != null ? _fedType : _type; } + + public InstructionType getOocType(){ + return _oocType != null ? _oocType : _type; + } public static InstructionType getTypeByOpcode(String opcode, Types.ExecType type) { if (opcode == null || opcode.trim().isEmpty()) { @@ -463,6 +467,8 @@ public enum Opcodes { return (op.getSpType() != null) ? op.getSpType() : op.getType(); case FED: return (op.getFedType() != null) ? op.getFedType() : op.getType(); + case OOC: + return (op.getOocType() != null) ? op.getOocType() : op.getType(); default: return op.getType(); } diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java index 0b2d62bbe3..2f5cb53acf 100644 --- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java @@ -116,7 +116,7 @@ public class AggUnaryOp extends MultiThreadedHop ExecType et = optFindExecType(); Hop input = getInput().get(0); - if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED ) + if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED || et == ExecType.OOC ) { Lop agg1 = null; if( isTernaryAggregateRewriteApplicable() ) { @@ -409,6 +409,9 @@ public class AggUnaryOp extends MultiThreadedHop else setRequiresRecompileIfNecessary(); + if( _etype == ExecType.OOC ) //TODO + setExecType(ExecType.CP); + return _etype; } diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index a3ddb45ea6..f433931a52 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -854,6 +854,9 @@ public class BinaryOp extends MultiThreadedHop { _etype = ExecType.CP; } + if( _etype == ExecType.OOC ) //TODO + setExecType(ExecType.CP); + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index 1ae8616001..eb0d1961cf 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -24,6 +24,7 @@ import java.util.Map.Entry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.OpOpData; @@ -465,6 +466,9 @@ public class DataOp extends Hop { } else //READ { + if( DMLScript.USE_OOC ) + checkAndSetForcedPlatform(); + //mark for recompile (forever) if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && letype==ExecType.SPARK && (_recompileRead || _requiresCheckpoint) ) @@ -473,7 +477,7 @@ public class DataOp extends Hop { } _etype = letype; - if ( _etypeForced == ExecType.FED ) + if ( _etypeForced == ExecType.FED || _etypeForced == ExecType.OOC ) _etype = _etypeForced; } diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 68e5bc94c0..86749d44c1 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -256,6 +256,8 @@ public abstract class Hop implements ParseInfo { { if(DMLScript.USE_ACCELERATOR && DMLScript.FORCE_ACCELERATOR && isGPUEnabled()) _etypeForced = ExecType.GPU; // enabled with -gpu force option + else if (DMLScript.USE_OOC) + _etypeForced = ExecType.OOC; else if ( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE && _etypeForced != ExecType.FED ) { if(OptimizerUtils.isMemoryBasedOptLevel() && DMLScript.USE_ACCELERATOR && isGPUEnabled()) { // enabled with -exec singlenode -gpu option @@ -406,7 +408,7 @@ public abstract class Hop implements ParseInfo { private void constructAndSetReblockLopIfRequired() { //determine execution type - ExecType et = ExecType.CP; + ExecType et = DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP; if( DMLScript.getGlobalExecMode() != ExecMode.SINGLE_NODE && !(getDataType()==DataType.SCALAR) ) { diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java index 4e03e02f62..4b5eaa8a9a 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java @@ -29,6 +29,7 @@ import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.HopsException; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.DataType; /** @@ -80,8 +81,12 @@ public class RewriteBlockSizeAndReblock extends HopRewriteRule { DataOp dop = (DataOp) hop; + if( DMLScript.USE_OOC && dop.getOp() == OpOpData.PERSISTENTREAD ) { + dop.setRequiresReblock(true); + dop.setBlocksize(blocksize); + } // if block size does not match - if( (dop.getDataType() == DataType.MATRIX && (dop.getBlocksize() != blocksize)) + else if( (dop.getDataType() == DataType.MATRIX && (dop.getBlocksize() != blocksize)) ||(dop.getDataType() == DataType.FRAME && OptimizerUtils.isSparkExecutionMode() && (dop.getFileFormat()==FileFormat.TEXT || dop.getFileFormat()==FileFormat.CSV)) ) { diff --git a/src/main/java/org/apache/sysds/lops/ReBlock.java b/src/main/java/org/apache/sysds/lops/ReBlock.java index 2e2c9dc2fd..d92144d63e 100644 --- a/src/main/java/org/apache/sysds/lops/ReBlock.java +++ b/src/main/java/org/apache/sysds/lops/ReBlock.java @@ -46,8 +46,8 @@ public class ReBlock extends Lop { _blocksize = blen; _outputEmptyBlocks = outputEmptyBlocks; - if(et == ExecType.SPARK) - lps.setProperties(inputs, ExecType.SPARK); + if(et == ExecType.SPARK || et == ExecType.OOC) + lps.setProperties(inputs, et); else throw new LopsException("Incorrect execution type for Reblock:" + et); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index eba22e7f15..e075b55e17 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -42,12 +42,14 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; import org.apache.sysds.runtime.instructions.gpu.context.GPUContext; import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.io.IOUtilFunctions; @@ -210,13 +212,15 @@ public abstract class CacheableData<T extends CacheBlock<?>> extends Data private boolean _requiresLocalWrite = false; //flag if local write for read obj private boolean _isAcquireFromEmpty = false; //flag if read from status empty - //spark-specific handles + //backend-specific handles //note: we use the abstraction of LineageObjects for two reasons: (1) to keep track of cleanup //for lazily evaluated RDDs, and (2) as abstraction for environments that do not necessarily have spark libraries available private RDDObject _rddHandle = null; //RDD handle private BroadcastObject<T> _bcHandle = null; //Broadcast handle protected HashMap<GPUContext, GPUObject> _gpuObjects = null; //Per GPUContext object allocated on GPU - + //TODO generalize for frames + private LocalTaskQueue<IndexedMatrixValue> _streamHandle = null; + private LineageItem _lineage = null; /** @@ -460,6 +464,10 @@ public abstract class CacheableData<T extends CacheBlock<?>> extends Data public boolean hasBroadcastHandle() { return _bcHandle != null && _bcHandle.hasBackReference(); } + + public LocalTaskQueue<IndexedMatrixValue> getStreamHandle() { + return _streamHandle; + } @SuppressWarnings({ "rawtypes", "unchecked" }) public void setBroadcastHandle( BroadcastObject bc ) { @@ -490,6 +498,10 @@ public abstract class CacheableData<T extends CacheBlock<?>> extends Data public synchronized void removeGPUObject(GPUContext gCtx) { _gpuObjects.remove(gCtx); } + + public synchronized void setStreamHandle(LocalTaskQueue<IndexedMatrixValue> q) { + _streamHandle = q; + } // ********************************************* // *** *** @@ -580,6 +592,9 @@ public abstract class CacheableData<T extends CacheBlock<?>> extends Data //mark for initial local write despite read operation _requiresLocalWrite = false; } + else if( getStreamHandle() != null ) { + _data = readBlobFromStream( getStreamHandle() ); + } else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) { if( DMLScript.STATISTICS ) CacheStatistics.incrementHDFSHits(); @@ -1099,6 +1114,9 @@ public abstract class CacheableData<T extends CacheBlock<?>> extends Data protected abstract T readBlobFromRDD(RDDObject rdd, MutableBoolean status) throws IOException; + protected abstract T readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) + throws IOException; + // Federated read protected T readBlobFromFederated(FederationMap fedMap) throws IOException { if( LOG.isDebugEnabled() ) //common if instructions keep federated outputs diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java index 582bb64dd8..56cc276cd8 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java @@ -33,7 +33,9 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.io.FrameReaderFactory; @@ -304,6 +306,12 @@ public class FrameObject extends CacheableData<FrameBlock> //lazy evaluation of pending transformations. SparkExecutionContext.writeFrameRDDtoHDFS(rdd, fname, iimd.getFileFormat()); } + + @Override + protected FrameBlock readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws IOException { + // TODO Auto-generated method stub + return null; + } @Override protected FrameBlock reconstructByLineage(LineageItem li) throws IOException { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java index f58b315e68..e9204bdaed 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java @@ -42,7 +42,9 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.io.ReaderWriterFederated; @@ -442,8 +444,10 @@ public class MatrixObject extends CacheableData<MatrixBlock> { // Read matrix and maintain meta data, // if the MatrixObject is federated there is nothing extra to read, and therefore only acquire read and release int blen = mc.getBlocksize() <= 0 ? ConfigurationManager.getBlocksize() : mc.getBlocksize(); - MatrixBlock newData = isFederated() ? acquireReadAndRelease() : DataConverter.readMatrixFromHDFS(fname, - iimd.getFileFormat(), rlen, clen, blen, mc.getNonZeros(), getFileFormatProperties()); + MatrixBlock newData = + isFederated() ? acquireReadAndRelease() : + DataConverter.readMatrixFromHDFS(fname, iimd.getFileFormat(), + rlen, clen, blen, mc.getNonZeros(), getFileFormatProperties()); if(iimd.getFileFormat() == FileFormat.CSV) { _metaData = _metaData instanceof MetaDataFormat ? new MetaDataFormat(newData.getDataCharacteristics(), @@ -518,6 +522,32 @@ public class MatrixObject extends CacheableData<MatrixBlock> { return mb; } + + + @Override + protected MatrixBlock readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws IOException { + MatrixBlock ret = new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false); + IndexedMatrixValue tmp = null; + try { + int blen = getBlocksize(), lnnz = 0; + while( (tmp = stream.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS ) { + // compute row/column block offsets + final int row_offset = (int) (tmp.getIndexes().getRowIndex() - 1) * blen; + final int col_offset = (int) (tmp.getIndexes().getColumnIndex() - 1) * blen; + + // Add the values of this block into the output block. + ((MatrixBlock)tmp.getValue()).putInto(ret, row_offset, col_offset, true); + + // incremental maintenance nnz + lnnz += tmp.getValue().getNonZeros(); + } + ret.setNonZeros(lnnz); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + return ret; + } @Override protected MatrixBlock readBlobFromFederated(FederationMap fedMap, long[] dims) throws IOException { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java index 8908f55d06..d39ed8c8a9 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java @@ -30,8 +30,10 @@ import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.data.TensorIndexes; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.lineage.LineageItem; @@ -199,6 +201,13 @@ public class TensorObject extends CacheableData<TensorBlock> { //TODO rdd write } + + @Override + protected TensorBlock readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws IOException { + // TODO Auto-generated method stub + return null; + } + @Override protected TensorBlock reconstructByLineage(LineageItem li) throws IOException { return ((TensorObject) LineageRecomputeUtils diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 191976f094..e0f84c5bd2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.InstructionType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -44,7 +45,9 @@ public class OOCInstructionParser extends InstructionParser { if(str == null || str.isEmpty()) return null; switch(ooctype) { - + case Reblock: + return ReblockOOCInstruction.parseInstruction(str); + // TODO: case AggregateUnary: case Binary: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java new file mode 100644 index 0000000000..5552017493 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.matrix.operators.Operator; + +public abstract class ComputationOOCInstruction extends OOCInstruction { + public CPOperand output; + public CPOperand input1, input2, input3; + + protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, String opcode, String istr) { + super(type, op, opcode, istr); + input1 = in1; + input2 = null; + input3 = null; + output = out; + } + + protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { + super(type, op, opcode, istr); + input1 = in1; + input2 = in2; + input3 = null; + output = out; + } + + public String getOutputVariableName() { + return output.getName(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 83cc972135..fe73e57fd2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - AggregateUnary, Binary + Reblock, AggregateUnary, Binary } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java new file mode 100644 index 0000000000..9a7059be51 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.concurrent.ExecutorService; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.mapred.JobConf; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.io.MatrixReader; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.util.CommonThreadPool; + +public class ReblockOOCInstruction extends ComputationOOCInstruction { + private int blen; + + private ReblockOOCInstruction(Operator op, CPOperand in, CPOperand out, + int br, int bc, String opcode, String instr) + { + super(OOCType.Reblock, op, in, out, opcode, instr); + blen = br; + blen = bc; + } + + public static ReblockOOCInstruction parseInstruction(String str) { + String parts[] = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + if(!opcode.equals(Opcodes.RBLK.toString())) + throw new DMLRuntimeException("Incorrect opcode for ReblockOOCInstruction:" + opcode); + + CPOperand in = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + int blen=Integer.parseInt(parts[3]); + return new ReblockOOCInstruction(null, in, out, blen, blen, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) { + //set the output characteristics + MatrixObject min = ec.getMatrixObject(input1); + DataCharacteristics mc = ec.getDataCharacteristics(input1.getName()); + DataCharacteristics mcOut = ec.getDataCharacteristics(output.getName()); + mcOut.set(mc.getRows(), mc.getCols(), blen, mc.getNonZeros()); + + //get the source format from the meta data + //MetaDataFormat iimd = (MetaDataFormat) min.getMetaData(); + //TODO support other formats than binary + + //create queue, spawn thread for asynchronous reading, and return + LocalTaskQueue<IndexedMatrixValue> q = new LocalTaskQueue<IndexedMatrixValue>(); + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(() -> readBinaryBlock(q, min.getFileName())); + } + finally { + pool.shutdown(); + } + + MatrixObject mout = ec.getMatrixObject(output); + mout.setStreamHandle(q); + } + + @SuppressWarnings("resource") + private void readBinaryBlock(LocalTaskQueue<IndexedMatrixValue> q, String fname) { + try { + //prepare file access + JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); + Path path = new Path( fname ); + FileSystem fs = IOUtilFunctions.getFileSystem(path, job); + + //check existence and non-empty file + MatrixReader.checkValidInputFile(fs, path); + + //core reading + for( Path lpath : IOUtilFunctions.getSequenceFilePaths(fs, path) ) { //1..N files + //directly read from sequence files (individual partfiles) + try( SequenceFile.Reader reader = new SequenceFile + .Reader(job, SequenceFile.Reader.file(lpath)) ) + { + MatrixIndexes key = new MatrixIndexes(); + MatrixBlock value = new MatrixBlock(); + while( reader.next(key, value) ) + q.enqueueTask(new IndexedMatrixValue(key, new MatrixBlock(value))); + } + } + q.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java index 893d665a24..245347e9cf 100644 --- a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java +++ b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java @@ -111,7 +111,7 @@ public abstract class MatrixReader return ret; } - protected static void checkValidInputFile(FileSystem fs, Path path) + public static void checkValidInputFile(FileSystem fs, Path path) throws IOException { //check non-existing file @@ -121,7 +121,6 @@ public abstract class MatrixReader //check for empty file if( HDFSTool.isFileEmpty(fs, path) ) throw new EOFException("Empty input file "+ path.toString() +"."); - } protected static void sortSparseRowsParallel(MatrixBlock dest, long rlen, int k, ExecutorService pool) diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java index bcd17ef7f0..7f5fb4c06f 100644 --- a/src/main/java/org/apache/sysds/utils/Explain.java +++ b/src/main/java/org/apache/sysds/utils/Explain.java @@ -62,6 +62,7 @@ import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.CPInstruction; import org.apache.sysds.runtime.instructions.fed.FEDInstruction; import org.apache.sysds.runtime.instructions.gpu.GPUInstruction; +import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.spark.CSVReblockSPInstruction; import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction; import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction; @@ -837,8 +838,9 @@ public class Explain private static String explainGenericInstruction( Instruction inst, int level ) { String tmp = null; - if ( inst instanceof SPInstruction || inst instanceof CPInstruction || inst instanceof GPUInstruction || - inst instanceof FEDInstruction ) + if ( inst instanceof SPInstruction || inst instanceof CPInstruction + || inst instanceof GPUInstruction || inst instanceof FEDInstruction + || inst instanceof OOCInstruction) tmp = inst.toString(); if( REPLACE_SPECIAL_CHARACTERS && tmp != null){ diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java index d9d42c913b..3681b74f83 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java @@ -21,14 +21,19 @@ package org.apache.sysds.test.functions.ooc; import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; -import org.apache.sysds.utils.Statistics; import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; import java.util.HashMap; @@ -52,7 +57,6 @@ public class SumScalarMultiplicationTest extends AutomatedTestBase { * Test the sum of scalar multiplication, "sum(X*7)", with OOC backend. */ @Test - @Ignore public void testSumScalarMult() { Types.ExecMode platformOld = rtplatform; @@ -62,42 +66,43 @@ public class SumScalarMultiplicationTest extends AutomatedTestBase { getAndLoadTestConfiguration(TEST_NAME); String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; - - int rows = 3; - int cols = 4; - double sparsity = 0.8; - - double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsity, 7); - writeInputMatrixWithMTD(INPUT_NAME, X, true); - + programArgs = new String[] {"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + int rows = 3500, cols = 4; + MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); + HDFSTool.writeMetaDataFile(input(INPUT_NAME+"mtd"), ValueType.FP64, + new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY); + runTest(true, false, null, -1); HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME); - // only one entry Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1)); - double expected = 0.0; for(int i = 0; i < rows; i++) { for(int j = 0; j < cols; j++) { - expected += X[i][j] * 7; + expected += mb.get(i, j) * 7; } } Assert.assertEquals(expected, result, 1e-10); String prefix = Instruction.OOC_INST_PREFIX; - - boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT); - Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult); - - boolean usedOOCSum = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP); - Assert.assertTrue("OOC wasn't used for SUM", usedOOCSum); - + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + +// boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT); +// Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult); +// boolean usedOOCSum = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP); +// Assert.assertTrue("OOC wasn't used for SUM", usedOOCSum); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); } finally { - // reset - rtplatform = platformOld; + resetExecMode(platformOld); } } }