This is an automated email from the ASF dual-hosted git repository.

janniklinde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a3400ee59a [SYSTEMDS-3891] Add OOC cbind
a3400ee59a is described below

commit a3400ee59ac45953aa20f83084e63aa9d9b340de
Author: jessicapriebe <[email protected]>
AuthorDate: Tue Mar 10 12:20:56 2026 +0100

    [SYSTEMDS-3891] Add OOC cbind
    
    Closes #2437
---
 .../runtime/instructions/OOCInstructionParser.java |   3 +
 .../instructions/ooc/AppendOOCInstruction.java     | 205 +++++++++++++++++++++
 .../runtime/instructions/ooc/CachingStream.java    |   3 +-
 .../runtime/instructions/ooc/OOCInstruction.java   |   2 +-
 .../instructions/ooc/SubscribableTaskQueue.java    |  15 +-
 .../apache/sysds/runtime/io/WriterBinaryBlock.java |  11 +-
 .../runtime/io/WriterBinaryBlockParallel.java      |  10 +
 .../sysds/runtime/ooc/stream/MergedOOCStream.java  |  11 ++
 .../apache/sysds/runtime/ooc/util/OOCUtils.java    |  10 +
 .../apache/sysds/test/functions/ooc/CBindTest.java | 156 ++++++++++++++++
 src/test/scripts/functions/ooc/CBindTest.dml       |  26 +++
 11 files changed, 445 insertions(+), 7 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index 21816e90ad..affda5910d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -42,6 +42,7 @@ import 
