This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new b8d373a889 [SYSTEMDS-3853] Fix ampute outer broadcasting and error 
handling
b8d373a889 is described below

commit b8d373a889963ca2845ced0f8d717a3d26295186
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Wed Apr 16 13:42:04 2025 +0200

    [SYSTEMDS-3853] Fix ampute outer broadcasting and error handling
    
    This patch fixes an invalid left-hand-side and left- and right-hand-side
    broadcasting in the new ampute builtin function. We now have a proper
    error handling in the hop to guide script developers that broadcasts
    can only be used from the right-hand-side.
---
 scripts/builtin/ampute.dml                        | 7 +++----
 src/main/java/org/apache/sysds/hops/BinaryOp.java | 9 +++++++++
 2 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/scripts/builtin/ampute.dml b/scripts/builtin/ampute.dml
index 7d96136b7c..691e5b48e2 100644
--- a/scripts/builtin/ampute.dml
+++ b/scripts/builtin/ampute.dml
@@ -72,8 +72,8 @@ m_ampute = function(Matrix[Double] X,
 
       # 4. Use probabilities to ampute pattern candidates:
       random = rand(rows=groupSize, cols=1, min=0, max=1, pdf="uniform", 
seed=seed)
-      amputeds = (random <= probs) * (1 - patterns[patternNum]) # Obtains 
matrix with 1's at indices to ampute.
-      while (FALSE) {} # FIX ME
+      # Obtains matrix with 1's at indices to ampute.
+      amputeds = outer((random <= probs), (1 - patterns[patternNum]), "*")
       groupSamples = groupSamples + replace(target=amputeds, pattern=1, 
replacement=NaN)
 
       # 5. Update output matrix:
@@ -241,7 +241,6 @@ return (Matrix[Double] groupAssignments, Matrix[Double] 
groupCounts) {
 
     for (i in 1:numGroups) {
       assigned = (random >= cumSum[i]) & (random < cumSum[i + 1])
-      while (FALSE) {} # FIX ME
       groupCounts[i] = sum(assigned)
       groupAssignments = groupAssignments + i * assigned
     }
@@ -308,4 +307,4 @@ return(Integer start, Integer end) {
     start = sum(numPerGroup[1:(patternNum - 1), ]) + 1
   }
   end = start + groupSize - 1
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 8d2b00c1aa..bbcb8b121b 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -1092,6 +1092,15 @@ public class BinaryOp extends MultiThreadedHop {
                                        }
                                        else //GENERAL CASE
                                        {
+                                               //check correct broadcasting 
dimensions
+                                               if( (input1.getDim1()==1 && 
input2.getDim1() > 1)
+                                                       || (input1.getDim2()==1 
&& input2.getDim2() > 1) )
+                                               {
+                                                       throw new 
HopsException("Invalid binary broadcasting from left: "
+                                                               + 
input1.getDataCharacteristics()+" "+getOp().name()+" "
+                                                               
+input2.getDataCharacteristics());
+                                               }
+                                               
                                                ldim1 = (input1.rowsKnown()) ? 
input1.getDim1()
                                                        : 
((input2.getDim1()>1)?input2.getDim1():-1);
                                                ldim2 = (input1.colsKnown()) ? 
input1.getDim2() 

Reply via email to