Repository: systemml
Updated Branches:
  refs/heads/master 367935fad -> c2dd05e51


[SYSTEMML-2093] Parser and language support for new += operator

This patch adds parser and language support for the new += accumulation
assignment operator. The semantics of A += B are equivalent to A = A +
B, but this operator will receive special treatment for parfor result
variables (with sum as result aggregation function). This patch also
modifies a number of existing algorithms accordingly to integrate this
operator into our testsuite.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1c7ecbf2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1c7ecbf2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1c7ecbf2

Branch: refs/heads/master
Commit: 1c7ecbf256c249d48aa1664db9edf726156cb014
Parents: 367935f
Author: Matthias Boehm <[email protected]>
Authored: Sat Jan 27 16:16:11 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jan 27 16:16:11 2018 -0800

----------------------------------------------------------------------
 scripts/algorithms/LinearRegCG.dml              |  6 ++--
 scripts/algorithms/l2-svm.dml                   |  5 ++--
 scripts/algorithms/m-svm.dml                    |  4 +--
 .../sysml/parser/AssignmentStatement.java       | 30 +++++++++++---------
 .../org/apache/sysml/parser/DMLTranslator.java  | 26 +++++++++++------
 .../org/apache/sysml/parser/Identifier.java     |  1 -
 .../java/org/apache/sysml/parser/dml/Dml.g4     |  1 +
 .../sysml/parser/dml/DmlPreprocessor.java       |  7 +++++
 .../sysml/parser/dml/DmlSyntacticValidator.java | 16 +++++++++++
 9 files changed, 64 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/scripts/algorithms/LinearRegCG.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/LinearRegCG.dml 
b/scripts/algorithms/LinearRegCG.dml
index 08cbaef..656d261 100644
--- a/scripts/algorithms/LinearRegCG.dml
+++ b/scripts/algorithms/LinearRegCG.dml
@@ -186,10 +186,10 @@ while (i < max_iteration & norm_r2 > norm_r2_target)
         q = scale_X * q + shift_X %*% q [m_ext, ];
     }
 
-       q = q + lambda * p;
+       q += lambda * p;
        a = norm_r2 / sum (p * q);
-       beta_unscaled = beta_unscaled + a * p;
-       r = r + a * q;
+       beta_unscaled += a * p;
+       r += a * q;
        old_norm_r2 = norm_r2;
        norm_r2 = sum (r ^ 2);
        p = -r + (norm_r2 / old_norm_r2) * p;

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/scripts/algorithms/l2-svm.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/l2-svm.dml b/scripts/algorithms/l2-svm.dml
index 2446610..cf669b5 100644
--- a/scripts/algorithms/l2-svm.dml
+++ b/scripts/algorithms/l2-svm.dml
@@ -138,8 +138,8 @@ while(continue & iter < maxiterations)  {
   }
 
   #update weights
-  w = w + step_sz*s
-  Xw = Xw + step_sz*Xd
+  w += step_sz * s
+  Xw += step_sz * Xd
   
   out = 1 - Y * Xw
   sv = (out > 0)
@@ -159,7 +159,6 @@ while(continue & iter < maxiterations)  {
   
   continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
   iter = iter + 1
-  
 }
 
 extra_model_params = matrix(0, rows=4, cols=1)

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/scripts/algorithms/m-svm.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/m-svm.dml b/scripts/algorithms/m-svm.dml
index 253764c..024aeda 100644
--- a/scripts/algorithms/m-svm.dml
+++ b/scripts/algorithms/m-svm.dml
@@ -148,8 +148,8 @@ parfor(iter_class in 1:num_classes){
     }
     
     #update weights
-    w_class = w_class + step_sz*s
-    Xw = Xw + step_sz*Xd
+    w_class += step_sz * s
+    Xw += step_sz * Xd
  
     out = 1 - Y_local * Xw
     sv = (out > 0)

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java 
b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
index 9195fbd..3525d13 100644
--- a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
+++ b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
@@ -30,6 +30,7 @@ public class AssignmentStatement extends Statement
 {
        private ArrayList<DataIdentifier> _targetList;
        private Expression _source;
+       private boolean _isAccum; //+=
         
        // rewrites statement to support function inlining (creates deep copy)
        @Override
@@ -66,18 +67,26 @@ public class AssignmentStatement extends Statement
                return _targetList.get(0);
        }
        
