[SYSTEMML-1990] Vectorization of consecutive right-left indexing pairs This patch improves the existing indexing vectorization rewrite by a new sub-rewrite for right-left indexing pairs. For example, we now automatically rewrite the following patterns for consistent row/column indexing and :
1) Consecutive indexing pairs w/ literals B[, 2] = A[, 1]; B[, 3] = A[, 2]; B[, 4] = A[, 3]; --> B[, 2:4] = A[, 1:3]; 2) Consecutive indexing pairs w/ scalar increments pos = 2; B[, pos] = A[, 1]; pos = pos + 1; B[, pos] = A[, 2]; pos = pos + 1; B[, pos] = A[, 3]; --> B[, 2:4] = A[, 1:3]; Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a9c14b02 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a9c14b02 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a9c14b02 Branch: refs/heads/master Commit: a9c14b02b333721a507b224110f8cf4a09e91e95 Parents: ac7990e Author: Matthias Boehm <[email protected]> Authored: Mon Nov 13 22:31:24 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Tue Nov 14 11:54:29 2017 -0800 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 2 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 48 +++++++-- .../rewrite/RewriteIndexingVectorization.java | 86 +++++++++++++++ .../misc/RewriteIndexingVectorizationTest.java | 105 +++++++++++++++++++ .../misc/RewriteIndexingVectorizationCol.dml | 41 ++++++++ .../misc/RewriteIndexingVectorizationRow.dml | 41 ++++++++ .../functions/misc/ZPackageSuite.java | 1 + 7 files changed, 315 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a9c14b02/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 8aa673b..8683bb0 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -647,7 +647,7 @@ public class SpoofCompiler for( int i=0; i<roots.size(); i++ ) { Hop hnewi = (roots.get(i) instanceof AggUnaryOp) ? HopRewriteUtils.createScalarIndexing(hnew, 1, i+1) : - HopRewriteUtils.createMatrixIndexing(hnew, 1, i+1); + HopRewriteUtils.createIndexingOp(hnew, 1, i+1); HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/a9c14b02/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 66f4fc7..a7373fb 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -636,20 +636,32 @@ public class HopRewriteUtils } public static Hop createScalarIndexing(Hop input, long rix, long cix) { - Hop ix = createMatrixIndexing(input, rix, cix); + Hop ix = createIndexingOp(input, rix, cix); return createUnary(ix, OpOp1.CAST_AS_SCALAR); } - public static Hop createMatrixIndexing(Hop input, long rix, long cix) { + public static IndexingOp createIndexingOp(Hop input, long rix, long cix) { LiteralOp row = new LiteralOp(rix); LiteralOp col = new LiteralOp(cix); - IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, row, row, col, col, true, true); + return createIndexingOp(input, row, row, col, col); + } + + public static IndexingOp createIndexingOp(Hop input, Hop rl, Hop ru, Hop cl, Hop cu) { + IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, rl, ru, cl, cu, rl==ru, cl==cu); ix.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, ix); ix.refreshSizeInformation(); return ix; } + public static LeftIndexingOp createLeftIndexingOp(Hop lhs, Hop rhs, Hop rl, Hop ru, Hop cl, Hop cu) { + LeftIndexingOp ix = new LeftIndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, lhs, rhs, rl, ru, cl, cu, rl==ru, cl==cu); + ix.setOutputBlocksizes(lhs.getRowsInBlock(), lhs.getColsInBlock()); + copyLineNumbers(lhs, ix); + ix.refreshSizeInformation(); + return ix; + } + public static NaryOp createNary(OpOpN op, Hop... inputs) throws HopsException { Hop mainInput = inputs[0]; NaryOp nop = new NaryOp(mainInput.getName(), mainInput.getDataType(), @@ -1073,16 +1085,29 @@ public class HopRewriteUtils public static boolean isFullColumnIndexing(LeftIndexingOp hop) { return hop.isColLowerEqualsUpper() && isLiteralOfValue(hop.getInput().get(2), 1) - && isLiteralOfValue(hop.getInput().get(3), hop.getDim1()); - //TODO extend by input/output size conditions, which are currently - //invalid due to temporarily incorrect size information + && (isLiteralOfValue(hop.getInput().get(3), hop.getDim1()) + || isSizeExpressionOf(hop.getInput().get(3), hop.getInput().get(0), true)); + } + + public static boolean isFullColumnIndexing(IndexingOp hop) { + return hop.isColLowerEqualsUpper() + && isLiteralOfValue(hop.getInput().get(1), 1) + && (isLiteralOfValue(hop.getInput().get(2), hop.getDim1()) + || isSizeExpressionOf(hop.getInput().get(2), hop.getInput().get(0), true)); } public static boolean isFullRowIndexing(LeftIndexingOp hop) { return hop.isRowLowerEqualsUpper() && isLiteralOfValue(hop.getInput().get(4), 1) - && isLiteralOfValue(hop.getInput().get(5), hop.getDim2()); - //TODO extend by input/output size conditions (see above) + && (isLiteralOfValue(hop.getInput().get(5), hop.getDim2()) + || isSizeExpressionOf(hop.getInput().get(5), hop.getInput().get(0), false)); + } + + public static boolean isFullRowIndexing(IndexingOp hop) { + return hop.isRowLowerEqualsUpper() + && isLiteralOfValue(hop.getInput().get(3), 1) + && (isLiteralOfValue(hop.getInput().get(4), hop.getDim2()) + || isSizeExpressionOf(hop.getInput().get(4), hop.getInput().get(0), false)); } public static boolean isColumnRangeIndexing(IndexingOp hop) { @@ -1093,6 +1118,13 @@ public class HopRewriteUtils && hop.getInput().get(4) instanceof LiteralOp; } + public static boolean isConsecutiveIndex(Hop index, Hop index2) { + return (index instanceof LiteralOp && index2 instanceof LiteralOp) ? + getDoubleValueSafe((LiteralOp)index2) == (getDoubleValueSafe((LiteralOp)index)+1) : + (isBinaryMatrixScalar(index2, OpOp2.PLUS, 1) && + (index2.getInput().get(0) == index || index2.getInput().get(1) == index)); + } + public static boolean isUnnecessaryRightIndexing(Hop hop) { if( !(hop instanceof IndexingOp) ) return false; http://git-wip-us.apache.org/repos/asf/systemml/blob/a9c14b02/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java index 3f116c2..e2097e2 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.java @@ -20,10 +20,12 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; +import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; @@ -83,6 +85,7 @@ public class RewriteIndexingVectorization extends HopRewriteRule //apply indexing vectorization rewrites //MB: disabled right indexing rewrite because (1) piggybacked in MR anyway, (2) usually //not too much overhead, and (3) makes literal replacement more difficult + hi = vectorizeRightLeftIndexingChains(hi); //e.g., B[,1]=A[,1]; B[,2]=A[2]; -> B[,1:2] = A[,1:2] //vectorizeRightIndexing( hi ); //e.g., multiple rightindexing X[i,1], X[i,3] -> X[i,]; hi = vectorizeLeftIndexing( hi ); //e.g., multiple left indexing X[i,1], X[i,3] -> X[i,]; @@ -93,6 +96,89 @@ public class RewriteIndexingVectorization extends HopRewriteRule hop.setVisited(); } + private static Hop vectorizeRightLeftIndexingChains(Hop hi) + throws HopsException + { + //check for valid root operator + if( !(hi instanceof LeftIndexingOp + && hi.getInput().get(1) instanceof IndexingOp + && hi.getInput().get(1).getParent().size()==1) ) + return hi; + LeftIndexingOp lix0 = (LeftIndexingOp) hi; + IndexingOp rix0 = (IndexingOp) hi.getInput().get(1); + if( !(lix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper()) + || lix0.isRowLowerEqualsUpper() != rix0.isRowLowerEqualsUpper() + || lix0.isColLowerEqualsUpper() != rix0.isColLowerEqualsUpper()) + return hi; + boolean row = lix0.isRowLowerEqualsUpper(); + if( !( (row ? HopRewriteUtils.isFullRowIndexing(lix0) : HopRewriteUtils.isFullColumnIndexing(lix0)) + && (row ? HopRewriteUtils.isFullRowIndexing(rix0) : HopRewriteUtils.isFullColumnIndexing(rix0))) ) + return hi; + + //determine consecutive left-right indexing chains for rows/columns + List<LeftIndexingOp> lix = new ArrayList<>(); lix.add(lix0); + List<IndexingOp> rix = new ArrayList<>(); rix.add(rix0); + LeftIndexingOp clix = lix0; + IndexingOp crix = rix0; + while( isConsecutiveLeftRightIndexing(clix, crix, clix.getInput().get(0)) + && clix.getInput().get(0).getParent().size()==1 + && clix.getInput().get(0).getInput().get(1).getParent().size()==1 ) { + clix = (LeftIndexingOp)clix.getInput().get(0); + crix = (IndexingOp)clix.getInput().get(1); + lix.add(clix); rix.add(crix); + } + + //rewrite pattern if at least two consecutive pairs + if( lix.size() >= 2 ) { + IndexingOp rixn = rix.get(rix.size()-1); + Hop rlrix = rixn.getInput().get(1); + Hop rurix = row ? HopRewriteUtils.createBinary(rlrix, new LiteralOp(rix.size()-1), OpOp2.PLUS) : rixn.getInput().get(2); + Hop clrix = rixn.getInput().get(3); + Hop curix = row ? rixn.getInput().get(4) : HopRewriteUtils.createBinary(clrix, new LiteralOp(rix.size()-1), OpOp2.PLUS); + IndexingOp rixNew = HopRewriteUtils.createIndexingOp(rixn.getInput().get(0), rlrix, rurix, clrix, curix); + + LeftIndexingOp lixn = lix.get(rix.size()-1); + Hop rllix = lixn.getInput().get(2); + Hop rulix = row ? HopRewriteUtils.createBinary(rllix, new LiteralOp(lix.size()-1), OpOp2.PLUS) : lixn.getInput().get(3); + Hop cllix = lixn.getInput().get(4); + Hop culix = row ? lixn.getInput().get(5) : HopRewriteUtils.createBinary(cllix, new LiteralOp(lix.size()-1), OpOp2.PLUS); + LeftIndexingOp lixNew = HopRewriteUtils.createLeftIndexingOp(lixn.getInput().get(0), rixNew, rllix, rulix, cllix, culix); + + //rewire parents and childs + HopRewriteUtils.replaceChildReference(hi.getParent().get(0), hi, lixNew); + for( int i=0; i<lix.size(); i++ ) { + HopRewriteUtils.removeAllChildReferences(lix.get(i)); + HopRewriteUtils.removeAllChildReferences(rix.get(i)); + } + + hi = lixNew; + LOG.debug("Applied vectorizeRightLeftIndexingChains (line "+hi.getBeginLine()+")"); + } + + return hi; + } + + + private static boolean isConsecutiveLeftRightIndexing(LeftIndexingOp lix, IndexingOp rix, Hop input) { + if( !(input instanceof LeftIndexingOp + && input.getInput().get(1) instanceof IndexingOp) ) + return false; + boolean row = lix.isRowLowerEqualsUpper(); + LeftIndexingOp lix2 = (LeftIndexingOp) input; + IndexingOp rix2 = (IndexingOp) input.getInput().get(1); + //check row/column access with full row/column indexing + boolean access = (row ? HopRewriteUtils.isFullRowIndexing(lix2) && HopRewriteUtils.isFullRowIndexing(rix2) : + HopRewriteUtils.isFullColumnIndexing(lix2) && HopRewriteUtils.isFullColumnIndexing(rix2)); + //check equivalent right indexing inputs + boolean rixInputs = (rix.getInput().get(0) == rix2.getInput().get(0)); + //check consecutive access + boolean consecutive = (row ? HopRewriteUtils.isConsecutiveIndex(lix2.getInput().get(2), lix.getInput().get(2)) + && HopRewriteUtils.isConsecutiveIndex(rix2.getInput().get(1), rix.getInput().get(1)) : + HopRewriteUtils.isConsecutiveIndex(lix2.getInput().get(4), lix.getInput().get(4)) + && HopRewriteUtils.isConsecutiveIndex(rix2.getInput().get(3), rix.getInput().get(3))); + return access && rixInputs && consecutive; + } + /** * Note: unnecessary row or column indexing then later removed via * dynamic rewrites http://git-wip-us.apache.org/repos/asf/systemml/blob/a9c14b02/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteIndexingVectorizationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteIndexingVectorizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteIndexingVectorizationTest.java new file mode 100644 index 0000000..9331ded --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteIndexingVectorizationTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; + +public class RewriteIndexingVectorizationTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteIndexingVectorizationRow"; + private static final String TEST_NAME2 = "RewriteIndexingVectorizationCol"; + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteIndexingVectorizationTest.class.getSimpleName() + "/"; + + private static final int dim1 = 711; + private static final int dim2 = 7; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testIndexingVectorizationRowNoRewrite() { + testRewriteIndexingVectorization(TEST_NAME1, false); + } + + @Test + public void testIndexingVectorizationColNoRewrite() { + testRewriteIndexingVectorization(TEST_NAME2, false); + } + + @Test + public void testIndexingVectorizationRow() { + testRewriteIndexingVectorization(TEST_NAME1, true); + } + + @Test + public void testIndexingVectorizationCol() { + testRewriteIndexingVectorization(TEST_NAME2, true); + } + + + private void testRewriteIndexingVectorization(String testname, boolean vectorize) + { + boolean oldFlag = OptimizerUtils.ALLOW_AUTO_VECTORIZATION; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + int rows = testname.equals(TEST_NAME1) ? dim2 : dim1; + int cols = testname.equals(TEST_NAME1) ? dim1 : dim2; + + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-stats","-args", String.valueOf(rows), + String.valueOf(cols), output("R") }; + + OptimizerUtils.ALLOW_AUTO_VECTORIZATION = vectorize; + + runTest(true, false, null, -1); + + //compare output + double ret = readDMLMatrixFromHDFS("R").get(new CellIndex(1,1)); + Assert.assertTrue(ret == (711*5)); + + //check for applied rewrite + int expected = vectorize ? 1 : 5; + Assert.assertTrue(Statistics.getCPHeavyHitterCount("rightIndex")==expected+1); + Assert.assertTrue(Statistics.getCPHeavyHitterCount("leftIndex")==expected); + } + finally { + OptimizerUtils.ALLOW_AUTO_VECTORIZATION = oldFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/a9c14b02/src/test/scripts/functions/misc/RewriteIndexingVectorizationCol.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteIndexingVectorizationCol.dml b/src/test/scripts/functions/misc/RewriteIndexingVectorizationCol.dml new file mode 100644 index 0000000..794da42 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteIndexingVectorizationCol.dml @@ -0,0 +1,41 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(1, $1, $2); +B = matrix(0, $1, 6); + +while(FALSE){} + +pos = 2; +B[,pos] = A[,1]; +pos = pos + 1; +B[,pos] = A[,2]; +pos = pos + 1; +B[,pos] = A[,3]; +pos = pos + 1; +B[,pos] = A[,4]; +pos = pos + 1; +B[,pos] = A[,5]; + +while(FALSE){} + +R = as.matrix(sum(B[,2:6])); +write(R, $3) http://git-wip-us.apache.org/repos/asf/systemml/blob/a9c14b02/src/test/scripts/functions/misc/RewriteIndexingVectorizationRow.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteIndexingVectorizationRow.dml b/src/test/scripts/functions/misc/RewriteIndexingVectorizationRow.dml new file mode 100644 index 0000000..156f593 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteIndexingVectorizationRow.dml @@ -0,0 +1,41 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(1, $1, $2); +B = matrix(0, 6, $2); + +while(FALSE){} + +pos = 2; +B[pos,] = A[1,]; +pos = pos + 1; +B[pos,] = A[2,]; +pos = pos + 1; +B[pos,] = A[3,]; +pos = pos + 1; +B[pos,] = A[4,]; +pos = pos + 1; +B[pos,] = A[5,]; + +while(FALSE){} + +R = as.matrix(sum(B[2:6,])); +write(R, $3) http://git-wip-us.apache.org/repos/asf/systemml/blob/a9c14b02/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index ae4f820..8805500 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -59,6 +59,7 @@ import org.junit.runners.Suite; RewriteFoldRCBindTest.class, RewriteFuseBinaryOpChainTest.class, RewriteFusedRandTest.class, + RewriteIndexingVectorizationTest.class, RewriteLoopVectorization.class, RewriteMatrixMultChainOptTest.class, RewriteMergeBlocksTest.class,
