Repository: incubator-systemml Updated Branches: refs/heads/master cfc73fefe -> eeb4f2708
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index 3b067dd..ab1f14b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -38,7 +38,8 @@ import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.codegen.SpoofOperator; import org.apache.sysml.runtime.codegen.SpoofOuterProduct; import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; -import org.apache.sysml.runtime.codegen.SpoofRowAggregate; +import org.apache.sysml.runtime.codegen.SpoofRowwise; +import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.functionobjects.Builtin; @@ -195,11 +196,36 @@ public class SpoofSPInstruction extends SPInstruction sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0))); } } - else if( _class.getSuperclass() == SpoofRowAggregate.class ) { //row aggregate operator - RowAggregateFunction fmmc = new RowAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars); - JavaPairRDD<MatrixIndexes,MatrixBlock> tmpRDD = in.mapToPair(fmmc); - MatrixBlock tmpMB = RDDAggregateUtils.sumStable(tmpRDD); - sec.setMatrixOutput(_out.getName(), tmpMB); + else if( _class.getSuperclass() == SpoofRowwise.class ) { //row aggregate operator + SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class); + RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars); + + if( op.getRowType().isColumnAgg() ) { + JavaPairRDD<MatrixIndexes,MatrixBlock> tmpRDD = in.mapToPair(fmmc); + MatrixBlock tmpMB = RDDAggregateUtils.sumStable(tmpRDD); + sec.setMatrixOutput(_out.getName(), tmpMB); + } + else //row-agg or no-agg + { + out = in.mapToPair(fmmc); + if( op.getRowType()==RowType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock() ) { + //TODO investigate if some other side effect of correct blocks + if( out.partitions().size() > mcIn.getNumRowBlocks() ) + out = RDDAggregateUtils.sumByKeyStable(out, (int)mcIn.getNumRowBlocks(), false); + else + out = RDDAggregateUtils.sumByKeyStable(out, false); + } + + sec.setRDDHandleForVariable(_out.getName(), out); + + //maintain lineage information for output rdd + sec.addLineageRDD(_out.getName(), _in[0].getName()); + for( String bcVar : bcVars ) + sec.addLineageBroadcast(_out.getName(), bcVar); + + //update matrix characteristics + updateOutputMatrixCharacteristics(sec, op); + } return; } else { @@ -236,7 +262,7 @@ public class SpoofSPInstruction extends SPInstruction } } - private static class RowAggregateFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> + private static class RowwiseFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -7926980450209760212L; @@ -244,9 +270,9 @@ public class SpoofSPInstruction extends SPInstruction private ArrayList<ScalarObject> _scalars = null; private byte[] _classBytes = null; private String _className = null; - private SpoofOperator _op = null; + private SpoofRowwise _op = null; - public RowAggregateFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + public RowwiseFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) throws DMLRuntimeException { _className = className; @@ -262,7 +288,7 @@ public class SpoofSPInstruction extends SPInstruction //lazy load of shipped class if( _op == null ) { Class<?> loadedClass = CodegenUtils.getClass(_className, _classBytes); - _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass); + _op = (SpoofRowwise) CodegenUtils.createInstance(loadedClass); } //get main input block and indexes @@ -272,7 +298,9 @@ public class SpoofSPInstruction extends SPInstruction //prepare output and execute single-threaded operator ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, rowIx); - MatrixIndexes ixOut = new MatrixIndexes(1,1); + MatrixIndexes ixOut = new MatrixIndexes( + _op.getRowType().isColumnAgg() ? 1 : ixIn.getRowIndex(), + _op.getRowType()!=RowType.NO_AGG ? 1 : ixIn.getColumnIndex()); MatrixBlock blkOut = new MatrixBlock(); _op.execute(inputs, _scalars, blkOut); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index c83bff3..865080a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -46,6 +46,9 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME8 = TEST_NAME+"8"; //colSums((X/rowSums(X))>0.7) private static final String TEST_NAME9 = TEST_NAME+"9"; //t(X) %*% (v - abs(y)) private static final String TEST_NAME10 = TEST_NAME+"10"; //Y=(X<=rowMins(X)); R=colSums((Y/rowSums(Y))); + private static final String TEST_NAME11 = TEST_NAME+"11"; //y - X %*% v + private static final String TEST_NAME12 = TEST_NAME+"12"; //Y=(X>=v); R=Y/rowSums(Y) + private static final String TEST_NAME13 = TEST_NAME+"13"; //rowSums(X)+rowSums(Y) private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -57,108 +60,203 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=10; i++) + for(int i=1; i<=13; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @Test - public void testCodegenRowAggRewrite1() { + public void testCodegenRowAggRewrite1CP() { testCodegenIntegration( TEST_NAME1, true, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite2() { + public void testCodegenRowAgg1CP() { + testCodegenIntegration( TEST_NAME1, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg1SP() { + testCodegenIntegration( TEST_NAME1, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite2CP() { testCodegenIntegration( TEST_NAME2, true, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite3() { + public void testCodegenRowAgg2CP() { + testCodegenIntegration( TEST_NAME2, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg2SP() { + testCodegenIntegration( TEST_NAME2, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite3CP() { testCodegenIntegration( TEST_NAME3, true, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite4() { - testCodegenIntegration( TEST_NAME4, true, ExecType.CP ); + public void testCodegenRowAgg3CP() { + testCodegenIntegration( TEST_NAME3, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg3SP() { + testCodegenIntegration( TEST_NAME3, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite4CP() { + testCodegenIntegration( TEST_NAME4, true, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite5() { - testCodegenIntegration( TEST_NAME5, true, ExecType.CP ); + public void testCodegenRowAgg4CP() { + testCodegenIntegration( TEST_NAME4, false, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite6() { - testCodegenIntegration( TEST_NAME6, true, ExecType.CP ); + public void testCodegenRowAgg4SP() { + testCodegenIntegration( TEST_NAME4, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite5CP() { + testCodegenIntegration( TEST_NAME5, true, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite7() { - testCodegenIntegration( TEST_NAME7, true, ExecType.CP ); + public void testCodegenRowAgg5CP() { + testCodegenIntegration( TEST_NAME5, false, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite8() { - testCodegenIntegration( TEST_NAME8, true, ExecType.CP ); + public void testCodegenRowAgg5SP() { + testCodegenIntegration( TEST_NAME5, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite6CP() { + testCodegenIntegration( TEST_NAME6, true, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite9() { - testCodegenIntegration( TEST_NAME9, true, ExecType.CP ); + public void testCodegenRowAgg6CP() { + testCodegenIntegration( TEST_NAME6, false, ExecType.CP ); } @Test - public void testCodegenRowAggRewrite10() { - testCodegenIntegration( TEST_NAME10, true, ExecType.CP ); + public void testCodegenRowAgg6SP() { + testCodegenIntegration( TEST_NAME6, false, ExecType.SPARK ); } @Test - public void testCodegenRowAgg1() { - testCodegenIntegration( TEST_NAME1, false, ExecType.CP ); + public void testCodegenRowAggRewrite7CP() { + testCodegenIntegration( TEST_NAME7, true, ExecType.CP ); } @Test - public void testCodegenRowAgg2() { - testCodegenIntegration( TEST_NAME2, false, ExecType.CP ); + public void testCodegenRowAgg7CP() { + testCodegenIntegration( TEST_NAME7, false, ExecType.CP ); } @Test - public void testCodegenRowAgg3() { - testCodegenIntegration( TEST_NAME3, false, ExecType.CP ); + public void testCodegenRowAgg7SP() { + testCodegenIntegration( TEST_NAME7, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite8CP() { + testCodegenIntegration( TEST_NAME8, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg8CP() { + testCodegenIntegration( TEST_NAME8, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg8SP() { + testCodegenIntegration( TEST_NAME8, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite9CP() { + testCodegenIntegration( TEST_NAME9, true, ExecType.CP ); } @Test - public void testCodegenRowAgg4() { - testCodegenIntegration( TEST_NAME4, false, ExecType.CP ); + public void testCodegenRowAgg9CP() { + testCodegenIntegration( TEST_NAME9, false, ExecType.CP ); } @Test - public void testCodegenRowAgg5() { - testCodegenIntegration( TEST_NAME5, false, ExecType.CP ); + public void testCodegenRowAgg9SP() { + testCodegenIntegration( TEST_NAME9, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite10CP() { + testCodegenIntegration( TEST_NAME10, true, ExecType.CP ); } @Test - public void testCodegenRowAgg6() { - testCodegenIntegration( TEST_NAME6, false, ExecType.CP ); + public void testCodegenRowAgg10CP() { + testCodegenIntegration( TEST_NAME10, false, ExecType.CP ); } @Test - public void testCodegenRowAgg7() { - testCodegenIntegration( TEST_NAME7, false, ExecType.CP ); + public void testCodegenRowAgg10SP() { + testCodegenIntegration( TEST_NAME10, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite11CP() { + testCodegenIntegration( TEST_NAME11, true, ExecType.CP ); } @Test - public void testCodegenRowAgg8() { - testCodegenIntegration( TEST_NAME8, false, ExecType.CP ); + public void testCodegenRowAgg11CP() { + testCodegenIntegration( TEST_NAME11, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg11SP() { + testCodegenIntegration( TEST_NAME11, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite12CP() { + testCodegenIntegration( TEST_NAME12, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg12CP() { + testCodegenIntegration( TEST_NAME12, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg12SP() { + testCodegenIntegration( TEST_NAME12, false, ExecType.SPARK ); + } + + @Test + public void testCodegenRowAggRewrite13CP() { + testCodegenIntegration( TEST_NAME13, true, ExecType.CP ); } @Test - public void testCodegenRowAgg9() { - testCodegenIntegration( TEST_NAME9, false, ExecType.CP ); + public void testCodegenRowAgg13CP() { + testCodegenIntegration( TEST_NAME13, false, ExecType.CP ); } @Test - public void testCodegenRowAgg10() { - testCodegenIntegration( TEST_NAME10, false, ExecType.CP ); + public void testCodegenRowAgg13SP() { + testCodegenIntegration( TEST_NAME13, false, ExecType.SPARK ); } private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/test/scripts/functions/codegen/rowAggPattern11.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern11.R b/src/test/scripts/functions/codegen/rowAggPattern11.R new file mode 100644 index 0000000..75ca730 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern11.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + + +X = matrix(seq(1,1500000), 150000, 10, byrow=TRUE); +v = seq(1,10); +y = seq(1,150000); + +R = y - X %*% v; + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/test/scripts/functions/codegen/rowAggPattern11.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern11.dml b/src/test/scripts/functions/codegen/rowAggPattern11.dml new file mode 100644 index 0000000..5a1c5b2 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern11.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,1500000), rows=150000, cols=10); +v = seq(1,10); +y = seq(1,150000); + +R = y - X %*% v; + +write(R, $1) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/test/scripts/functions/codegen/rowAggPattern12.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern12.R b/src/test/scripts/functions/codegen/rowAggPattern12.R new file mode 100644 index 0000000..9c66f14 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern12.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + + +X = matrix(seq(1,1500), 150, 10, byrow=TRUE); +v = seq(1,150); + +Y = (X >= v); +R = Y / rowSums(Y); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/test/scripts/functions/codegen/rowAggPattern12.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern12.dml b/src/test/scripts/functions/codegen/rowAggPattern12.dml new file mode 100644 index 0000000..a4b6834 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern12.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,1500), rows=150, cols=10); +v = seq(1,150); + +Y = (X >= v); +R = Y / rowSums(Y); + +write(R, $1) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/test/scripts/functions/codegen/rowAggPattern13.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern13.R b/src/test/scripts/functions/codegen/rowAggPattern13.R new file mode 100644 index 0000000..9bff2fe --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern13.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + + +X = matrix(seq(1,1500), 150, 10, byrow=TRUE); +Y = matrix(seq(2,1501), 150, 10, byrow=TRUE); + +R = rowSums(X) + rowSums(Y); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/eeb4f270/src/test/scripts/functions/codegen/rowAggPattern13.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern13.dml b/src/test/scripts/functions/codegen/rowAggPattern13.dml new file mode 100644 index 0000000..95c80df --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern13.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,1500), rows=150, cols=10); +Y = matrix(seq(2,1501), rows=150, cols=10); + +R = rowSums(X) + rowSums(Y); + +write(R, $1)