-       public ArrayList<DataIdentifier> getTargetList()
-       {
+       public ArrayList<DataIdentifier> getTargetList() {
                return _targetList;
        }
 
        public Expression getSource(){
                return _source;
        }
+       
        public void setSource(Expression s){
                _source = s;
        }
        
+       public boolean isAccumulator() {
+               return _isAccum;
+       }
+       
+       public void setAccumulator(boolean flag) {
+               _isAccum = flag;
+       }
+       
        @Override
        public boolean controlStatement() {
                // ensure that breakpoints end up in own statement block 
@@ -106,29 +115,22 @@ public class AssignmentStatement extends Statement
        @Override
        public VariableSet variablesRead() {
                VariableSet result = new VariableSet();
-               
                // add variables read by source expression
                result.addVariables(_source.variablesRead());
-               
-               // for LHS IndexedIdentifier, add variables for indexing 
expressions
-               for (int i=0; i<_targetList.size(); i++){
-                       if (_targetList.get(i) instanceof IndexedIdentifier) {
-                               IndexedIdentifier target = (IndexedIdentifier) 
_targetList.get(i);
+               // for left indexing or accumulators add targets as well
+               for (DataIdentifier target : _targetList)
+                       if (target instanceof IndexedIdentifier || _isAccum )
                                result.addVariables(target.variablesRead());
-                       }
-               }               
                return result;
        }
        
        @Override
        public  VariableSet variablesUpdated() {
                VariableSet result =  new VariableSet();
-               
                // add target to updated list
                for (DataIdentifier target : _targetList)
-                       if (target != null) {
+                       if (target != null)
                                result.addVariable(target.getName(), target);
-                       }
                return result;
        }
        
@@ -139,7 +141,7 @@ public class AssignmentStatement extends Statement
                        DataIdentifier di = _targetList.get(i);
                        sb.append(di);
                }
-               sb.append(" = ");
+               sb.append(_isAccum ? " += " : " = ");
                if (_source instanceof StringIdentifier) {
                        sb.append("\"");
                        sb.append(_source.toString());

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 9a47d09..faee84b 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -1264,10 +1264,21 @@ public class DMLTranslator
                                
                                        // CASE: target is regular data 
identifier
                                        if (!(target instanceof 
IndexedIdentifier)) {
-                                       
+                                               //process right hand side and 
accumulation
                                                Hop ae = 
processExpression(source, target, ids);
+                                               if( 
((AssignmentStatement)current).isAccumulator() ) {
+                                                       DataIdentifier accum = 
liveIn.getVariable(target.getName());
+                                                       if( accum == null )
+                                                               throw new 
LanguageException("Invalid accumulator assignment "
+                                                                       + "to 
non-existing variable "+target.getName()+".");
+                                                       ae = 
HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS);
+                                                       
target.setProperties(accum.getOutput());
+                                               }
+                                               else
+                                                       
target.setProperties(source.getOutput());
                                                ids.put(target.getName(), ae);
-                                               
target.setProperties(source.getOutput());
+                                               
+                                               //add transient write if needed
                                                Integer statementId = 
liveOutToTemp.get(target.getName());
                                                if ((statementId != null) && 
(statementId.intValue() == i)) {
                                                        DataOp transientwrite = 
new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, 
DataOpTypes.TRANSIENTWRITE, null);
@@ -1276,8 +1287,7 @@ public class DMLTranslator
                                                        
updatedLiveOut.addVariable(target.getName(), target);
                                                        
output.add(transientwrite);
                                                }
-                                       } // end if (!(target instanceof 
IndexedIdentifier)) {
-                                       
+                                       } 
                                        // CASE: target is indexed identifier 
(left-hand side indexed expression)
                                        else {
                                                Hop ae = 
processLeftIndexedExpression(source, (IndexedIdentifier)target, ids);
@@ -1287,12 +1297,12 @@ public class DMLTranslator
                                                // obtain origDim values BEFORE 
they are potentially updated during setProperties call
                                                //      (this is incorrect for 
LHS Indexing)
                                                long origDim1 = 
((IndexedIdentifier)target).getOrigDim1();
-                                               long origDim2 = 
((IndexedIdentifier)target).getOrigDim2();                                      
         
+                                               long origDim2 = 
((IndexedIdentifier)target).getOrigDim2();
                                                
target.setProperties(source.getOutput());
                                                
((IndexedIdentifier)target).setOriginalDimensions(origDim1, origDim2);
                                                
                                                // preserve data type matrix of 
