This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new b8d373a889 [SYSTEMDS-3853] Fix ampute outer broadcasting and error handling b8d373a889 is described below commit b8d373a889963ca2845ced0f8d717a3d26295186 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Wed Apr 16 13:42:04 2025 +0200 [SYSTEMDS-3853] Fix ampute outer broadcasting and error handling This patch fixes an invalid left-hand-side and left- and right-hand-side broadcasting in the new ampute builtin function. We now have a proper error handling in the hop to guide script developers that broadcasts can only be used from the right-hand-side. --- scripts/builtin/ampute.dml | 7 +++---- src/main/java/org/apache/sysds/hops/BinaryOp.java | 9 +++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/scripts/builtin/ampute.dml b/scripts/builtin/ampute.dml index 7d96136b7c..691e5b48e2 100644 --- a/scripts/builtin/ampute.dml +++ b/scripts/builtin/ampute.dml @@ -72,8 +72,8 @@ m_ampute = function(Matrix[Double] X, # 4. Use probabilities to ampute pattern candidates: random = rand(rows=groupSize, cols=1, min=0, max=1, pdf="uniform", seed=seed) - amputeds = (random <= probs) * (1 - patterns[patternNum]) # Obtains matrix with 1's at indices to ampute. - while (FALSE) {} # FIX ME + # Obtains matrix with 1's at indices to ampute. + amputeds = outer((random <= probs), (1 - patterns[patternNum]), "*") groupSamples = groupSamples + replace(target=amputeds, pattern=1, replacement=NaN) # 5. Update output matrix: @@ -241,7 +241,6 @@ return (Matrix[Double] groupAssignments, Matrix[Double] groupCounts) { for (i in 1:numGroups) { assigned = (random >= cumSum[i]) & (random < cumSum[i + 1]) - while (FALSE) {} # FIX ME groupCounts[i] = sum(assigned) groupAssignments = groupAssignments + i * assigned } @@ -308,4 +307,4 @@ return(Integer start, Integer end) { start = sum(numPerGroup[1:(patternNum - 1), ]) + 1 } end = start + groupSize - 1 -} \ No newline at end of file +} diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 8d2b00c1aa..bbcb8b121b 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -1092,6 +1092,15 @@ public class BinaryOp extends MultiThreadedHop { } else //GENERAL CASE { + //check correct broadcasting dimensions + if( (input1.getDim1()==1 && input2.getDim1() > 1) + || (input1.getDim2()==1 && input2.getDim2() > 1) ) + { + throw new HopsException("Invalid binary broadcasting from left: " + + input1.getDataCharacteristics()+" "+getOp().name()+" " + +input2.getDataCharacteristics()); + } + ldim1 = (input1.rowsKnown()) ? input1.getDim1() : ((input2.getDim1()>1)?input2.getDim1():-1); ldim2 = (input1.colsKnown()) ? input1.getDim2()