[SYSTEMML-1714] Fix codegen rowwise correctness scalar-vector ops

This patch fixes result correctness issues of the codegen row-wise
template for scalar-vector operations if the vector is extracted from
side inputs. An example, where the underlying issue led to wrong
algorithm results was lenet over mnist. 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/23a164a8
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/23a164a8
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/23a164a8

Branch: refs/heads/master
Commit: 23a164a83c480dc78b2df9da099a5140c7572b7e
Parents: f7a18fa
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 17 19:51:30 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 17 19:51:30 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeBinary.java   | 32 ++++++++++++++------
 1 file changed, 22 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/23a164a8/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index ac0ea6f..7ed2408 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -67,7 +67,7 @@ public class CNodeBinary extends CNode
                        return ssComm || vsComm || vvComm;
                }
                
-               public String getTemplate(boolean sparse) {
+               public String getTemplate(boolean sparse, boolean scalarVector) 
{
                        switch (this) {
                                case DOT_PRODUCT:   
                                        return sparse ? "    double %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, %LEN%);\n" 
:
@@ -88,8 +88,12 @@ public class CNodeBinary extends CNode
                                case VECT_GREATER_ADD:
                                case VECT_GREATEREQUAL_ADD: {
                                        String vectName = 
getVectorPrimitiveName();
-                                       return sparse ? "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, 
%POSOUT%, %LEN%);\n" : 
-                                                                       "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, 
%LEN%);\n";
+                                       if( scalarVector )
+                                               return sparse ? "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, 
%POSOUT%, %LEN%);\n" : 
+                                                                               
"    LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, 
%POSOUT%, %LEN%);\n";
+                                       else    
+                                               return sparse ? "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, 
%POSOUT%, %LEN%);\n" : 
+                                                                               
"    LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, 
%POSOUT%, %LEN%);\n";
                                }
                                
                                //vector-scalar operations
@@ -107,8 +111,12 @@ public class CNodeBinary extends CNode
                                case VECT_GREATER_SCALAR:
                                case VECT_GREATEREQUAL_SCALAR: {
                                        String vectName = 
getVectorPrimitiveName();
-                                       return sparse ? "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, 
%LEN%);\n" : 
-                                                                       "    
double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, 
%LEN%);\n";
+                                       if( scalarVector )
+                                               return sparse ? "    double[] 
%TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, 
%LEN%);\n" : 
+                                                                               
"    double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, 
%POS2%, %LEN%);\n";
+                                       else    
+                                               return sparse ? "    double[] 
%TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, 
%LEN%);\n" : 
+                                                                               
"    double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, 
%POS1%, %LEN%);\n";
                                }
                                
                                //vector-vector operations
@@ -239,8 +247,10 @@ public class CNodeBinary extends CNode
                boolean lsparse = sparse && (_inputs.get(0) instanceof 
CNodeData 
                        && !_inputs.get(0).getVarname().startsWith("b")
                        && !_inputs.get(0).isLiteral());
+               boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
+                       && _inputs.get(1).getDataType().isMatrix());
                String var = createVarname();
-               String tmp = _type.getTemplate(lsparse);
+               String tmp = _type.getTemplate(lsparse, scalarVector);
                tmp = tmp.replaceAll("%TMP%", var);
                
                //replace input references and start indexes
@@ -346,8 +356,9 @@ public class CNodeBinary extends CNode
                        case VECT_LESSEQUAL_ADD: 
                        case VECT_GREATER_ADD: 
                        case VECT_GREATEREQUAL_ADD:
-                               _rows = _inputs.get(1)._rows;
-                               _cols = _inputs.get(1)._cols;
+                               boolean vectorScalar = 
_inputs.get(1).getDataType()==DataType.SCALAR;
+                               _rows = _inputs.get(vectorScalar ? 0 : 1)._rows;
+                               _cols = _inputs.get(vectorScalar ? 0 : 1)._cols;
                                _dataType= DataType.MATRIX;
                                break;
                                
@@ -377,8 +388,9 @@ public class CNodeBinary extends CNode
                        case VECT_LESSEQUAL: 
                        case VECT_GREATER: 
                        case VECT_GREATEREQUAL: 
-                               _rows = _inputs.get(0)._rows;
-                               _cols = _inputs.get(0)._cols;
+                               boolean scalarVector = 
(_inputs.get(0).getDataType()==DataType.SCALAR);
+                               _rows = _inputs.get(scalarVector ? 1 : 0)._rows;
+                               _cols = _inputs.get(scalarVector ? 1 : 0)._cols;
                                _dataType= DataType.MATRIX;
                                break;
                                

Reply via email to