Repository: systemml
Updated Branches:
  refs/heads/master 97018d4e6 -> a62b65c8f


[SYSTEMML-2350] Fix missing support for lists in as.scalar casts

This patch fixes the missing support for list inputs in as.scalar casts,
which is necessary as a means to index scalars out of unnamed or named
lists because the list indexing itself still returns a list of one
element. Furthermore, this also improves the error handling of the
related runtime instruction.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a62b65c8
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a62b65c8
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a62b65c8

Branch: refs/heads/master
Commit: a62b65c8f61ed8cf0b009732f8cbdb7c8eda95e9
Parents: 97018d4
Author: Matthias Boehm <[email protected]>
Authored: Wed May 30 13:40:09 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed May 30 13:40:09 2018 -0700

----------------------------------------------------------------------
 .../sysml/parser/BuiltinFunctionExpression.java  | 14 ++++++++------
 .../java/org/apache/sysml/parser/Expression.java |  2 +-
 .../instructions/cp/VariableCPInstruction.java   | 19 ++++++++++++++-----
 .../mnist_lenet_paramserv_minimum_version.dml    |  4 ++--
 4 files changed, 25 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 0e949d0..ea51bd1 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -496,18 +496,20 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        output.setValueType(id.getValueType());
                        
                        break;
-                       
                case CAST_AS_SCALAR:
                        checkNumParameters(1);
-                       checkMatrixFrameParam(getFirstExpr());
-                       if (( getFirstExpr().getOutput().getDim1() != -1 && 
getFirstExpr().getOutput().getDim1() !=1) || ( 
getFirstExpr().getOutput().getDim2() != -1 && 
getFirstExpr().getOutput().getDim2() !=1)) {
-                               raiseValidateError("dimension mismatch while 
casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() +  " 
dim2 " + getFirstExpr().getOutput().getDim2(), 
-                                         conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                       checkDataTypeParam(getFirstExpr(),
+                               DataType.MATRIX, DataType.FRAME, DataType.LIST);
+                       if (( getFirstExpr().getOutput().getDim1() != -1 && 
getFirstExpr().getOutput().getDim1() !=1)
+                               || ( getFirstExpr().getOutput().getDim2() != -1 
&& getFirstExpr().getOutput().getDim2() !=1)) {
+                               raiseValidateError("dimension mismatch while 
casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() 
+                                       +  " dim2 " + 
getFirstExpr().getOutput().getDim2(), conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
                        }
                        output.setDataType(DataType.SCALAR);
                        output.setDimensions(0, 0);
                        output.setBlockDimensions (0, 0);
-                       output.setValueType(id.getValueType());
+                       
output.setValueType((id.getValueType()!=ValueType.UNKNOWN) ?
+                               id.getValueType() : ValueType.DOUBLE);
                        break;
                case CAST_AS_MATRIX:
                        checkNumParameters(1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index fd3f855..9a6ea64 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -194,7 +194,7 @@ public abstract class Expression implements ParseInfo
                public boolean isScalar() {
                        return (this == SCALAR);
                }
-               public boolean isComposite() {
+               public boolean isList() {
                        return (this == LIST);
                }
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
index 5786e87..b46f4df 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
@@ -553,7 +553,7 @@ public class VariableCPInstruction extends CPInstruction {
                        break;
                        
                case CastAsScalarVariable: //castAsScalarVariable
-                       if( getInput1().getDataType()==DataType.FRAME ) {
+                       if( getInput1().getDataType().isFrame() ) {
                                FrameBlock fBlock = 
ec.getFrameInput(getInput1().getName());
                                if( fBlock.getNumRows()!=1 || 
fBlock.getNumColumns()!=1 )
                                        throw new 
DMLRuntimeException("Dimension mismatch - unable to cast frame 
'"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x 
"+fBlock.getNumColumns()+") to scalar.");
@@ -562,7 +562,7 @@ public class VariableCPInstruction extends CPInstruction {
                                ec.setScalarOutput(output.getName(), 
                                                
ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value));
                        }
-                       else { //assume DataType.MATRIX otherwise
+                       else if( getInput1().getDataType().isMatrix() ) {
                                MatrixBlock mBlock = 
ec.getMatrixInput(getInput1().getName(), getExtendedOpcode());
                                if( mBlock.getNumRows()!=1 || 
mBlock.getNumColumns()!=1 )
                                        throw new 
DMLRuntimeException("Dimension mismatch - unable to cast matrix 
'"+getInput1().getName()+"' of dimension ("+mBlock.getNumRows()+" x 
"+mBlock.getNumColumns()+") to scalar.");
@@ -570,21 +570,30 @@ public class VariableCPInstruction extends CPInstruction {
                                ec.releaseMatrixInput(getInput1().getName(), 
getExtendedOpcode());
                                ec.setScalarOutput(output.getName(), new 
DoubleObject(value));
                        }
+                       else if( getInput1().getDataType().isList() ) {
+                               //TODO handling of cleanup status, potentially 
new object
+                               ListObject list = 
(ListObject)ec.getVariable(getInput1().getName());
+                               ec.setVariable(output.getName(), list.slice(0));
+                       }
+                       else {
+                               throw new DMLRuntimeException("Unsupported data 
type "
+                                       + "in as.scalar(): 
"+getInput1().getDataType().name());
+                       }
                        break;
                case CastAsMatrixVariable:{
-                       if( getInput1().getDataType()==DataType.FRAME ) {
+                       if( getInput1().getDataType().isFrame() ) {
                                FrameBlock fin = 
ec.getFrameInput(getInput1().getName());
                                MatrixBlock out = 
DataConverter.convertToMatrixBlock(fin);
                                ec.releaseFrameInput(getInput1().getName());
                                ec.setMatrixOutput(output.getName(), out, 
getExtendedOpcode());
                        }
-                       else if( getInput1().getDataType()==DataType.SCALAR ) {
+                       else if( getInput1().getDataType().isScalar() ) {
                                ScalarObject scalarInput = ec.getScalarInput(
                                        getInput1().getName(), 
getInput1().getValueType(), getInput1().isLiteral());
                                MatrixBlock out = new 
MatrixBlock(scalarInput.getDoubleValue());
                                ec.setMatrixOutput(output.getName(), out, 
getExtendedOpcode());
                        }
-                       else if( getInput1().getDataType()==DataType.LIST ) {
+                       else if( getInput1().getDataType().isList() ) {
                                //TODO handling of cleanup status, potentially 
new object
                                ListObject list = 
(ListObject)ec.getVariable(getInput1().getName());
                                ec.setVariable(output.getName(), list.slice(0));

http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index 2ef7411..8811c36 100644
--- 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -234,8 +234,8 @@ aggregation = function(list[unknown] model,
      vb2 = as.matrix(model["vb2"])
      vb3 = as.matrix(model["vb3"])
      vb4 = as.matrix(model["vb4"])
-     lr = 0.01
-     mu = 0.9
+     lr = as.scalar(hyperparams['lr']);
+     mu = as.scalar(hyperparams['mu']);
 
      # Optimize with SGD w/ Nesterov momentum
      [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)

Reply via email to