[SYSTEMML-1714] Fix codegen rowwise correctness scalar-vector ops This patch fixes result correctness issues of the codegen row-wise template for scalar-vector operations if the vector is extracted from side inputs. An example, where the underlying issue led to wrong algorithm results was lenet over mnist.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/23a164a8 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/23a164a8 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/23a164a8 Branch: refs/heads/master Commit: 23a164a83c480dc78b2df9da099a5140c7572b7e Parents: f7a18fa Author: Matthias Boehm <[email protected]> Authored: Sat Jun 17 19:51:30 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 17 19:51:30 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeBinary.java | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/23a164a8/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index ac0ea6f..7ed2408 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -67,7 +67,7 @@ public class CNodeBinary extends CNode return ssComm || vsComm || vvComm; } - public String getTemplate(boolean sparse) { + public String getTemplate(boolean sparse, boolean scalarVector) { switch (this) { case DOT_PRODUCT: return sparse ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, %LEN%);\n" : @@ -88,8 +88,12 @@ public class CNodeBinary extends CNode case VECT_GREATER_ADD: case VECT_GREATEREQUAL_ADD: { String vectName = getVectorPrimitiveName(); - return sparse ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, %LEN%);\n" : - " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n"; + if( scalarVector ) + return sparse ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, %LEN%);\n" : + " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n"; + else + return sparse ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, %LEN%);\n" : + " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n"; } //vector-scalar operations @@ -107,8 +111,12 @@ public class CNodeBinary extends CNode case VECT_GREATER_SCALAR: case VECT_GREATEREQUAL_SCALAR: { String vectName = getVectorPrimitiveName(); - return sparse ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %LEN%);\n" : - " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n"; + if( scalarVector ) + return sparse ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS2%, %LEN%);\n"; + else + return sparse ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n"; } //vector-vector operations @@ -239,8 +247,10 @@ public class CNodeBinary extends CNode boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData && !_inputs.get(0).getVarname().startsWith("b") && !_inputs.get(0).isLiteral()); + boolean scalarVector = (_inputs.get(0).getDataType().isScalar() + && _inputs.get(1).getDataType().isMatrix()); String var = createVarname(); - String tmp = _type.getTemplate(lsparse); + String tmp = _type.getTemplate(lsparse, scalarVector); tmp = tmp.replaceAll("%TMP%", var); //replace input references and start indexes @@ -346,8 +356,9 @@ public class CNodeBinary extends CNode case VECT_LESSEQUAL_ADD: case VECT_GREATER_ADD: case VECT_GREATEREQUAL_ADD: - _rows = _inputs.get(1)._rows; - _cols = _inputs.get(1)._cols; + boolean vectorScalar = _inputs.get(1).getDataType()==DataType.SCALAR; + _rows = _inputs.get(vectorScalar ? 0 : 1)._rows; + _cols = _inputs.get(vectorScalar ? 0 : 1)._cols; _dataType= DataType.MATRIX; break; @@ -377,8 +388,9 @@ public class CNodeBinary extends CNode case VECT_LESSEQUAL: case VECT_GREATER: case VECT_GREATEREQUAL: - _rows = _inputs.get(0)._rows; - _cols = _inputs.get(0)._cols; + boolean scalarVector = (_inputs.get(0).getDataType()==DataType.SCALAR); + _rows = _inputs.get(scalarVector ? 1 : 0)._rows; + _cols = _inputs.get(scalarVector ? 1 : 0)._cols; _dataType= DataType.MATRIX; break;