org.apache.sysds.runtime.instructions.ooc.MMultOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction;
 
 public class OOCInstructionParser extends InstructionParser {
        protected static final Log LOG = 
LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -108,6 +109,8 @@ public class OOCInstructionParser extends InstructionParser 
{
                                return 
IndexingOOCInstruction.parseInstruction(str);
                        case Rand:
                                return 
DataGenOOCInstruction.parseInstruction(str);
+                       case Append:
+                               return 
AppendOOCInstruction.parseInstruction(str);
 
                        default:
                                throw new DMLRuntimeException("Invalid OOC 
Instruction Type: " + ooctype);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java
new file mode 100644
index 0000000000..2c0e24523c
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+
+public class AppendOOCInstruction extends BinaryOOCInstruction {
+
+       public enum AppendType {
+               CBIND
+       }
+
+       protected final AppendType _type;
+
+       protected AppendOOCInstruction(Operator op, CPOperand in1, CPOperand 
in2, CPOperand out, AppendType type,
+               String opcode, String istr) {
+               super(OOCType.Append, op, in1, in2, out, opcode, istr);
+               _type = type;
+       }
+
+       public static AppendOOCInstruction parseInstruction(String str) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               InstructionUtils.checkNumFields(parts, 5, 4);
+
+               String opcode = parts[0];
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand out = new CPOperand(parts[parts.length-2]);
+               boolean cbind = Boolean.parseBoolean(parts[parts.length-1]);
+
+               if(in1.getDataType() != Types.DataType.MATRIX || 
in2.getDataType() != Types.DataType.MATRIX || !cbind){
+                       throw new DMLRuntimeException("Only matrix-matrix cbind 
is supported");
+               }
+               AppendType type = AppendType.CBIND;
+
+               Operator op = new 
ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1));
+               return new AppendOOCInstruction(op, in1, in2, out, type, 
opcode, str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject in1 = ec.getMatrixObject(input1);
+               MatrixObject in2 = ec.getMatrixObject(input2);
+               validateInput(in1, in2);
+               if(handleZeroDims(in1, in2, ec))
+                       return;
+
+               OOCStream<IndexedMatrixValue> qIn1 = in1.getStreamHandle();
+               OOCStream<IndexedMatrixValue> qIn2 = in2.getStreamHandle();
+
+               int blksize = in1.getBlocksize();
+               int rem1 = (int) in1.getNumColumns()%blksize;
+               int rem2 = (int) in2.getNumColumns()%blksize;
+               int cblk1 = (int) 
in1.getDataCharacteristics().getNumColBlocks();
+               int cblk2 = (int) 
in2.getDataCharacteristics().getNumColBlocks();
+               int cblkRes = (int) 
Math.ceil((double)(in1.getNumColumns()+in2.getNumColumns())/blksize);
+
+               if(rem1==0){
+                       // no shifting needed
+                       OOCStream<IndexedMatrixValue> out = new 
SubscribableTaskQueue<>();
+                       mapOOC(qIn2, out, imv -> new IndexedMatrixValue(
+                               new 
MatrixIndexes(imv.getIndexes().getRowIndex(), 
cblk1+imv.getIndexes().getColumnIndex()), imv.getValue()));
+
+                       
ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(List.of(qIn1, out)));
+                       return;
+               }
+
+               List<OOCStream<IndexedMatrixValue>> split1 = 
splitOOCStream(qIn1, imv -> imv.getIndexes().getColumnIndex()==cblk1? 1 : 0, 2);
+               List<OOCStream<IndexedMatrixValue>> split2 = 
splitOOCStream(qIn2, imv -> (int) imv.getIndexes().getColumnIndex()-1, cblk2);
+
+               OOCStream<IndexedMatrixValue> head = split1.get(0);
+               OOCStream<IndexedMatrixValue> lastCol = split1.get(1);
+               OOCStream<IndexedMatrixValue> firstCol = split2.get(0);
+
+               CachingStream firstColCache = new CachingStream(firstCol);
+               OOCStream<IndexedMatrixValue> firstColForCritical = 
firstColCache.getReadStream();
+               OOCStream<IndexedMatrixValue> firstColForTail = 
firstColCache.getReadStream();
+
+               SubscribableTaskQueue<IndexedMatrixValue> out = new 
SubscribableTaskQueue<>();
+               Function<IndexedMatrixValue, MatrixIndexes> rowKey = imv -> new 
MatrixIndexes(imv.getIndexes().getRowIndex(), 1);
+
+               int fullRem2 = rem2==0? blksize : rem2;
+               // combine cols both matrices
+               joinOOC(lastCol, firstColForCritical, out, (left, right) -> {
+                       MatrixBlock lb = (MatrixBlock) left.getValue();
+                       MatrixBlock rb = (MatrixBlock) right.getValue();
+                       int stop = cblk2==1 && blksize-rem1>fullRem2? fullRem2 
: blksize-rem1;
+                       MatrixBlock combined = cbindBlocks(lb, sliceCols(rb, 0, 
stop));
+                       return new IndexedMatrixValue(
+                               new 
MatrixIndexes(left.getIndexes().getRowIndex(), 
left.getIndexes().getColumnIndex()), combined);
+               }, rowKey);
+
+               List<OOCStream<IndexedMatrixValue>> outStreams = new 
ArrayList<>();
+               outStreams.add(head);
+               outStreams.add(out);
+
+               // shift cols second matrix
+               OOCStream<IndexedMatrixValue> fst = firstColForTail;
+               OOCStream<IndexedMatrixValue> sec = null;
+               for(int i=0; i<cblk2-1; i++){
+                       out = new SubscribableTaskQueue<>();
+                       CachingStream secCachingStream = new 
CachingStream(split2.get(i+1));
+                       sec = secCachingStream.getReadStream();
+
+                       int finalI = i;
+                       joinOOC(fst, sec, out, (left, right) -> {
+                               MatrixBlock lb = (MatrixBlock) left.getValue();
+                               MatrixBlock rb = (MatrixBlock) right.getValue();
+                               int stop = finalI+2==cblk2 && 
blksize-rem1>fullRem2? fullRem2 : blksize-rem1;
+                               MatrixBlock combined = 
cbindBlocks(sliceCols(lb, blksize-rem1, blksize), sliceCols(rb, 0, stop));
+                               return new IndexedMatrixValue(
+                                       new 
MatrixIndexes(left.getIndexes().getRowIndex(), cblk1 + 
left.getIndexes().getColumnIndex()),
+                                       combined);
+                       }, rowKey);
+
+                       fst = secCachingStream.getReadStream();
+                       outStreams.add(out);
+               }
+
+               if(cblk1+cblk2==cblkRes){
+                       // overflow
+                       int remSize = (rem1+rem2)%blksize;
+                       out = new SubscribableTaskQueue<>();
+                       mapOOC(fst, out, imv -> new IndexedMatrixValue(
+                               new 
MatrixIndexes(imv.getIndexes().getRowIndex(), 
cblk1+imv.getIndexes().getColumnIndex()), 
+                               sliceCols((MatrixBlock) imv.getValue(), 
fullRem2-remSize, fullRem2)));
+
+                       outStreams.add(out);
+               }
+               
ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(outStreams));
+       }
+
+       public AppendType getAppendType() {
+               return _type;
+       }
+
+       private void validateInput(MatrixObject m1, MatrixObject m2) {
+               if(_type == AppendType.CBIND && m1.getNumRows() != 
m2.getNumRows()) {
+                       throw new DMLRuntimeException(
+                               "Append-cbind is not possible for input 
matrices " + input1.getName() + " and " + input2.getName()
+                                       + " with different number of rows: " + 
m1.getNumRows() + " vs " + m2.getNumRows());
+               }
+       }
+
+       private boolean handleZeroDims(MatrixObject m1, MatrixObject m2, 
ExecutionContext ec) {
+               long rows = m1.getNumRows();
+               long cols1 = m1.getNumColumns();
+               long cols2 = m2.getNumColumns();
+               if(rows == 0 || (cols1 == 0 && cols2 == 0)) {
+                       OOCStream<IndexedMatrixValue> empty = 
createWritableStream();
+                       empty.closeInput();
+                       ec.getMatrixObject(output).setStreamHandle(empty);
+               }
+               else if(cols1 == 0) {
+                       
ec.getMatrixObject(output).setStreamHandle(m2.getStreamHandle());
+               }
+               else if(cols2 == 0) {
+                       
ec.getMatrixObject(output).setStreamHandle(m1.getStreamHandle());
+               }
+               else return false;
+
+               return true;
+       }
+
+       private static MatrixBlock sliceCols(MatrixBlock in, int colStart, int 
colEndExclusive) {
+               // slice is inclusive
+               return in.slice(0, in.getNumRows()-1, colStart, 
colEndExclusive-1);
+       }
+
+       private static MatrixBlock cbindBlocks(MatrixBlock left, MatrixBlock 
right) {
+               return left.append(right, new MatrixBlock());
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
index d8eac50fcb..b3f5e57aaf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -33,6 +33,7 @@ import org.apache.sysds.runtime.ooc.cache.OOCCacheManager;
 import org.apache.sysds.runtime.ooc.stream.SourceOOCStream;
 import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage;
 import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage;
+import org.apache.sysds.runtime.ooc.util.OOCUtils;
 import org.apache.sysds.runtime.util.IndexRange;
 import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList;
 
@@ -469,7 +470,7 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
        private void validateBlockCountOnClose() {
                DataCharacteristics dc = _source.getDataCharacteristics();
                if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) {
-                       long expected = dc.getNumBlocks();
+                       long expected = OOCUtils.getNumBlocks(dc);
                        if (expected >= 0 && _numBlocks != expected) {
                                throw new DMLRuntimeException("CachingStream 
block count mismatch: expected "
                                        + expected + " but saw " + _numBlocks + 
" (" + dc.getRows() + "x" + dc.getCols() + ")");
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index 8f1f6fef24..2b90a02604 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction {
 
        public enum OOCType {
                Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, 
AggregateBinary, AggregateTernary, MAPMM, MMTSJ,
-               MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, 
ParameterizedBuiltin, Rand
+               MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, 
ParameterizedBuiltin, Rand, Append
        }
 
        protected final OOCInstruction.OOCType _ooctype;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
index e5c48decdd..ce79672831 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
@@ -25,6 +25,7 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage;
 import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage;
+import org.apache.sysds.runtime.ooc.util.OOCUtils;
 import org.apache.sysds.runtime.util.IndexRange;
 
 import java.util.LinkedList;
@@ -166,7 +167,7 @@ public class SubscribableTaskQueue<T> extends 
LocalTaskQueue<T> implements OOCSt
        private void validateBlockCountOnClose() {
                DataCharacteristics dc = getDataCharacteristics();
                if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) {
-                       long expected = dc.getNumBlocks();
+                       long expected = OOCUtils.getNumBlocks(dc);
                        if (expected >= 0 && _blockCount.get() != expected) {
                                throw new DMLRuntimeException("OOCStream block 
count mismatch: expected "
                                        + expected + " but saw " + 
_blockCount.get() + " (" + dc.getRows() + "x" + dc.getCols() + ")");
@@ -180,6 +181,7 @@ public class SubscribableTaskQueue<T> extends 
LocalTaskQueue<T> implements OOCSt
                        throw new IllegalArgumentException("Cannot set 
subscriber to null");
 
                LinkedList<T> data;
+               boolean needsEos;
 
                synchronized(this) {
                        if(_subscriber != null)
@@ -189,12 +191,20 @@ public class SubscribableTaskQueue<T> extends 
LocalTaskQueue<T> implements OOCSt
                                throw _failure;
                        data = _data;
                        _data = new LinkedList<>();
+                       // If this stream was already closed with no buffered 
data, no further
+                       // onDeliveryFinished() call will happen, so emit EOS 
immediately.
+                       needsEos = _closed.get() && data.isEmpty() && 
_availableCtr.get() == 0;
+                       if(needsEos)
+                               _availableCtr.incrementAndGet(); // route 
terminal emission via onDeliveryFinished
                }
 
                for (T t : data) {
                        subscriber.accept(new SimpleQueueCallback<>(t, 
_failure));
                        onDeliveryFinished();
                }
+
+               if(needsEos)
+                       onDeliveryFinished();
        }
 
        @SuppressWarnings("unchecked")
@@ -214,6 +224,9 @@ public class SubscribableTaskQueue<T> extends 
LocalTaskQueue<T> implements OOCSt
 
        @Override
        public synchronized void propagateFailure(DMLRuntimeException re) {
+               // Ignore late failures
+               if(_closed.get() && _availableCtr.get() == 0)
+                       return;
                super.propagateFailure(re);
                Consumer<QueueCallback<T>> s = _subscriber;
                if(s != null)
diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java 
b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java
index 69fd386c5e..a991237329 100644
--- a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java
@@ -97,10 +97,13 @@ public class WriterBinaryBlock extends MatrixWriter {
                FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
                final Writer writer = IOUtilFunctions.getSeqWriter(path, job, 
_replication);
                try {
-                       MatrixIndexes index = new MatrixIndexes(1, 1);
-                       MatrixBlock block = new MatrixBlock((int) 
Math.max(Math.min(rlen, blen), 1),
-                               (int) Math.max(Math.min(clen, blen), 1), true);
-                       writer.append(index, block);
+                       // For 0xN or Nx0, emit a valid sequence file header 
only (no blocks).
+                       if(rlen > 0 && clen > 0) {
+                               MatrixIndexes index = new MatrixIndexes(1, 1);
+                               MatrixBlock block = new MatrixBlock((int) 
Math.max(Math.min(rlen, blen), 1),
+                                       (int) Math.max(Math.min(clen, blen), 
1), true);
+                               writer.append(index, block);
+                       }
                }
                finally {
                        IOUtilFunctions.closeSilently(writer);
diff --git 
a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java 
b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java
index c00e58b7fa..88f7c0a690 100644
--- a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java
+++ b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java
@@ -95,6 +95,16 @@ public class WriterBinaryBlockParallel extends 
WriterBinaryBlock
        public long writeMatrixFromStream(String fname, 
OOCStream<IndexedMatrixValue> stream, long rlen, long clen, int blen)
                throws IOException {
                Path path = new Path(fname);
+
+               // For empty dimensions, no stream tiles are expected but the 
output must still exist.
+               if(rlen <= 0 || clen <= 0) {
+                       while(stream.dequeue() != LocalTaskQueue.NO_MORE_TASKS) 
{
+                               // Drain any unexpected records to keep stream 
producers unblocked.
+                       }
+                       writeEmptyMatrixToHDFS(fname, rlen, clen, blen);
+                       return 0;
+               }
+
                long nnz = -1;
                DataCharacteristics dc = stream.getDataCharacteristics();
                if(dc != null)
diff --git 
a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java 
b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java
index 907f78795a..8ea1384217 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java
@@ -85,6 +85,17 @@ public class MergedOOCStream<T> implements OOCStream<T> {
                                                if(_failed.get())
                                                        return;
 
+                                               if(cb instanceof 
OOCStream.GroupQueueCallback<?>) {
+                                                       
OOCStream.GroupQueueCallback<T> group = (OOCStream.GroupQueueCallback<T>) cb;
+                                                       for(int i = 0; i < 
group.size(); i++) {
+                                                               
OOCStream.QueueCallback<T> sub = group.getCallback(i);
+                                                               try(sub) {
+                                                                       
_taskQueue.enqueue(sub.keepOpen());
+                                                               }
+                                                       }
+                                                       return;
+                                               }
+
                                                
_taskQueue.enqueue(cb.keepOpen());
                                        }
                                }
diff --git a/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java 
b/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java
index c564748e1d..f33d17ea13 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.ooc.util;
 
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.IndexRange;
 
 import java.util.ArrayList;
@@ -60,4 +61,13 @@ public class OOCUtils {
                                list.add(new MatrixIndexes(r, c));
                return list;
        }
+
+       public static long getNumBlocks(DataCharacteristics dc) {
+               if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) {
+                       if(dc.getCols() == 0 || dc.getRows() == 0)
+                               return 0;
+                       return dc.getNumBlocks();
+               }
+               return -1;
+       }
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java
new file mode 100644
index 0000000000..3172585c6e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+
+@RunWith(Parameterized.class)
[email protected]
+public class CBindTest extends AutomatedTestBase {
+
+       private static final String TEST_NAME = "CBindTest";
+       private static final String TEST_DIR = "functions/ooc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
CBindTest.class.getSimpleName() + "/";
+
+       private final static double eps = 1e-8;
+       private static final String INPUT_NAME_1 = "A";
+       private static final String INPUT_NAME_2 = "B";
+       private static final String OUTPUT_NAME = "res";
+
+       private final int r1;
+       private final int c1;
+       private final int r2;
+       private final int c2;
+       private final int bsize;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+       }
+
+       public CBindTest(int r1, int c1, int r2, int c2, int bsize) {
+               this.r1 = r1;
+               this.c1 = c1;
+               this.r2 = r2;
+               this.c2 = c2;
+               this.bsize = bsize;
+       }
+
+       @Parameterized.Parameters(name = "{0}x{1} {2}x{3} bsize {4}")
+       public static Iterable<Object[]> getParams() {
+               int[] rows = new int[]{1000, 2000};
+               int[] cols = new int[]{300, 700, 2300, 2700, 3000, 3300};
+               int[] bsizes = new int[]{1000};
+
+               ArrayList<Object[]> params = new ArrayList<>();
+
+               for(int row : rows) {
+                       for(int col : cols) {
+                               for(int col2 : cols) {
+                                       for(int bsize : bsizes) {
+                                               params.add(new Object[] {row, 
col, row, col2, bsize});
+                                       }
+                               }
+                       }
+               }
+
+               params.add(new Object[] {10, 1000, 20, 1000, 1000});
+               params.add(new Object[] {0, 1000, 0, 1000, 1000});
+               params.add(new Object[] {1000, 0, 1000, 1000, 1000});
+               params.add(new Object[] {1000, 1000, 1000, 0, 1000});
+               params.add(new Object[] {1000, 0, 1000, 0, 1000});
+
+               return params;
+       }
+
+       @Test
+       public void runCBindTest() {
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+                       double[][] A = TestUtils.floor(getRandomMatrix(r1, c1, 
-1, 1, 1.0, 7));
+                       double[][] B = TestUtils.floor(getRandomMatrix(r2, c2, 
-1, 1, 1.0, 13));
+
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+                       
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(A), 
input(INPUT_NAME_1), r1, c1, bsize, r1*c1);
+                       
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(B), 
input(INPUT_NAME_2), r2, c2, bsize, r2*c2);
+
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(r1, c1, bsize, 
r1*c1), Types.FileFormat.BINARY);
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(r2, c2, bsize, 
r2*c2), Types.FileFormat.BINARY);
+
+
+                       programArgs = new String[] {"-explain", "-stats", 
"-ooc", "-args",
+                               input(INPUT_NAME_1), input(INPUT_NAME_2), 
output(OUTPUT_NAME)};
+
+                       if(r1 != r2){
+                               runTest(true,true, LanguageException.class,-1);
+                               return;
+                       }
+
+                       runTest(true, false, null, -1);
+                       Assert.assertTrue("OOC wasn't used for cbind",
+                               
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.APPEND));
+
+                       // rerun without ooc flag
+                       programArgs = new String[] {"-explain", "-stats", 
"-args",
+                               input(INPUT_NAME_1), input(INPUT_NAME_2), 
output(OUTPUT_NAME + "_target")};
+                       runTest(true, false, null, -1);
+
+                       // compare results
+                       MatrixBlock ret1 = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+                               Types.FileFormat.BINARY, r1, c1+c2, bsize);
+                       MatrixBlock ret2 = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
+                               Types.FileFormat.BINARY, r1, c1+c2, bsize);
+                       TestUtils.compareMatrices(ret1, ret2, eps);
+               }
+               catch(Exception ex) {
+                       Assert.fail(ex.getMessage());
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git a/src/test/scripts/functions/ooc/CBindTest.dml 
b/src/test/scripts/functions/ooc/CBindTest.dml
new file mode 100644
index 0000000000..edfbddafc0
--- /dev/null
+++ b/src/test/scripts/functions/ooc/CBindTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1)
+B = read($2)
+res = cbind(A, B)
+
+write(res, $3, format="binary");

Reply via email to