Repository: systemml Updated Branches: refs/heads/master 3bba03184 -> 08f9e3e47
[SYSTEMML-2415] Support for left indexing on list data types This patch adds support for left indexing operations over both unnamed and named list data types. We allow right-hand-side inputs of types list and scalar with both named and position indexing expressions. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/08f9e3e4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/08f9e3e4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/08f9e3e4 Branch: refs/heads/master Commit: 08f9e3e47836c729e91940bcb65c90780d25649b Parents: 3bba031 Author: Matthias Boehm <[email protected]> Authored: Wed Jun 20 21:16:05 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 20 21:17:02 2018 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/LeftIndexingOp.java | 3 + .../apache/sysml/parser/IndexedIdentifier.java | 9 +- .../instructions/cp/IndexingCPInstruction.java | 6 +- .../cp/ListIndexingCPInstruction.java | 42 +++--- .../runtime/instructions/cp/ListObject.java | 136 +++++++++++++------ .../functions/misc/ListAndStructTest.java | 24 ++++ src/test/scripts/functions/misc/ListNamedRix.R | 34 +++++ .../scripts/functions/misc/ListNamedRix.dml | 31 +++++ .../scripts/functions/misc/ListUnnamedRix.R | 34 +++++ .../scripts/functions/misc/ListUnnamedRix.dml | 29 ++++ 10 files changed, 280 insertions(+), 68 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java b/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java index ba355e2..29a95e2 100644 --- a/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java +++ b/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java @@ -388,6 +388,9 @@ public class LeftIndexingOp extends Hop checkAndSetInvalidCPDimsAndSize(); } + if( getInput().get(0).getDataType()==DataType.LIST ) + _etype = ExecType.CP; + //mark for recompile (forever) setRequiresRecompileIfNecessary(); http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java b/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java index a94ca94..91b982c 100644 --- a/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java +++ b/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java @@ -107,8 +107,7 @@ public class IndexedIdentifier extends DataIdentifier // valid lower row bound value isConst_rowLowerBound = true; } - - else if (_rowLowerBound instanceof ConstIdentifier) { + else if (_rowLowerBound instanceof ConstIdentifier && !getDataType().isList() ) { raiseValidateError("assign lower-bound row index for Indexed Identifier " + this.toString() + " the non-numeric value " + _rowLowerBound.toString(), conditional); } @@ -192,7 +191,7 @@ public class IndexedIdentifier extends DataIdentifier } isConst_rowUpperBound = true; } - else if (_rowUpperBound instanceof ConstIdentifier){ + else if (_rowUpperBound instanceof ConstIdentifier && !getDataType().isList()){ raiseValidateError("assign upper-bound row index for " + this.toString() + " the non-numeric value " + _rowUpperBound.toString(), conditional); } @@ -268,7 +267,7 @@ public class IndexedIdentifier extends DataIdentifier isConst_colLowerBound = true; } - else if (_colLowerBound instanceof ConstIdentifier) { + else if (_colLowerBound instanceof ConstIdentifier && !getDataType().isList()) { raiseValidateError("assign lower-bound column index for Indexed Identifier " + this.toString() + " the non-numeric value " + _colLowerBound.toString(), conditional); } @@ -352,7 +351,7 @@ public class IndexedIdentifier extends DataIdentifier isConst_colUpperBound = true; } - else if (_colUpperBound instanceof ConstIdentifier){ + else if (_colUpperBound instanceof ConstIdentifier && !getDataType().isList()){ raiseValidateError("assign upper-bound column index for " + this.toString() + " the non-numeric value " + _colUpperBound.toString(), conditional); } http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/main/java/org/apache/sysml/runtime/instructions/cp/IndexingCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/IndexingCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/IndexingCPInstruction.java index 0046713..fb4732d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/IndexingCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/IndexingCPInstruction.java @@ -76,7 +76,7 @@ public abstract class IndexingCPInstruction extends UnaryCPInstruction { else if( in.getDataType() == DataType.LIST ) return new ListIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str); else - throw new DMLRuntimeException("Can index only on Frames or Matrices"); + throw new DMLRuntimeException("Can index only on matrices, frames, and lists."); } else { throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); @@ -96,8 +96,10 @@ public abstract class IndexingCPInstruction extends UnaryCPInstruction { return new MatrixIndexingCPInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); else if (lhsInput.getDataType() == DataType.FRAME) return new FrameIndexingCPInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); + else if( lhsInput.getDataType() == DataType.LIST ) + return new ListIndexingCPInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); else - throw new DMLRuntimeException("Can index only on Frames or Matrices"); + throw new DMLRuntimeException("Can index only on matrices, frames, and lists."); } else { throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java index 4890439..f41601f 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java @@ -59,27 +59,27 @@ public final class ListIndexingCPInstruction extends IndexingCPInstruction { } //left indexing else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE)) { -// FrameBlock lin = ec.getFrameInput(input1.getName()); -// FrameBlock out = null; -// -// if(input2.getDataType() == DataType.FRAME) { //FRAME<-FRAME -// FrameBlock rin = ec.getFrameInput(input2.getName()); -// out = lin.leftIndexingOperations(rin, ixrange, new FrameBlock()); -// ec.releaseFrameInput(input2.getName()); -// } -// else { //FRAME<-SCALAR -// if(!ixrange.isScalar()) -// throw new DMLRuntimeException("Invalid index range of scalar leftindexing: "+ixrange.toString()+"." ); -// ScalarObject scalar = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()); -// out = new FrameBlock(lin); -// out.set((int)ixrange.rowStart, (int)ixrange.colStart, scalar.getStringValue()); -// } -// -// //unpin lhs input -// ec.releaseFrameInput(input1.getName()); -// -// //unpin output -// ec.setFrameOutput(output.getName(), out); + ListObject lin = (ListObject) ec.getVariable(input1.getName()); + + //execute right indexing operation and set output + if( input2.getDataType().isList() ) { //LIST <- LIST + ListObject rin = (ListObject) ec.getVariable(input2.getName()); + if( rl.getValueType()==ValueType.STRING || ru.getValueType()==ValueType.STRING ) + ec.setVariable(output.getName(), lin.copy().set(rl.getStringValue(), ru.getStringValue(), rin)); + else + ec.setVariable(output.getName(), lin.copy().set((int)rl.getLongValue()-1, (int)ru.getLongValue()-1, rin)); + } + else if( input2.getDataType().isScalar() ) { //LIST <- SCALAR + ScalarObject scalar = ec.getScalarInput(input2); + if( rl.getValueType()==ValueType.STRING ) + ec.setVariable(output.getName(), lin.copy().set(rl.getStringValue(), scalar)); + else + ec.setVariable(output.getName(), lin.copy().set((int)rl.getLongValue()-1, scalar)); + } + else { + throw new DMLRuntimeException("Unsupported list " + + "left indexing rhs type: "+input2.getDataType().name()); + } } else throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in ListIndexingCPInstruction."); http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java index 5a59d4f..d41dfc5 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java @@ -19,6 +19,7 @@ package org.apache.sysml.runtime.instructions.cp; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -30,9 +31,9 @@ import org.apache.sysml.runtime.controlprogram.caching.CacheableData; public class ListObject extends Data { private static final long serialVersionUID = 3652422061598967358L; - private final List<String> _names; private final List<Data> _data; private boolean[] _dataState = null; + private List<String> _names = null; public ListObject(List<Data> data) { super(DataType.LIST, ValueType.UNKNOWN); @@ -58,6 +59,31 @@ public class ListObject extends Data { return _data.size(); } + public List<String> getNames() { + return _names; + } + + public String getName(int ix) { + return (_names == null) ? null : _names.get(ix); + } + + public boolean isNamedList() { + return _names != null; + } + + public List<Data> getData() { + return _data; + } + + public long getDataSize() { + return _data.stream().filter(data -> data instanceof CacheableData) + .mapToLong(data -> ((CacheableData<?>) data).getDataSize()).sum(); + } + + public boolean checkAllDataTypes(DataType dt) { + return _data.stream().allMatch(d -> d.getDataType()==dt); + } + public Data slice(int ix) { return _data.get(ix); } @@ -71,59 +97,89 @@ public class ListObject extends Data { } public Data slice(String name) { - //check for existing named list - if (_names == null) - throw new DMLRuntimeException("Invalid lookup by name" + " in unnamed list: " + name + "."); - - //find position and check for existing entry - int pos = _names.indexOf(name); - if (pos < 0 || pos >= _data.size()) - throw new DMLRuntimeException("List lookup returned no entry for name='" + name + "'"); - + //lookup position by name, incl error handling + int pos = getPosForName(name); + //return existing entry return slice(pos); } public ListObject slice(String name1, String name2) { - //check for existing named list - if (_names == null) - throw new DMLRuntimeException("Invalid lookup by name" + " in unnamed list: " + name1 + ", " + name2 + "."); - - //find position and check for existing entry - int pos1 = _names.indexOf(name1); - int pos2 = _names.indexOf(name2); - if (pos1 < 0 || pos1 >= _data.size()) - throw new DMLRuntimeException("List lookup returned no entry for name='" + name1 + "'"); - if (pos2 < 0 || pos2 >= _data.size()) - throw new DMLRuntimeException("List lookup returned no entry for name='" + name2 + "'"); - + //lookup positions by name, incl error handling + int pos1 = getPosForName(name1); + int pos2 = getPosForName(name2); + //return list object return slice(pos1, pos2); } - - public List<String> getNames() { - return _names; + + public ListObject copy() { + ListObject ret = isNamedList() ? + new ListObject(new ArrayList<>(getData()), new ArrayList<>(getNames())) : + new ListObject(new ArrayList<>(getData())); + ret.setStatus(Arrays.copyOf(getStatus(), getLength())); + return ret; } - - public String getName(int ix) { - return (_names == null) ? null : _names.get(ix); + + public ListObject set(int ix, Data data) { + _data.set(ix, data); + return this; } - - public boolean isNamedList() { - return _names != null; + + public ListObject set(int ix1, int ix2, ListObject data) { + int range = ix2 - ix1 + 1; + if( range != data.getLength() || range > getLength() ) { + throw new DMLRuntimeException("List leftindexing size mismatch: length(lhs)=" + +getLength()+", range=["+ix1+":"+ix2+"], legnth(rhs)="+data.getLength()); + } + + //copy rhs list object including meta data + if( range == getLength() ) { + //overwrite all entries in left hand side + _data.clear(); _data.addAll(data.getData()); + System.arraycopy(data.getStatus(), 0, _dataState, 0, range); + if( data.isNamedList() ) + _names = new ArrayList<>(data.getNames()); + } + else { + //overwrite entries of subrange in left hand side + for( int i=ix1; i<=ix2; i++ ) { + set(i, data.slice(i-ix1)); + _dataState[i] = data._dataState[i-ix1]; + if( isNamedList() && data.isNamedList() ) + _names.set(i, data.getName(i-ix1)); + } + } + return this; } - - public List<Data> getData() { - return _data; + + public Data set(String name, Data data) { + //lookup position by name, incl error handling + int pos = getPosForName(name); + + //set entry into position + return set(pos, data); } - - public long getDataSize() { - return _data.stream().filter(data -> data instanceof CacheableData) - .mapToLong(data -> ((CacheableData<?>) data).getDataSize()).sum(); + + public ListObject set(String name1, String name2, ListObject data) { + //lookup positions by name, incl error handling + int pos1 = getPosForName(name1); + int pos2 = getPosForName(name2); + + //set list into position range + return set(pos1, pos2, data); } - public boolean checkAllDataTypes(DataType dt) { - return _data.stream().allMatch(d -> d.getDataType()==dt); + private int getPosForName(String name) { + //check for existing named list + if (_names == null) + throw new DMLRuntimeException("Invalid indexing by name" + " in unnamed list: " + name + "."); + + //find position and check for existing entry + int pos = _names.indexOf(name); + if (pos < 0 || pos >= _data.size()) + throw new DMLRuntimeException("List indexing returned no entry for name='" + name + "'"); + return pos; } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java index 5129785..e8e3d1a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/ListAndStructTest.java @@ -40,6 +40,8 @@ public class ListAndStructTest extends AutomatedTestBase private static final String TEST_NAME5 = "ListUnnamedParfor"; private static final String TEST_NAME6 = "ListNamedParfor"; private static final String TEST_NAME7 = "ListAsMatrix"; + private static final String TEST_NAME8 = "ListUnnamedRix"; + private static final String TEST_NAME9 = "ListNamedRix"; private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + ListAndStructTest.class.getSimpleName() + "/"; @@ -54,6 +56,8 @@ public class ListAndStructTest extends AutomatedTestBase addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) ); addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) ); addTestConfiguration( TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] { "R" }) ); } @Test @@ -126,6 +130,26 @@ public class ListAndStructTest extends AutomatedTestBase runListStructTest(TEST_NAME7, true); } + @Test + public void testListRix() { + runListStructTest(TEST_NAME8, false); + } + + @Test + public void testListRixRewrites() { + runListStructTest(TEST_NAME8, true); + } + + @Test + public void testListNamedRix() { + runListStructTest(TEST_NAME9, false); + } + + @Test + public void testListNamedRixRewrites() { + runListStructTest(TEST_NAME9, true); + } + private void runListStructTest(String testname, boolean rewrites) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/test/scripts/functions/misc/ListNamedRix.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/ListNamedRix.R b/src/test/scripts/functions/misc/ListNamedRix.R new file mode 100644 index 0000000..a3c5b67 --- /dev/null +++ b/src/test/scripts/functions/misc/ListNamedRix.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") + +#X = list(1,3,7,5,4); +X = list(1,0,0,0,0); +X[2:4] = list(3,7,5); +X[5] = 4; +Y = as.matrix(unlist(X)); +R = as.matrix(nrow(Y) * sum(Y) + ncol(Y)); + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/test/scripts/functions/misc/ListNamedRix.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/ListNamedRix.dml b/src/test/scripts/functions/misc/ListNamedRix.dml new file mode 100644 index 0000000..2b5cabb --- /dev/null +++ b/src/test/scripts/functions/misc/ListNamedRix.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = list(a=0,b=0,c=0,d=0,e=0); +X[2:4] = list(3,7,5); +X["e"] = 4; +X[1] = list(f=2); +X["f"] = 1; + +Y = as.matrix(X); +R = as.matrix(nrow(Y) * sum(Y) + ncol(Y)); + +write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/test/scripts/functions/misc/ListUnnamedRix.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/ListUnnamedRix.R b/src/test/scripts/functions/misc/ListUnnamedRix.R new file mode 100644 index 0000000..a3c5b67 --- /dev/null +++ b/src/test/scripts/functions/misc/ListUnnamedRix.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") + +#X = list(1,3,7,5,4); +X = list(1,0,0,0,0); +X[2:4] = list(3,7,5); +X[5] = 4; +Y = as.matrix(unlist(X)); +R = as.matrix(nrow(Y) * sum(Y) + ncol(Y)); + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/08f9e3e4/src/test/scripts/functions/misc/ListUnnamedRix.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/ListUnnamedRix.dml b/src/test/scripts/functions/misc/ListUnnamedRix.dml new file mode 100644 index 0000000..aa15554 --- /dev/null +++ b/src/test/scripts/functions/misc/ListUnnamedRix.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +#X = list(1,3,7,5,4); +X = list(1,0,0,0,0); +X[2:4] = list(3,7,5); +X[5] = 4; +Y = as.matrix(X); +R = as.matrix(nrow(Y) * sum(Y) + ncol(Y)); + +write(R, $1);
