Repository: systemml Updated Branches: refs/heads/master 367935fad -> c2dd05e51
[SYSTEMML-2093] Parser and language support for new += operator This patch adds parser and language support for the new += accumulation assignment operator. The semantics of A += B are equivalent to A = A + B, but this operator will receive special treatment for parfor result variables (with sum as result aggregation function). This patch also modifies a number of existing algorithms accordingly to integrate this operator into our testsuite. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1c7ecbf2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1c7ecbf2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1c7ecbf2 Branch: refs/heads/master Commit: 1c7ecbf256c249d48aa1664db9edf726156cb014 Parents: 367935f Author: Matthias Boehm <[email protected]> Authored: Sat Jan 27 16:16:11 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Sat Jan 27 16:16:11 2018 -0800 ---------------------------------------------------------------------- scripts/algorithms/LinearRegCG.dml | 6 ++-- scripts/algorithms/l2-svm.dml | 5 ++-- scripts/algorithms/m-svm.dml | 4 +-- .../sysml/parser/AssignmentStatement.java | 30 +++++++++++--------- .../org/apache/sysml/parser/DMLTranslator.java | 26 +++++++++++------ .../org/apache/sysml/parser/Identifier.java | 1 - .../java/org/apache/sysml/parser/dml/Dml.g4 | 1 + .../sysml/parser/dml/DmlPreprocessor.java | 7 +++++ .../sysml/parser/dml/DmlSyntacticValidator.java | 16 +++++++++++ 9 files changed, 64 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/scripts/algorithms/LinearRegCG.dml ---------------------------------------------------------------------- diff --git a/scripts/algorithms/LinearRegCG.dml b/scripts/algorithms/LinearRegCG.dml index 08cbaef..656d261 100644 --- a/scripts/algorithms/LinearRegCG.dml +++ b/scripts/algorithms/LinearRegCG.dml @@ -186,10 +186,10 @@ while (i < max_iteration & norm_r2 > norm_r2_target) q = scale_X * q + shift_X %*% q [m_ext, ]; } - q = q + lambda * p; + q += lambda * p; a = norm_r2 / sum (p * q); - beta_unscaled = beta_unscaled + a * p; - r = r + a * q; + beta_unscaled += a * p; + r += a * q; old_norm_r2 = norm_r2; norm_r2 = sum (r ^ 2); p = -r + (norm_r2 / old_norm_r2) * p; http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/scripts/algorithms/l2-svm.dml ---------------------------------------------------------------------- diff --git a/scripts/algorithms/l2-svm.dml b/scripts/algorithms/l2-svm.dml index 2446610..cf669b5 100644 --- a/scripts/algorithms/l2-svm.dml +++ b/scripts/algorithms/l2-svm.dml @@ -138,8 +138,8 @@ while(continue & iter < maxiterations) { } #update weights - w = w + step_sz*s - Xw = Xw + step_sz*Xd + w += step_sz * s + Xw += step_sz * Xd out = 1 - Y * Xw sv = (out > 0) @@ -159,7 +159,6 @@ while(continue & iter < maxiterations) { continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0); iter = iter + 1 - } extra_model_params = matrix(0, rows=4, cols=1) http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/scripts/algorithms/m-svm.dml ---------------------------------------------------------------------- diff --git a/scripts/algorithms/m-svm.dml b/scripts/algorithms/m-svm.dml index 253764c..024aeda 100644 --- a/scripts/algorithms/m-svm.dml +++ b/scripts/algorithms/m-svm.dml @@ -148,8 +148,8 @@ parfor(iter_class in 1:num_classes){ } #update weights - w_class = w_class + step_sz*s - Xw = Xw + step_sz*Xd + w_class += step_sz * s + Xw += step_sz * Xd out = 1 - Y_local * Xw sv = (out > 0) http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/AssignmentStatement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java index 9195fbd..3525d13 100644 --- a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java +++ b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java @@ -30,6 +30,7 @@ public class AssignmentStatement extends Statement { private ArrayList<DataIdentifier> _targetList; private Expression _source; + private boolean _isAccum; //+= // rewrites statement to support function inlining (creates deep copy) @Override @@ -66,18 +67,26 @@ public class AssignmentStatement extends Statement return _targetList.get(0); } - public ArrayList<DataIdentifier> getTargetList() - { + public ArrayList<DataIdentifier> getTargetList() { return _targetList; } public Expression getSource(){ return _source; } + public void setSource(Expression s){ _source = s; } + public boolean isAccumulator() { + return _isAccum; + } + + public void setAccumulator(boolean flag) { + _isAccum = flag; + } + @Override public boolean controlStatement() { // ensure that breakpoints end up in own statement block @@ -106,29 +115,22 @@ public class AssignmentStatement extends Statement @Override public VariableSet variablesRead() { VariableSet result = new VariableSet(); - // add variables read by source expression result.addVariables(_source.variablesRead()); - - // for LHS IndexedIdentifier, add variables for indexing expressions - for (int i=0; i<_targetList.size(); i++){ - if (_targetList.get(i) instanceof IndexedIdentifier) { - IndexedIdentifier target = (IndexedIdentifier) _targetList.get(i); + // for left indexing or accumulators add targets as well + for (DataIdentifier target : _targetList) + if (target instanceof IndexedIdentifier || _isAccum ) result.addVariables(target.variablesRead()); - } - } return result; } @Override public VariableSet variablesUpdated() { VariableSet result = new VariableSet(); - // add target to updated list for (DataIdentifier target : _targetList) - if (target != null) { + if (target != null) result.addVariable(target.getName(), target); - } return result; } @@ -139,7 +141,7 @@ public class AssignmentStatement extends Statement DataIdentifier di = _targetList.get(i); sb.append(di); } - sb.append(" = "); + sb.append(_isAccum ? " += " : " = "); if (_source instanceof StringIdentifier) { sb.append("\""); sb.append(_source.toString()); http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 9a47d09..faee84b 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -1264,10 +1264,21 @@ public class DMLTranslator // CASE: target is regular data identifier if (!(target instanceof IndexedIdentifier)) { - + //process right hand side and accumulation Hop ae = processExpression(source, target, ids); + if( ((AssignmentStatement)current).isAccumulator() ) { + DataIdentifier accum = liveIn.getVariable(target.getName()); + if( accum == null ) + throw new LanguageException("Invalid accumulator assignment " + + "to non-existing variable "+target.getName()+"."); + ae = HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS); + target.setProperties(accum.getOutput()); + } + else + target.setProperties(source.getOutput()); ids.put(target.getName(), ae); - target.setProperties(source.getOutput()); + + //add transient write if needed Integer statementId = liveOutToTemp.get(target.getName()); if ((statementId != null) && (statementId.intValue() == i)) { DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null); @@ -1276,8 +1287,7 @@ public class DMLTranslator updatedLiveOut.addVariable(target.getName(), target); output.add(transientwrite); } - } // end if (!(target instanceof IndexedIdentifier)) { - + } // CASE: target is indexed identifier (left-hand side indexed expression) else { Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier)target, ids); @@ -1287,12 +1297,12 @@ public class DMLTranslator // obtain origDim values BEFORE they are potentially updated during setProperties call // (this is incorrect for LHS Indexing) long origDim1 = ((IndexedIdentifier)target).getOrigDim1(); - long origDim2 = ((IndexedIdentifier)target).getOrigDim2(); + long origDim2 = ((IndexedIdentifier)target).getOrigDim2(); target.setProperties(source.getOutput()); ((IndexedIdentifier)target).setOriginalDimensions(origDim1, origDim2); // preserve data type matrix of any index identifier - // (required for scalar input to left indexing) + // (required for scalar input to left indexing) if( target.getDataType() != DataType.MATRIX ) { target.setDataType(DataType.MATRIX); target.setValueType(ValueType.DOUBLE); @@ -1308,10 +1318,8 @@ public class DMLTranslator output.add(transientwrite); } } - - } - else + else { //assignment, function call FunctionCallIdentifier fci = (FunctionCallIdentifier) source; http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/Identifier.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Identifier.java b/src/main/java/org/apache/sysml/parser/Identifier.java index 4d62f1a..b2a679d 100644 --- a/src/main/java/org/apache/sysml/parser/Identifier.java +++ b/src/main/java/org/apache/sysml/parser/Identifier.java @@ -66,7 +66,6 @@ public abstract class Identifier extends Expression _columns_in_block = i.getColumnsInBlock(); _nnz = i.getNnz(); _formatType = i.getFormatType(); - } public void setDimensionValueProperties(Identifier i) http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/dml/Dml.g4 ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 index a3782f2..fb72ed2 100644 --- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 +++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 @@ -66,6 +66,7 @@ statement returns [ org.apache.sysml.parser.common.StatementInfo info ] // AssignmentStatement | targetList=dataIdentifier op=('<-'|'=') 'ifdef' '(' commandLineParam=dataIdentifier ',' source=expression ')' ';'* # IfdefAssignmentStatement | targetList=dataIdentifier op=('<-'|'=') source=expression ';'* # AssignmentStatement + | targetList=dataIdentifier op='+=' source=expression ';'* # AccumulatorAssignmentStatement // ------------------------------------------ // We don't support block statement // | '{' body+=expression ';'* ( body+=expression ';'* )* '}' # BlockStatement http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java index 11d6c7f..00473c0 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java +++ b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java @@ -27,6 +27,7 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ErrorNode; import org.antlr.v4.runtime.tree.TerminalNode; import org.apache.sysml.parser.common.CustomErrorListener; +import org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext; import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext; import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext; import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext; @@ -338,6 +339,12 @@ public class DmlPreprocessor implements DmlListener { public void exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {} @Override + public void enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {} + + @Override + public void exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {} + + @Override public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) {} @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java index 56642d9..93b670b 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java @@ -62,6 +62,7 @@ import org.apache.sysml.parser.common.CommonSyntacticValidator; import org.apache.sysml.parser.common.CustomErrorListener; import org.apache.sysml.parser.common.ExpressionInfo; import org.apache.sysml.parser.common.StatementInfo; +import org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext; import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext; import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext; import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext; @@ -922,6 +923,19 @@ public class DmlSyntacticValidator extends CommonSyntacticValidator implements D } @Override + public void exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) { + if(ctx.targetList == null) { + notifyErrorListeners("incorrect parsing for accumulator assignment", ctx.start); + return; + } + //process as default assignment statement + exitAssignmentStatementHelper(ctx, ctx.targetList.getText(), + ctx.targetList.dataInfo, ctx.targetList.start, ctx.source.info, ctx.info); + //mark as accumulator + ((AssignmentStatement)ctx.info.stmt).setAccumulator(true); + } + + @Override public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) { checkValidDataType(ctx.ID().getText(), ctx.start); } @@ -955,6 +969,8 @@ public class DmlSyntacticValidator extends CommonSyntacticValidator implements D @Override public void enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {} + @Override public void enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {} + @Override public void enterConstStringIdExpression(ConstStringIdExpressionContext ctx) {} @Override public void enterConstTrueExpression(ConstTrueExpressionContext ctx) {}
