[SYSTEMML-2405] Support for as.matrix over lists of scalars This patch adds a convenience feature for creating matrices out of lists of scalars and necessary compiler/runtime extensions.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/fff0aa46 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/fff0aa46 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/fff0aa46 Branch: refs/heads/master Commit: fff0aa469dc41fdd73b7c364095d596c6be9dd65 Parents: 3705e78 Author: Matthias Boehm <[email protected]> Authored: Sat Jun 16 19:21:01 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 16 19:21:01 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/UnaryOp.java | 11 +++---- .../hops/recompile/LiteralReplacement.java | 12 ++++---- .../runtime/instructions/cp/ListObject.java | 4 +++ .../instructions/cp/VariableCPInstruction.java | 20 ++++++++++++- .../functions/misc/ListAndStructTest.java | 12 ++++++++ src/test/scripts/functions/misc/ListAsMatrix.R | 31 ++++++++++++++++++++ .../scripts/functions/misc/ListAsMatrix.dml | 26 ++++++++++++++++ 7 files changed, 105 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/hops/UnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java index fedc557..19b6ed3 100644 --- a/src/main/java/org/apache/sysml/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java @@ -684,16 +684,18 @@ public class UnaryOp extends MultiThreadedHop @Override public void refreshSizeInformation() { + Hop input = getInput().get(0); if ( getDataType() == DataType.SCALAR ) { //do nothing always known } else if( (_op == OpOp1.CAST_AS_MATRIX || _op == OpOp1.CAST_AS_FRAME - || _op == OpOp1.CAST_AS_SCALAR) && getInput().get(0).getDataType()==DataType.LIST ){ - setDim1( -1 ); - setDim2( -1 ); + || _op == OpOp1.CAST_AS_SCALAR) && input.getDataType()==DataType.LIST ){ + //handle two cases of list of scalars or list of single matrix + setDim1( input.getLength() > 1 ? input.getLength() : -1 ); + setDim2( input.getLength() > 1 ? 1 : -1 ); } else if( (_op == OpOp1.CAST_AS_MATRIX || _op == OpOp1.CAST_AS_FRAME) - && getInput().get(0).getDataType()==DataType.SCALAR ) + && input.getDataType()==DataType.SCALAR ) { //prevent propagating 0 from scalar (which would be interpreted as unknown) setDim1( 1 ); @@ -703,7 +705,6 @@ public class UnaryOp extends MultiThreadedHop { // If output is a Matrix then this operation is of type (B = op(A)) // Dimensions of B are same as that of A, and sparsity may/maynot change - Hop input = getInput().get(0); setDim1( input.getDim1() ); setDim2( input.getDim2() ); // cosh(0)=cos(0)=1, acos(0)=1.5707963267948966 http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java index 7c5014c..d6bac40 100644 --- a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java +++ b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java @@ -328,7 +328,7 @@ public class LiteralReplacement long clval = getIntValueDataLiteral(cl, vars); long cuval = getIntValueDataLiteral(cu, vars); - MatrixObject mo = (MatrixObject) vars.get(data.getName()); + MatrixObject mo = (MatrixObject) vars.get(data.getName()); //get the dimension information from the matrix object because the hop //dimensions might not have been updated during recompile @@ -356,10 +356,12 @@ public class LiteralReplacement if( in.getDataType() == DataType.LIST && HopRewriteUtils.isData(in, DataOpTypes.TRANSIENTREAD) ) { ListObject list = (ListObject)vars.get(in.getName()); - String varname = Dag.getNextUniqueVarname(DataType.MATRIX); - MatrixObject mo = (MatrixObject) list.slice(0); - vars.put(varname, mo); - ret = HopRewriteUtils.createTransientRead(varname, c); + if( list.getLength() == 1 ) { + String varname = Dag.getNextUniqueVarname(DataType.MATRIX); + MatrixObject mo = (MatrixObject) list.slice(0); + vars.put(varname, mo); + ret = HopRewriteUtils.createTransientRead(varname, c); + } } } return ret; http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java index 039f1ed..5a59d4f 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java @@ -121,6 +121,10 @@ public class ListObject extends Data { return _data.stream().filter(data -> data instanceof CacheableData) .mapToLong(data -> ((CacheableData<?>) data).getDataSize()).sum(); } + + public boolean checkAllDataTypes(DataType dt) { + return _data.stream().allMatch(d -> d.getDataType()==dt); + } @Override public String getDebugName() { http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java index 26fcb2e..cb8a1b4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java @@ -600,7 +600,25 @@ public class VariableCPInstruction extends CPInstruction { else if( getInput1().getDataType().isList() ) { //TODO handling of cleanup status, potentially new object ListObject list = (ListObject)ec.getVariable(getInput1().getName()); - ec.setVariable(output.getName(), list.slice(0)); + if( list.getLength() > 1 ) { + if( !list.checkAllDataTypes(DataType.SCALAR) ) + throw new DMLRuntimeException("as.matrix over multi-entry list only allows scalars."); + MatrixBlock out = new MatrixBlock(list.getLength(), 1, false); + for( int i=0; i<list.getLength(); i++ ) + out.quickSetValue(i, 0, ((ScalarObject)list.slice(i)).getDoubleValue()); + ec.setMatrixOutput(output.getName(), out, getExtendedOpcode()); + } + else { + //pass through matrix input or create 1x1 matrix for scalar + Data tmp = list.slice(0); + if( tmp instanceof ScalarObject && tmp.getValueType()!=ValueType.STRING ) { + MatrixBlock out = new MatrixBlock(((ScalarObject)tmp).getDoubleValue()); + ec.setMatrixOutput(output.getName(), out, getExtendedOpcode()); + } + else { + ec.setVariable(output.getName(), tmp); + } + } } else { throw new DMLRuntimeException("Unsupported data type " http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java index b49f84c..5129785 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java @@ -39,6 +39,7 @@ public class ListAndStructTest extends AutomatedTestBase private static final String TEST_NAME4 = "ListNamedFun"; private static final String TEST_NAME5 = "ListUnnamedParfor"; private static final String TEST_NAME6 = "ListNamedParfor"; + private static final String TEST_NAME7 = "ListAsMatrix"; private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + ListAndStructTest.class.getSimpleName() + "/"; @@ -52,6 +53,7 @@ public class ListAndStructTest extends AutomatedTestBase addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) ); addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] { "R" }) ); } @Test @@ -114,6 +116,16 @@ public class ListAndStructTest extends AutomatedTestBase runListStructTest(TEST_NAME6, true); } + @Test + public void testListAsMatrix() { + runListStructTest(TEST_NAME7, false); + } + + @Test + public void testListAsMatrixRewrites() { + runListStructTest(TEST_NAME7, true); + } + private void runListStructTest(String testname, boolean rewrites) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/test/scripts/functions/misc/ListAsMatrix.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/ListAsMatrix.R b/src/test/scripts/functions/misc/ListAsMatrix.R new file mode 100644 index 0000000..506879e --- /dev/null +++ b/src/test/scripts/functions/misc/ListAsMatrix.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = list(1,3,7,5,4); +Y = as.matrix(unlist(X)); +R = as.matrix(nrow(Y) * sum(Y) + ncol(Y)); + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/test/scripts/functions/misc/ListAsMatrix.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/ListAsMatrix.dml b/src/test/scripts/functions/misc/ListAsMatrix.dml new file mode 100644 index 0000000..33ab8aa --- /dev/null +++ b/src/test/scripts/functions/misc/ListAsMatrix.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = list(1,3,7,5,4); +Y = as.matrix(X); +R = as.matrix(nrow(Y) * sum(Y) + ncol(Y)); + +write(R, $1);
