Repository: systemml Updated Branches: refs/heads/master 1b3dff06b -> 85e3a9631
New rewrite rule for chains of element-wise multiply. Placed rewrite rule after Common Subexpression Elimination. Included helper method in HopRewriteUtils. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d578838 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d578838 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d578838 Branch: refs/heads/master Commit: 7d578838cc291a1adb6229bae01f7c9428b6f858 Parents: c434208 Author: Dylan Hutchison <[email protected]> Authored: Thu Jun 8 18:17:36 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Sun Jun 18 17:43:13 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 17 +- .../sysml/hops/rewrite/ProgramRewriter.java | 1 + .../apache/sysml/hops/rewrite/RewriteEMult.java | 186 +++++++++++++++++++ .../org/apache/sysml/parser/Expression.java | 1 + 4 files changed, 204 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index cf6081b..4d23cb9 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -241,7 +241,22 @@ public class HopRewriteUtils parent.getInput().add( pos, child ); child.getParent().add( parent ); } - + + /** + * Replace an old Hop with a replacement Hop. + * If the old Hop has no parents, then return the replacement. + * Otherwise rewire each of the Hop's parents into the replacement and return the replacement. + * @return replacement + */ + public static Hop replaceHop(final Hop old, final Hop replacement) { + final ArrayList<Hop> rootParents = old.getParent(); + if (rootParents.isEmpty()) + return replacement; // new old! + HopRewriteUtils.rewireAllParentChildReferences(old, replacement); + return replacement; + } + + public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) { ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent()); for( Hop lparent : parents ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 0e65f3f..8573dd7 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -96,6 +96,7 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); + _dagRuleSet.add( new RewriteEMult() ); //dependency: cse if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java new file mode 100644 index 0000000..47c32a9 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; + +/** + * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule. + * + * Rewrite a chain of element-wise multiply hops that contain identical elements. + * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power. + * + * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain, + * since foreign parents may rely on the individual results. + */ +public class RewriteEMult extends HopRewriteRule { + @Override + public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { + if( roots == null ) + return null; + + for( int i=0; i<roots.size(); i++ ) { + Hop h = roots.get(i); + roots.set(i, rule_RewriteEMult(h)); + } + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { + if( root == null ) + return null; + return rule_RewriteEMult(root); + } + + private static boolean isBinaryMult(final Hop hop) { + return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT; + } + + private static Hop rule_RewriteEMult(final Hop root) { + if (root.isVisited()) + return root; + root.setVisited(); + + final ArrayList<Hop> rootInputs = root.getInput(); + + // 1. Find immediate subtree of EMults. + if (isBinaryMult(root)) { + final Hop left = rootInputs.get(0), right = rootInputs.get(1); + final BinaryOp r = (BinaryOp)root; + final Set<BinaryOp> emults = new HashSet<>(); + final Multiset<Hop> leaves = HashMultiset.create(); + findEMultsAndLeaves(r, emults, leaves); + // 2. Ensure it is profitable to do a rewrite. + if (isOptimizable(leaves)) { + // 3. Check for foreign parents. + // A foreign parent is a parent of some EMult that is not in the set. + // Foreign parents destroy correctness of this rewrite. + final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && + (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); + if (okay) { + // 4. Construct replacement EMults for the leaves + final Hop replacement = constructReplacement(leaves); + + // 5. Replace root with replacement + return HopRewriteUtils.replaceHop(root, replacement); + } + } + } + + // This rewrite is not applicable to the current root. + // Try the root's children. + for (int i = 0; i < rootInputs.size(); i++) { + final Hop input = rootInputs.get(i); + final Hop newInput = rule_RewriteEMult(input); + rootInputs.set(i, newInput); + } + return root; + } + + private static Hop constructReplacement(final Multiset<Hop> leaves) { + // Sort by data type + final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType); + for (final Multiset.Entry<Hop> entry : leaves.entrySet()) { + final Hop h = entry.getElement(); + // unlink parents (the EMults, which we are throwing away) + h.getParent().clear(); + sorted.put(h, entry.getCount()); + } + // sorted contains all leaves, sorted by data type, stripped from their parents + + // Construct left-deep EMult tree + Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); + Hop first = constructPower(iterator.next()); + + for (int i = 1; i < sorted.size(); i++) { + final Hop second = constructPower(iterator.next()); + first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT); + } + return first; + } + + private static Hop constructPower(Map.Entry<Hop, Integer> entry) { + final Hop hop = entry.getKey(); + final int cnt = entry.getValue(); + assert(cnt >= 1); + if (cnt == 1) + return hop; + return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); + } + + private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType); + + private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { + final ArrayList<Hop> parents = child.getParent(); + if (parents.size() > 1) + for (final Hop parent : parents) + //noinspection SuspiciousMethodCalls + if (!emults.contains(parent)) + return false; + // child does not have foreign parents + + final ArrayList<Hop> inputs = child.getInput(); + final Hop left = inputs.get(0), right = inputs.get(1); + return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) && + (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); + } + + /** + * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root. + */ + private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) { + // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality. + emults.add(root); + + final ArrayList<Hop> inputs = root.getInput(); + final Hop left = inputs.get(0), right = inputs.get(1); + + if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves); + else leaves.add(left); + + if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves); + else leaves.add(right); + } + + /** Only optimize a subtree of EMults if at least one leaf occurs more than once. */ + private static boolean isOptimizable(final Multiset<Hop> set) { + for (Multiset.Entry<Hop> hopEntry : set.entrySet()) { + if (hopEntry.getCount() > 1) + return true; + } + return false; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 9ee3fba..b944e29 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -162,6 +162,7 @@ public abstract class Expression * Data types (matrix, scalar, frame, object, unknown). */ public enum DataType { + // Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN; public boolean isMatrix() {
