This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 61afba5d0f [SYSTEMDS-3888] Fix size propagation over unique operations
61afba5d0f is described below

commit 61afba5d0ffe8e1cbf9f3e4b956d57f0bc3997b4
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Fri Jun 6 12:25:38 2025 +0200

    [SYSTEMDS-3888] Fix size propagation over unique operations
    
    This patch fixes the incorrect size propagation of unique which led
    to incorrect results if the dimensions are used in subsequent ops.
    Thanks to Chi-Hsin Huang for catching this bug.
    
    Furthermore, this patch also includes minor updates for code quality
    (removed unused imports, annotated unused functions)
---
 .../java/org/apache/sysds/hops/AggUnaryOp.java     | 36 +++++++++++++++++-----
 .../sysds/hops/estim/EstimatorLayeredGraph.java    |  8 +++--
 .../RewriteQuantizationFusedCompression.java       |  2 --
 .../ParameterizedBuiltinFunctionExpression.java    | 20 ++++++------
 .../test/functions/misc/SizePropagationTest.java   | 15 +++++++++
 .../functions/misc/SizePropagationUnique.dml       | 28 +++++++++++++++++
 6 files changed, 88 insertions(+), 21 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 954114a0a4..0b2d62bbe3 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -323,10 +323,18 @@ public class AggUnaryOp extends MultiThreadedHop
                DataCharacteristics ret = null;
                Hop input = getInput().get(0);
                DataCharacteristics dc = memo.getAllInputStats(input);
-               if( _direction == Direction.Col && dc.colsKnown() )
-                       ret = new MatrixCharacteristics(1, dc.getCols(), -1, 
-1);
-               else if( _direction == Direction.Row && dc.rowsKnown() )
-                       ret = new MatrixCharacteristics(dc.getRows(), 1, -1, 
-1);
+               if( _op == AggOp.UNIQUE ) {
+                       if( _direction == Direction.RowCol && dc.rowsKnown() )
+                               ret = new MatrixCharacteristics(dc.getRows(), 
1, -1, -1);
+                       else
+                               ret = new MatrixCharacteristics(dc.getRows(), 
dc.getCols(), -1, -1);
+               }
+               else {
+                       if( _direction == Direction.Col && dc.colsKnown() )
+                               ret = new MatrixCharacteristics(1, 
dc.getCols(), -1, -1);
+                       else if( _direction == Direction.Row && dc.rowsKnown() )
+                               ret = new MatrixCharacteristics(dc.getRows(), 
1, -1, -1);
+               }
                return ret;
        }
        
