[SYSTEMML-2405] Support for as.matrix over lists of scalars

This patch adds a convenience feature for creating matrices out of lists
of scalars and necessary compiler/runtime extensions. 


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/fff0aa46
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/fff0aa46
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/fff0aa46

Branch: refs/heads/master
Commit: fff0aa469dc41fdd73b7c364095d596c6be9dd65
Parents: 3705e78
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 16 19:21:01 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 16 19:21:01 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/UnaryOp.java     | 11 +++----
 .../hops/recompile/LiteralReplacement.java      | 12 ++++----
 .../runtime/instructions/cp/ListObject.java     |  4 +++
 .../instructions/cp/VariableCPInstruction.java  | 20 ++++++++++++-
 .../functions/misc/ListAndStructTest.java       | 12 ++++++++
 src/test/scripts/functions/misc/ListAsMatrix.R  | 31 ++++++++++++++++++++
 .../scripts/functions/misc/ListAsMatrix.dml     | 26 ++++++++++++++++
 7 files changed, 105 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java 
b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index fedc557..19b6ed3 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -684,16 +684,18 @@ public class UnaryOp extends MultiThreadedHop
        @Override
        public void refreshSizeInformation()
        {
+               Hop input = getInput().get(0);
                if ( getDataType() == DataType.SCALAR )  {
                        //do nothing always known
                }
                else if( (_op == OpOp1.CAST_AS_MATRIX || _op == 
OpOp1.CAST_AS_FRAME
-                       || _op == OpOp1.CAST_AS_SCALAR) && 
getInput().get(0).getDataType()==DataType.LIST ){
-                       setDim1( -1 );
-                       setDim2( -1 );
+                       || _op == OpOp1.CAST_AS_SCALAR) && 
input.getDataType()==DataType.LIST ){
+                       //handle two cases of list of scalars or list of single 
matrix
+                       setDim1( input.getLength() > 1 ? input.getLength() : -1 
);
+                       setDim2( input.getLength() > 1 ? 1 : -1 );
                }
                else if( (_op == OpOp1.CAST_AS_MATRIX || _op == 
OpOp1.CAST_AS_FRAME)
-                       && getInput().get(0).getDataType()==DataType.SCALAR )
+                       && input.getDataType()==DataType.SCALAR )
                {
                        //prevent propagating 0 from scalar (which would be 
interpreted as unknown)
                        setDim1( 1 );
@@ -703,7 +705,6 @@ public class UnaryOp extends MultiThreadedHop
                {
                        // If output is a Matrix then this operation is of type 
(B = op(A))
                        // Dimensions of B are same as that of A, and sparsity 
may/maynot change
-                       Hop input = getInput().get(0);
                        setDim1( input.getDim1() );
                        setDim2( input.getDim2() );
                        // cosh(0)=cos(0)=1, acos(0)=1.5707963267948966

http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java 
b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
index 7c5014c..d6bac40 100644
--- a/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
+++ b/src/main/java/org/apache/sysml/hops/recompile/LiteralReplacement.java
@@ -328,7 +328,7 @@ public class LiteralReplacement
                                long clval = getIntValueDataLiteral(cl, vars);
                                long cuval = getIntValueDataLiteral(cu, vars);
 
-                               MatrixObject mo = (MatrixObject) 
vars.get(data.getName());      
+                               MatrixObject mo = (MatrixObject) 
vars.get(data.getName());
                                
                                //get the dimension information from the matrix 
object because the hop
                                //dimensions might not have been updated during 
recompile
@@ -356,10 +356,12 @@ public class LiteralReplacement
                        if( in.getDataType() == DataType.LIST
                                && HopRewriteUtils.isData(in, 
DataOpTypes.TRANSIENTREAD) ) {
                                ListObject list = 
(ListObject)vars.get(in.getName());
-                               String varname = 
Dag.getNextUniqueVarname(DataType.MATRIX);
-                               MatrixObject mo = (MatrixObject) list.slice(0);
-                               vars.put(varname, mo);
-                               ret = 
HopRewriteUtils.createTransientRead(varname, c);
+                               if( list.getLength() == 1 ) {
+                                       String varname = 
Dag.getNextUniqueVarname(DataType.MATRIX);
+                                       MatrixObject mo = (MatrixObject) 
list.slice(0);
+                                       vars.put(varname, mo);
+                                       ret = 
HopRewriteUtils.createTransientRead(varname, c);
+                               }
                        }
                }
                return ret;

