Repository: systemml Updated Branches: refs/heads/master 97018d4e6 -> a62b65c8f
[SYSTEMML-2350] Fix missing support for lists in as.scalar casts This patch fixes the missing support for list inputs in as.scalar casts, which is necessary as a means to index scalars out of unnamed or named lists because the list indexing itself still returns a list of one element. Furthermore, this also improves the error handling of the related runtime instruction. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a62b65c8 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a62b65c8 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a62b65c8 Branch: refs/heads/master Commit: a62b65c8f61ed8cf0b009732f8cbdb7c8eda95e9 Parents: 97018d4 Author: Matthias Boehm <[email protected]> Authored: Wed May 30 13:40:09 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed May 30 13:40:09 2018 -0700 ---------------------------------------------------------------------- .../sysml/parser/BuiltinFunctionExpression.java | 14 ++++++++------ .../java/org/apache/sysml/parser/Expression.java | 2 +- .../instructions/cp/VariableCPInstruction.java | 19 ++++++++++++++----- .../mnist_lenet_paramserv_minimum_version.dml | 4 ++-- 4 files changed, 25 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index 0e949d0..ea51bd1 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -496,18 +496,20 @@ public class BuiltinFunctionExpression extends DataIdentifier output.setValueType(id.getValueType()); break; - case CAST_AS_SCALAR: checkNumParameters(1); - checkMatrixFrameParam(getFirstExpr()); - if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1) || ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) { - raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() + " dim2 " + getFirstExpr().getOutput().getDim2(), - conditional, LanguageErrorCodes.INVALID_PARAMETERS); + checkDataTypeParam(getFirstExpr(), + DataType.MATRIX, DataType.FRAME, DataType.LIST); + if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1) + || ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) { + raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() + + " dim2 " + getFirstExpr().getOutput().getDim2(), conditional, LanguageErrorCodes.INVALID_PARAMETERS); } output.setDataType(DataType.SCALAR); output.setDimensions(0, 0); output.setBlockDimensions (0, 0); - output.setValueType(id.getValueType()); + output.setValueType((id.getValueType()!=ValueType.UNKNOWN) ? + id.getValueType() : ValueType.DOUBLE); break; case CAST_AS_MATRIX: checkNumParameters(1); http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index fd3f855..9a6ea64 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -194,7 +194,7 @@ public abstract class Expression implements ParseInfo public boolean isScalar() { return (this == SCALAR); } - public boolean isComposite() { + public boolean isList() { return (this == LIST); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java index 5786e87..b46f4df 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java @@ -553,7 +553,7 @@ public class VariableCPInstruction extends CPInstruction { break; case CastAsScalarVariable: //castAsScalarVariable - if( getInput1().getDataType()==DataType.FRAME ) { + if( getInput1().getDataType().isFrame() ) { FrameBlock fBlock = ec.getFrameInput(getInput1().getName()); if( fBlock.getNumRows()!=1 || fBlock.getNumColumns()!=1 ) throw new DMLRuntimeException("Dimension mismatch - unable to cast frame '"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x "+fBlock.getNumColumns()+") to scalar."); @@ -562,7 +562,7 @@ public class VariableCPInstruction extends CPInstruction { ec.setScalarOutput(output.getName(), ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value)); } - else { //assume DataType.MATRIX otherwise + else if( getInput1().getDataType().isMatrix() ) { MatrixBlock mBlock = ec.getMatrixInput(getInput1().getName(), getExtendedOpcode()); if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 ) throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix '"+getInput1().getName()+"' of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar."); @@ -570,21 +570,30 @@ public class VariableCPInstruction extends CPInstruction { ec.releaseMatrixInput(getInput1().getName(), getExtendedOpcode()); ec.setScalarOutput(output.getName(), new DoubleObject(value)); } + else if( getInput1().getDataType().isList() ) { + //TODO handling of cleanup status, potentially new object + ListObject list = (ListObject)ec.getVariable(getInput1().getName()); + ec.setVariable(output.getName(), list.slice(0)); + } + else { + throw new DMLRuntimeException("Unsupported data type " + + "in as.scalar(): "+getInput1().getDataType().name()); + } break; case CastAsMatrixVariable:{ - if( getInput1().getDataType()==DataType.FRAME ) { + if( getInput1().getDataType().isFrame() ) { FrameBlock fin = ec.getFrameInput(getInput1().getName()); MatrixBlock out = DataConverter.convertToMatrixBlock(fin); ec.releaseFrameInput(getInput1().getName()); ec.setMatrixOutput(output.getName(), out, getExtendedOpcode()); } - else if( getInput1().getDataType()==DataType.SCALAR ) { + else if( getInput1().getDataType().isScalar() ) { ScalarObject scalarInput = ec.getScalarInput( getInput1().getName(), getInput1().getValueType(), getInput1().isLiteral()); MatrixBlock out = new MatrixBlock(scalarInput.getDoubleValue()); ec.setMatrixOutput(output.getName(), out, getExtendedOpcode()); } - else if( getInput1().getDataType()==DataType.LIST ) { + else if( getInput1().getDataType().isList() ) { //TODO handling of cleanup status, potentially new object ListObject list = (ListObject)ec.getVariable(getInput1().getName()); ec.setVariable(output.getName(), list.slice(0)); http://git-wip-us.apache.org/repos/asf/systemml/blob/a62b65c8/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml index 2ef7411..8811c36 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml @@ -234,8 +234,8 @@ aggregation = function(list[unknown] model, vb2 = as.matrix(model["vb2"]) vb3 = as.matrix(model["vb3"]) vb4 = as.matrix(model["vb4"]) - lr = 0.01 - mu = 0.9 + lr = as.scalar(hyperparams['lr']); + mu = as.scalar(hyperparams['mu']); # Optimize with SGD w/ Nesterov momentum [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
