This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ebf913c17 [SYSTEMDS-3868] Fix missing function hoisting from if 
predicates
7ebf913c17 is described below

commit 7ebf913c17518190a82b216696b0a08c93ba2892
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Fri Apr 25 17:42:25 2025 +0200

    [SYSTEMDS-3868] Fix missing function hoisting from if predicates
    
    This patch adds the missing hoisting of DML function calls
    (which always need to bind to variables) from basic if
    predicates for convenience and in order to prevent
    unexpected errors. Furthermore, this patch simplifies the
    existing DML-bodied ampute() builtin by using this features
    as well as call the existing sigmoid() instead of a custom one.
---
 scripts/builtin/ampute.dml                                | 13 +++----------
 src/main/java/org/apache/sysds/parser/StatementBlock.java |  9 ++++++++-
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/scripts/builtin/ampute.dml b/scripts/builtin/ampute.dml
index 691e5b48e2..90557789dd 100644
--- a/scripts/builtin/ampute.dml
+++ b/scripts/builtin/ampute.dml
@@ -184,8 +184,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns, 
Matrix[Double] weights) {
 u_handleDefaults = function(Matrix[Double] freq, Matrix[Double] patterns, 
Matrix[Double] weights, String mech, Integer numFeatures)
 return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
   # Patterns: Default is a quadratic matrix wherein pattern i amputes feature 
i.
-  empty = u_isEmpty(patterns)
-  if (empty) { # FIX ME
+  if (u_isEmpty(patterns)) {
     patterns = matrix(1, rows=numFeatures, cols=numFeatures) - diag(matrix(1, 
rows=numFeatures, cols=1))
   }
 
@@ -205,8 +204,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns, 
Matrix[Double] weights) {
   }
 
   # Frequencies: Uniform by default.
-  empty = u_isEmpty(freq) # FIX ME
-  if (empty) {
+  if (u_isEmpty(freq)) {
     freq = matrix(1 / numPatterns, rows=numPatterns, cols=1)
   }
 }
@@ -282,7 +280,7 @@ return (Matrix[Double] probsArray) {
   while (counter < maxIter & (is.na(currentProb) | abs(currentProb - prop) >= 
epsilon)) {
     counter += 1
     shift = lowerRange + (upperRange - lowerRange) / 2
-    probsArray = u_sigmoid(zScores + shift) # Calculates Right-Sigmoid 
probability (R implementation's default).
+    probsArray = sigmoid(zScores + shift) # Calculates Right-Sigmoid 
probability (R implementation's default).
     currentProb = mean(probsArray)
     if (currentProb - prop > 0) {
       upperRange = shift
@@ -293,11 +291,6 @@ return (Matrix[Double] probsArray) {
   }
 }
 
-u_sigmoid = function(Matrix[Double] X)
-return (Matrix[Double] sigmoided) {
-  sigmoided = 1 / (1 + exp(-X))
-}
-
 u_getBounds = function(Matrix[Double] numPerGroup, Integer groupSize, Integer 
patternNum)
 return(Integer start, Integer end) {
   if (patternNum == 1) {
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index b81a603e7c..2e62cc7f2e 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -503,7 +503,12 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                else if (current instanceof IfStatementBlock) {
                        IfStatementBlock isb = (IfStatementBlock) current;
                        IfStatement istmt = (IfStatement)isb.getStatement(0);
-                       //TODO handle predicates
+                       //handle predicate 
+                       ArrayList<Statement> tmpPred = new ArrayList<>();
+                       istmt.getConditionalPredicate().setPredicate(
+                               rHoistFunctionCallsFromExpressions(
+                                       
istmt.getConditionalPredicate().getPredicate(), false, tmpPred, prog));
+                       //handle if and else body
                        ArrayList<StatementBlock> tmp = new ArrayList<>();
                        for (StatementBlock sb : istmt.getIfBody())
                                
tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog));
@@ -514,6 +519,8 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                                        
tmp2.addAll(rHoistFunctionCallsFromExpressions(sb, prog));
                                istmt.setElseBody(tmp2);
                        }
+                       if( !tmpPred.isEmpty() )
+                               return createStatementBlocks(current, tmpPred);
                }
                else if (current instanceof ForStatementBlock) { //incl parfor
                        ForStatementBlock fsb = (ForStatementBlock) current;

Reply via email to