any index identifier
-                                               // (required for scalar input 
to left indexing)                                 
+                                               // (required for scalar input 
to left indexing)
                                                if( target.getDataType() != 
DataType.MATRIX ) {
                                                        
target.setDataType(DataType.MATRIX);
                                                        
target.setValueType(ValueType.DOUBLE);
@@ -1308,10 +1318,8 @@ public class DMLTranslator
                                                        
output.add(transientwrite);
                                                }
                                        }
-                                       
-                                       
                                }
-                               else 
+                               else
                                {
                                        //assignment, function call
                                        FunctionCallIdentifier fci = 
(FunctionCallIdentifier) source;

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/Identifier.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Identifier.java 
b/src/main/java/org/apache/sysml/parser/Identifier.java
index 4d62f1a..b2a679d 100644
--- a/src/main/java/org/apache/sysml/parser/Identifier.java
+++ b/src/main/java/org/apache/sysml/parser/Identifier.java
@@ -66,7 +66,6 @@ public abstract class Identifier extends Expression
                _columns_in_block = i.getColumnsInBlock();
                _nnz = i.getNnz();
                _formatType = i.getFormatType();
-                               
        }
        
        public void setDimensionValueProperties(Identifier i)

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/dml/Dml.g4
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 
b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
index a3782f2..fb72ed2 100644
--- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4
+++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
@@ -66,6 +66,7 @@ statement returns [ 
org.apache.sysml.parser.common.StatementInfo info ]
     // AssignmentStatement
     | targetList=dataIdentifier op=('<-'|'=') 'ifdef' '(' 
commandLineParam=dataIdentifier ','  source=expression ')' ';'*   # 
IfdefAssignmentStatement
     | targetList=dataIdentifier op=('<-'|'=') source=expression ';'*   # 
AssignmentStatement
+    | targetList=dataIdentifier op='+=' source=expression ';'*   # 
AccumulatorAssignmentStatement
     // ------------------------------------------
     // We don't support block statement
     // | '{' body+=expression ';'* ( body+=expression ';'* )*  '}' # 
BlockStatement

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java 
b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
index 11d6c7f..00473c0 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
@@ -27,6 +27,7 @@ import org.antlr.v4.runtime.Token;
 import org.antlr.v4.runtime.tree.ErrorNode;
 import org.antlr.v4.runtime.tree.TerminalNode;
 import org.apache.sysml.parser.common.CustomErrorListener;
+import 
org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext;
 import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext;
 import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext;
 import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext;
@@ -338,6 +339,12 @@ public class DmlPreprocessor implements DmlListener {
        public void 
exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {}
 
        @Override
+       public void 
enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) 
{}
+
+       @Override
+       public void 
exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {}
+
+       @Override
        public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) 
{}
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/1c7ecbf2/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java 
b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
index 56642d9..93b670b 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
@@ -62,6 +62,7 @@ import 
org.apache.sysml.parser.common.CommonSyntacticValidator;
 import org.apache.sysml.parser.common.CustomErrorListener;
 import org.apache.sysml.parser.common.ExpressionInfo;
 import org.apache.sysml.parser.common.StatementInfo;
+import 
org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext;
 import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext;
 import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext;
 import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext;
@@ -922,6 +923,19 @@ public class DmlSyntacticValidator extends 
CommonSyntacticValidator implements D
        }
 
        @Override
+       public void 
exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {
+               if(ctx.targetList == null) {
+                       notifyErrorListeners("incorrect parsing for accumulator 
assignment", ctx.start);
+                       return;
+               }
+               //process as default assignment statement
+               exitAssignmentStatementHelper(ctx, ctx.targetList.getText(),
+                       ctx.targetList.dataInfo, ctx.targetList.start, 
ctx.source.info, ctx.info);
+               //mark as accumulator
+               ((AssignmentStatement)ctx.info.stmt).setAccumulator(true);
+       }
+       
+       @Override
        public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {
                checkValidDataType(ctx.ID().getText(), ctx.start);
        }
@@ -955,6 +969,8 @@ public class DmlSyntacticValidator extends 
CommonSyntacticValidator implements D
 
        @Override public void 
enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {}
 
+       @Override public void 
enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) 
{}
+       
        @Override public void 
enterConstStringIdExpression(ConstStringIdExpressionContext ctx) {}
 
        @Override public void 
enterConstTrueExpression(ConstTrueExpressionContext ctx) {}

Reply via email to