@@ -648,9 +656,23 @@ public class AggUnaryOp extends MultiThreadedHop
        @Override
        public void refreshSizeInformation()
        {
-               if (getDataType() != DataType.SCALAR)
-               {
-                       Hop input = getInput().get(0);
+               Hop input = getInput().get(0);
+               if( _op == AggOp.UNIQUE ) {
+                       if ( _direction == Direction.Col ) {
+                               setDim1(-1); //unknown num unique
+                               setDim2(input.getDim2());
+                       }
+                       else if ( _direction == Direction.Row ) {
+                               setDim1(input.getDim1());
+                               setDim2(-1); //unknown num unique
+                       }
+                       else {
+                               setDim1(-1);
+                               setDim2(1);
+                       }
+               }
+               //general case: all other unary aggregations 
+               else if (getDataType() != DataType.SCALAR) {
                        if ( _direction == Direction.Col ) //colwise 
computations
                        {
                                setDim1(1);
diff --git 
a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java 
b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
index f997db6503..1fbdb1fd46 100644
--- a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
+++ b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
@@ -57,9 +57,9 @@ public class EstimatorLayeredGraph extends SparsityEstimator {
        
        @Override
        public DataCharacteristics estim(MMNode root) {
-               List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>());
-               List<OpCode> ops = getOps(root, new ArrayList<>());
-               List<LayeredGraph> LGs = new ArrayList<>();
+               //List<MatrixBlock> leafs = getMatrices(root, new 
ArrayList<>());
+               //List<OpCode> ops = getOps(root, new ArrayList<>());
+               //List<LayeredGraph> LGs = new ArrayList<>();
                LayeredGraph ret = traverse(root);
                long nnz = ret.estimateNnz();
                return root.setDataCharacteristics(new MatrixCharacteristics(
@@ -125,6 +125,7 @@ public class EstimatorLayeredGraph extends 
SparsityEstimator {
                }
        }
        
+       @SuppressWarnings("unused")
        private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> 
leafs) {
                //NOTE: this extraction is only correct and efficient for 
chains, no DAGs
                if( node.isLeaf() )
@@ -136,6 +137,7 @@ public class EstimatorLayeredGraph extends 
SparsityEstimator {
                return leafs;
        }
 
+       @SuppressWarnings("unused")
        private List<OpCode> getOps(MMNode node, List<OpCode> ops) {
                //NOTE: this extraction is only correct and efficient for 
chains, no DAGs
                if(node.isLeaf()) {
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
index f29d1dce81..1ff5e086ce 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
@@ -27,8 +27,6 @@ import java.util.Map.Entry;
 import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.hops.BinaryOp;
 
 import org.apache.sysds.common.Types.DataType;
diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 4ee92e783b..314440628e 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -562,24 +562,26 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
 
        private void validateUniqueAggregationDirection(Identifier dataId, 
DataIdentifier output) {
                HashMap<String, Expression> varParams = getVarParams();
+               String inputDirection = Types.Direction.RowCol.toString();
                if (varParams.containsKey("dir")) {
-                       String inputDirectionString = 
varParams.get("dir").toString().toUpperCase();
-
+                       inputDirection = 
varParams.get("dir").toString().toUpperCase();
                        // unrecognized value for "dir" parameter
-                       if 
(!inputDirectionString.equals(Types.Direction.Row.toString())
-                                       && 
!inputDirectionString.equals(Types.Direction.Col.toString())
-                                       && 
!inputDirectionString.equals(Types.Direction.RowCol.toString())) {
-                               raiseValidateError("Invalid argument: " + 
inputDirectionString + " is not recognized");
+                       if 
(!inputDirection.equals(Types.Direction.Row.toString())
+                                       && 
!inputDirection.equals(Types.Direction.Col.toString())
+                                       && 
!inputDirection.equals(Types.Direction.RowCol.toString())) {
+                               raiseValidateError("Invalid argument: " + 
inputDirection + " is not recognized");
                        }
                }
 
-               // rc/r/c -> unique return value is the same as the input in 
the worst case
                // default to dir="rc"
                output.setDataType(DataType.MATRIX);
-               output.setDimensions(dataId.getDim1(), dataId.getDim2());
+               output.setDimensions(
+                       inputDirection.equals(Types.Direction.Row.toString()) ? 
dataId.getDim1() : -1,
+                       inputDirection.equals(Types.Direction.Col.toString()) ? 
dataId.getDim2() : 
+                               
inputDirection.equals(Types.Direction.RowCol.toString()) ? 1 : -1);
                output.setBlocksize(dataId.getBlocksize());
                output.setValueType(ValueType.FP64);
-               output.setNnz(dataId.getNnz());
+               output.setNnz(-1);
        }
 
        private void checkStringParam(boolean optional, String fname, String 
pname, boolean conditional) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java 
b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
index 4b4a76aa19..9d9fa59bc9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
@@ -27,6 +27,7 @@ import 
org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
 import org.junit.Assert;
 
 import java.util.HashMap;
@@ -38,6 +39,7 @@ public class SizePropagationTest extends AutomatedTestBase
        private static final String TEST_NAME3 = "SizePropagationLoopIx2";
        private static final String TEST_NAME4 = "SizePropagationLoopIx3";
        private static final String TEST_NAME5 = "SizePropagationLoopIx4";
+       private static final String TEST_NAME6 = "SizePropagationUnique";
        
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
SizePropagationTest.class.getSimpleName() + "/";
@@ -52,6 +54,7 @@ public class SizePropagationTest extends AutomatedTestBase
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
        }
 
        @Test
@@ -104,6 +107,16 @@ public class SizePropagationTest extends AutomatedTestBase
                testSizePropagation( TEST_NAME5, true, N );
        }
        
+       @Test
+       public void testSizePropagationUnique1() {
+               testSizePropagation( TEST_NAME6, false, 10 );
+       }
+       
+       @Test
+       public void testSizePropagationUnique2() {
+               testSizePropagation( TEST_NAME6, false, 10 );
+       }
+       
        private void testSizePropagation( String testname, boolean rewrites, 
int expect ) {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                ExecMode oldPlatform = rtplatform;
@@ -122,6 +135,8 @@ public class SizePropagationTest extends AutomatedTestBase
                        runTest(true, false, null, -1); 
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
                        Assert.assertEquals(Double.valueOf(expect), 
dmlfile.get(new CellIndex(1,1)));
+                       if( testname.equals(TEST_NAME6) )
+                               Assert.assertEquals(0, 
Statistics.getNoOfCompiledSPInst());
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
diff --git a/src/test/scripts/functions/misc/SizePropagationUnique.dml 
b/src/test/scripts/functions/misc/SizePropagationUnique.dml
new file mode 100644
index 0000000000..803cb949ed
--- /dev/null
+++ b/src/test/scripts/functions/misc/SizePropagationUnique.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix("1 2 3 4 5 6 7", rows=7,cols=1)
+B = matrix("4 5 6 7 8 9 10", rows=7,cols=1)
+C = rbind(A,B)
+D = unique(C)
+n = nrow(D);
+R = as.matrix(n);
+write(R, $2);

Reply via email to