This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c4166647bb [SYSTEMDS-2864] Fix opcode merge conflicts and lineage bug
c4166647bb is described below

commit c4166647bb952df54625510015a7fa32bd4d20fb
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 1 14:08:24 2025 +0100

    [SYSTEMDS-2864] Fix opcode merge conflicts and lineage bug
    
    * revert the bad merge of the previous ctable modification
    * fix the handling of opcodes in the lineage program reconstruction
---
 .../instructions/cp/CtableCPInstruction.java       | 29 ++++++++++++++--------
 .../runtime/lineage/LineageRecomputeUtils.java     |  4 +--
 2 files changed, 20 insertions(+), 13 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
index 55d98481b6..4f508cd5b8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
@@ -31,6 +31,7 @@ import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.matrix.data.CTableMap;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.LongLongDoubleHashMap.EntryType;
@@ -40,21 +41,23 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
        private final CPOperand _outDim2;
        private final boolean _isExpand;
        private final boolean _ignoreZeros;
+       private final int _k;
 
        private CtableCPInstruction(CPOperand in1, CPOperand in2, CPOperand 
in3, CPOperand out,
                        String outputDim1, boolean dim1Literal, String 
outputDim2, boolean dim2Literal, boolean isExpand,
-                       boolean ignoreZeros, String opcode, String istr) {
+                       boolean ignoreZeros, String opcode, String istr, int k) 
{
                super(CPType.Ctable, null, in1, in2, in3, out, opcode, istr);
                _outDim1 = new CPOperand(outputDim1, ValueType.FP64, 
DataType.SCALAR, dim1Literal);
                _outDim2 = new CPOperand(outputDim2, ValueType.FP64, 
DataType.SCALAR, dim2Literal);
                _isExpand = isExpand;
                _ignoreZeros = ignoreZeros;
+               _k = k;
        }
 
        public static CtableCPInstruction parseInstruction(String inst)
        {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(inst);
-               InstructionUtils.checkNumFields ( parts, 7 );
+               InstructionUtils.checkNumFields ( parts, 8 );
                
                String opcode = parts[0];
                
@@ -76,8 +79,12 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                CPOperand out = new CPOperand(parts[6]);
                boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
 
+               int k = Integer.parseInt(parts[8]);
+               
                // ctable does not require any operator, so we simply pass-in a 
dummy operator with null functionobject
-               return new CtableCPInstruction(in1, in2, in3, out, 
dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], 
Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
+               return new CtableCPInstruction(in1, in2, in3, out, 
+                       dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), 
dim2Fields[0],
+                       Boolean.parseBoolean(dim2Fields[1]), isExpand, 
ignoreZeros, opcode, inst, k);
        }
 
        private Ctable.OperationTypes findCtableOperation() {
@@ -89,8 +96,8 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
        
        @Override
        public void processInstruction(ExecutionContext ec) {
-               MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
-               MatrixBlock matBlock2=null, wtBlock=null;
+               MatrixBlock matBlock1 = !_isExpand ? ec.getMatrixInput(input1): 
null;
+               MatrixBlock matBlock2 = null, wtBlock=null;
                double cst1, cst2;
                
                CTableMap resultMap = new CTableMap(EntryType.INT);
@@ -111,9 +118,6 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                        if( !sparse )
                                resultBlock = new MatrixBlock((int)outputDim1, 
(int)outputDim2, false); 
                }
-               if( _isExpand ){
-                       resultBlock = new MatrixBlock( matBlock1.getNumRows(), 
Integer.MAX_VALUE, true );
-               }
                
                switch(ctableOp) {
                        case CTABLE_TRANSFORM: //(VECTOR)
@@ -130,10 +134,13 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                                break;
                        case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR)
                                // F = ctable(seq,A) or F = ctable(seq,B,1)
+                               // ignore first argument
+                               if(input1.getDataType() == DataType.MATRIX){
+                                       LOG.warn("rewrite for table expand not 
activated please fix");
+                               }
                                matBlock2 = ec.getMatrixInput(input2.getName());
                                cst1 = 
ec.getScalarInput(input3).getDoubleValue();
-                               // only resultBlock.rlen known, 
resultBlock.clen set in operation
-                               matBlock1.ctableSeqOperations(matBlock2, cst1, 
resultBlock);
+                               resultBlock = 
LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, 
resultBlock, true, _k);
                                break;
                        case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
                                // F=ctable(A,1) or F = ctable(A,1,1)
@@ -152,7 +159,7 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                                throw new DMLRuntimeException("Encountered an 
invalid ctable operation ("+ctableOp+") while executing instruction: " + 
this.toString());
                }
                
-               if(input1.getDataType() == DataType.MATRIX)
+               if(input1.getDataType() == DataType.MATRIX && ctableOp != 
Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT)
                        ec.releaseMatrixInput(input1.getName());
                if(input2.getDataType() == DataType.MATRIX)
                        ec.releaseMatrixInput(input2.getName());
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index c3bd095aca..e64c742888 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -88,7 +88,7 @@ public class LineageRecomputeUtils {
        public static Data parseNComputeLineageTrace(String mainTrace) {
                if (DEBUG)
                        System.out.println(mainTrace);
-
+               
                // Separate the global trace and the dedup patches
                String[] patches = 
LineageParser.separateMainAndDedupPatches(mainTrace);
                LineageItem root = LineageParser.parseLineageTrace(patches[0]); 
//global trace
@@ -307,7 +307,7 @@ public class LineageRecomputeUtils {
                                break;
                        }
                        case Instruction: {
-                               CPType ctype = 
InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+                               CPType ctype = 
Opcodes.getCPTypeByOpcode(item.getOpcode());
                                SPType stype = 
InstructionUtils.getSPTypeByOpcode(item.getOpcode());
                                
                                if (ctype != null) {

Reply via email to