http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
index 039f1ed..5a59d4f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
@@ -121,6 +121,10 @@ public class ListObject extends Data {
                return _data.stream().filter(data -> data instanceof 
CacheableData)
                        .mapToLong(data -> ((CacheableData<?>) 
data).getDataSize()).sum();
        }
+       
+       public boolean checkAllDataTypes(DataType dt) {
+               return _data.stream().allMatch(d -> d.getDataType()==dt);
+       }
 
        @Override
        public String getDebugName() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
index 26fcb2e..cb8a1b4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java
@@ -600,7 +600,25 @@ public class VariableCPInstruction extends CPInstruction {
                        else if( getInput1().getDataType().isList() ) {
                                //TODO handling of cleanup status, potentially 
new object
                                ListObject list = 
(ListObject)ec.getVariable(getInput1().getName());
-                               ec.setVariable(output.getName(), list.slice(0));
+                               if( list.getLength() > 1 ) {
+                                       if( 
!list.checkAllDataTypes(DataType.SCALAR) )
+                                               throw new 
DMLRuntimeException("as.matrix over multi-entry list only allows scalars.");
+                                       MatrixBlock out = new 
MatrixBlock(list.getLength(), 1, false);
+                                       for( int i=0; i<list.getLength(); i++ )
+                                               out.quickSetValue(i, 0, 
((ScalarObject)list.slice(i)).getDoubleValue());
+                                       ec.setMatrixOutput(output.getName(), 
out, getExtendedOpcode());
+                               }
+                               else {
+                                       //pass through matrix input or create 
1x1 matrix for scalar
+                                       Data tmp = list.slice(0);
+                                       if( tmp instanceof ScalarObject && 
tmp.getValueType()!=ValueType.STRING ) {
+                                               MatrixBlock out = new 
MatrixBlock(((ScalarObject)tmp).getDoubleValue());
+                                               
ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
+                                       } 
+                                       else {
+                                               
ec.setVariable(output.getName(), tmp);
+                                       }
+                               }
                        }
                        else {
                                throw new DMLRuntimeException("Unsupported data 
type "

http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java
index b49f84c..5129785 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java
@@ -39,6 +39,7 @@ public class ListAndStructTest extends AutomatedTestBase
        private static final String TEST_NAME4 = "ListNamedFun";
        private static final String TEST_NAME5 = "ListUnnamedParfor";
        private static final String TEST_NAME6 = "ListNamedParfor";
+       private static final String TEST_NAME7 = "ListAsMatrix";
        
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
ListAndStructTest.class.getSimpleName() + "/";
@@ -52,6 +53,7 @@ public class ListAndStructTest extends AutomatedTestBase
                addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] { "R" }) );
        }
        
        @Test
@@ -114,6 +116,16 @@ public class ListAndStructTest extends AutomatedTestBase
                runListStructTest(TEST_NAME6, true);
        }
        
+       @Test
+       public void testListAsMatrix() {
+               runListStructTest(TEST_NAME7, false);
+       }
+       
+       @Test
+       public void testListAsMatrixRewrites() {
+               runListStructTest(TEST_NAME7, true);
+       }
+       
        private void runListStructTest(String testname, boolean rewrites)
        {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/test/scripts/functions/misc/ListAsMatrix.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/ListAsMatrix.R 
b/src/test/scripts/functions/misc/ListAsMatrix.R
new file mode 100644
index 0000000..506879e
--- /dev/null
+++ b/src/test/scripts/functions/misc/ListAsMatrix.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = list(1,3,7,5,4);
+Y = as.matrix(unlist(X));
+R = as.matrix(nrow(Y) * sum(Y) + ncol(Y));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/fff0aa46/src/test/scripts/functions/misc/ListAsMatrix.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/ListAsMatrix.dml 
b/src/test/scripts/functions/misc/ListAsMatrix.dml
new file mode 100644
index 0000000..33ab8aa
--- /dev/null
+++ b/src/test/scripts/functions/misc/ListAsMatrix.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = list(1,3,7,5,4);
+Y = as.matrix(X);
+R = as.matrix(nrow(Y) * sum(Y) + ncol(Y));
+
+write(R, $1);

Reply via email to