This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 7ebf913c17 [SYSTEMDS-3868] Fix missing function hoisting from if predicates 7ebf913c17 is described below commit 7ebf913c17518190a82b216696b0a08c93ba2892 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Fri Apr 25 17:42:25 2025 +0200 [SYSTEMDS-3868] Fix missing function hoisting from if predicates This patch adds the missing hoisting of DML function calls (which always need to bind to variables) from basic if predicates for convenience and in order to prevent unexpected errors. Furthermore, this patch simplifies the existing DML-bodied ampute() builtin by using this features as well as call the existing sigmoid() instead of a custom one. --- scripts/builtin/ampute.dml | 13 +++---------- src/main/java/org/apache/sysds/parser/StatementBlock.java | 9 ++++++++- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/scripts/builtin/ampute.dml b/scripts/builtin/ampute.dml index 691e5b48e2..90557789dd 100644 --- a/scripts/builtin/ampute.dml +++ b/scripts/builtin/ampute.dml @@ -184,8 +184,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) { u_handleDefaults = function(Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights, String mech, Integer numFeatures) return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) { # Patterns: Default is a quadratic matrix wherein pattern i amputes feature i. - empty = u_isEmpty(patterns) - if (empty) { # FIX ME + if (u_isEmpty(patterns)) { patterns = matrix(1, rows=numFeatures, cols=numFeatures) - diag(matrix(1, rows=numFeatures, cols=1)) } @@ -205,8 +204,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) { } # Frequencies: Uniform by default. - empty = u_isEmpty(freq) # FIX ME - if (empty) { + if (u_isEmpty(freq)) { freq = matrix(1 / numPatterns, rows=numPatterns, cols=1) } } @@ -282,7 +280,7 @@ return (Matrix[Double] probsArray) { while (counter < maxIter & (is.na(currentProb) | abs(currentProb - prop) >= epsilon)) { counter += 1 shift = lowerRange + (upperRange - lowerRange) / 2 - probsArray = u_sigmoid(zScores + shift) # Calculates Right-Sigmoid probability (R implementation's default). + probsArray = sigmoid(zScores + shift) # Calculates Right-Sigmoid probability (R implementation's default). currentProb = mean(probsArray) if (currentProb - prop > 0) { upperRange = shift @@ -293,11 +291,6 @@ return (Matrix[Double] probsArray) { } } -u_sigmoid = function(Matrix[Double] X) -return (Matrix[Double] sigmoided) { - sigmoided = 1 / (1 + exp(-X)) -} - u_getBounds = function(Matrix[Double] numPerGroup, Integer groupSize, Integer patternNum) return(Integer start, Integer end) { if (patternNum == 1) { diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java index b81a603e7c..2e62cc7f2e 100644 --- a/src/main/java/org/apache/sysds/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java @@ -503,7 +503,12 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo else if (current instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) current; IfStatement istmt = (IfStatement)isb.getStatement(0); - //TODO handle predicates + //handle predicate + ArrayList<Statement> tmpPred = new ArrayList<>(); + istmt.getConditionalPredicate().setPredicate( + rHoistFunctionCallsFromExpressions( + istmt.getConditionalPredicate().getPredicate(), false, tmpPred, prog)); + //handle if and else body ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : istmt.getIfBody()) tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog)); @@ -514,6 +519,8 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo tmp2.addAll(rHoistFunctionCallsFromExpressions(sb, prog)); istmt.setElseBody(tmp2); } + if( !tmpPred.isEmpty() ) + return createStatementBlocks(current, tmpPred); } else if (current instanceof ForStatementBlock) { //incl parfor ForStatementBlock fsb = (ForStatementBlock) current;