This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 61afba5d0f [SYSTEMDS-3888] Fix size propagation over unique operations 61afba5d0f is described below commit 61afba5d0ffe8e1cbf9f3e4b956d57f0bc3997b4 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Fri Jun 6 12:25:38 2025 +0200 [SYSTEMDS-3888] Fix size propagation over unique operations This patch fixes the incorrect size propagation of unique which led to incorrect results if the dimensions are used in subsequent ops. Thanks to Chi-Hsin Huang for catching this bug. Furthermore, this patch also includes minor updates for code quality (removed unused imports, annotated unused functions) --- .../java/org/apache/sysds/hops/AggUnaryOp.java | 36 +++++++++++++++++----- .../sysds/hops/estim/EstimatorLayeredGraph.java | 8 +++-- .../RewriteQuantizationFusedCompression.java | 2 -- .../ParameterizedBuiltinFunctionExpression.java | 20 ++++++------ .../test/functions/misc/SizePropagationTest.java | 15 +++++++++ .../functions/misc/SizePropagationUnique.dml | 28 +++++++++++++++++ 6 files changed, 88 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java index 954114a0a4..0b2d62bbe3 100644 --- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java @@ -323,10 +323,18 @@ public class AggUnaryOp extends MultiThreadedHop DataCharacteristics ret = null; Hop input = getInput().get(0); DataCharacteristics dc = memo.getAllInputStats(input); - if( _direction == Direction.Col && dc.colsKnown() ) - ret = new MatrixCharacteristics(1, dc.getCols(), -1, -1); - else if( _direction == Direction.Row && dc.rowsKnown() ) - ret = new MatrixCharacteristics(dc.getRows(), 1, -1, -1); + if( _op == AggOp.UNIQUE ) { + if( _direction == Direction.RowCol && dc.rowsKnown() ) + ret = new MatrixCharacteristics(dc.getRows(), 1, -1, -1); + else + ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, -1); + } + else { + if( _direction == Direction.Col && dc.colsKnown() ) + ret = new MatrixCharacteristics(1, dc.getCols(), -1, -1); + else if( _direction == Direction.Row && dc.rowsKnown() ) + ret = new MatrixCharacteristics(dc.getRows(), 1, -1, -1); + } return ret; } @@ -648,9 +656,23 @@ public class AggUnaryOp extends MultiThreadedHop @Override public void refreshSizeInformation() { - if (getDataType() != DataType.SCALAR) - { - Hop input = getInput().get(0); + Hop input = getInput().get(0); + if( _op == AggOp.UNIQUE ) { + if ( _direction == Direction.Col ) { + setDim1(-1); //unknown num unique + setDim2(input.getDim2()); + } + else if ( _direction == Direction.Row ) { + setDim1(input.getDim1()); + setDim2(-1); //unknown num unique + } + else { + setDim1(-1); + setDim2(1); + } + } + //general case: all other unary aggregations + else if (getDataType() != DataType.SCALAR) { if ( _direction == Direction.Col ) //colwise computations { setDim1(1); diff --git a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java index f997db6503..1fbdb1fd46 100644 --- a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java +++ b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java @@ -57,9 +57,9 @@ public class EstimatorLayeredGraph extends SparsityEstimator { @Override public DataCharacteristics estim(MMNode root) { - List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>()); - List<OpCode> ops = getOps(root, new ArrayList<>()); - List<LayeredGraph> LGs = new ArrayList<>(); + //List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>()); + //List<OpCode> ops = getOps(root, new ArrayList<>()); + //List<LayeredGraph> LGs = new ArrayList<>(); LayeredGraph ret = traverse(root); long nnz = ret.estimateNnz(); return root.setDataCharacteristics(new MatrixCharacteristics( @@ -125,6 +125,7 @@ public class EstimatorLayeredGraph extends SparsityEstimator { } } + @SuppressWarnings("unused") private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) { //NOTE: this extraction is only correct and efficient for chains, no DAGs if( node.isLeaf() ) @@ -136,6 +137,7 @@ public class EstimatorLayeredGraph extends SparsityEstimator { return leafs; } + @SuppressWarnings("unused") private List<OpCode> getOps(MMNode node, List<OpCode> ops) { //NOTE: this extraction is only correct and efficient for chains, no DAGs if(node.isLeaf()) { diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java index f29d1dce81..1ff5e086ce 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java @@ -27,8 +27,6 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.hops.UnaryOp; -import org.apache.sysds.runtime.instructions.cp.DoubleObject; -import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.hops.BinaryOp; import org.apache.sysds.common.Types.DataType; diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 4ee92e783b..314440628e 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -562,24 +562,26 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier private void validateUniqueAggregationDirection(Identifier dataId, DataIdentifier output) { HashMap<String, Expression> varParams = getVarParams(); + String inputDirection = Types.Direction.RowCol.toString(); if (varParams.containsKey("dir")) { - String inputDirectionString = varParams.get("dir").toString().toUpperCase(); - + inputDirection = varParams.get("dir").toString().toUpperCase(); // unrecognized value for "dir" parameter - if (!inputDirectionString.equals(Types.Direction.Row.toString()) - && !inputDirectionString.equals(Types.Direction.Col.toString()) - && !inputDirectionString.equals(Types.Direction.RowCol.toString())) { - raiseValidateError("Invalid argument: " + inputDirectionString + " is not recognized"); + if (!inputDirection.equals(Types.Direction.Row.toString()) + && !inputDirection.equals(Types.Direction.Col.toString()) + && !inputDirection.equals(Types.Direction.RowCol.toString())) { + raiseValidateError("Invalid argument: " + inputDirection + " is not recognized"); } } - // rc/r/c -> unique return value is the same as the input in the worst case // default to dir="rc" output.setDataType(DataType.MATRIX); - output.setDimensions(dataId.getDim1(), dataId.getDim2()); + output.setDimensions( + inputDirection.equals(Types.Direction.Row.toString()) ? dataId.getDim1() : -1, + inputDirection.equals(Types.Direction.Col.toString()) ? dataId.getDim2() : + inputDirection.equals(Types.Direction.RowCol.toString()) ? 1 : -1); output.setBlocksize(dataId.getBlocksize()); output.setValueType(ValueType.FP64); - output.setNnz(dataId.getNnz()); + output.setNnz(-1); } private void checkStringParam(boolean optional, String fname, String pname, boolean conditional) { diff --git a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java index 4b4a76aa19..9d9fa59bc9 100644 --- a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java @@ -27,6 +27,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; import org.junit.Assert; import java.util.HashMap; @@ -38,6 +39,7 @@ public class SizePropagationTest extends AutomatedTestBase private static final String TEST_NAME3 = "SizePropagationLoopIx2"; private static final String TEST_NAME4 = "SizePropagationLoopIx3"; private static final String TEST_NAME5 = "SizePropagationLoopIx4"; + private static final String TEST_NAME6 = "SizePropagationUnique"; private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + SizePropagationTest.class.getSimpleName() + "/"; @@ -52,6 +54,7 @@ public class SizePropagationTest extends AutomatedTestBase addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) ); } @Test @@ -104,6 +107,16 @@ public class SizePropagationTest extends AutomatedTestBase testSizePropagation( TEST_NAME5, true, N ); } + @Test + public void testSizePropagationUnique1() { + testSizePropagation( TEST_NAME6, false, 10 ); + } + + @Test + public void testSizePropagationUnique2() { + testSizePropagation( TEST_NAME6, false, 10 ); + } + private void testSizePropagation( String testname, boolean rewrites, int expect ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; ExecMode oldPlatform = rtplatform; @@ -122,6 +135,8 @@ public class SizePropagationTest extends AutomatedTestBase runTest(true, false, null, -1); HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); Assert.assertEquals(Double.valueOf(expect), dmlfile.get(new CellIndex(1,1))); + if( testname.equals(TEST_NAME6) ) + Assert.assertEquals(0, Statistics.getNoOfCompiledSPInst()); } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; diff --git a/src/test/scripts/functions/misc/SizePropagationUnique.dml b/src/test/scripts/functions/misc/SizePropagationUnique.dml new file mode 100644 index 0000000000..803cb949ed --- /dev/null +++ b/src/test/scripts/functions/misc/SizePropagationUnique.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix("1 2 3 4 5 6 7", rows=7,cols=1) +B = matrix("4 5 6 7 8 9 10", rows=7,cols=1) +C = rbind(A,B) +D = unique(C) +n = nrow(D); +R = as.matrix(n); +write(R, $2);