This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new b7480917b5 [SYSTEMDS-3333] Add a DP optimization for matrix chains
with transposes
b7480917b5 is described below
commit b7480917b5178b1f566f1c5aa68cfddaeb5e4f80
Author: Elman Jahangiri <[email protected]>
AuthorDate: Mon May 4 15:45:15 2026 +0200
[SYSTEMDS-3333] Add a DP optimization for matrix chains with transposes
This adds a new HOP rewrite rule,
RewriteMatrixMultChainWithTransOptimization.java, to find the optimal
execution plan for matrix multiplication chains containing transposes.
Previously, these chains were optimized using a simple heuristic that
just pushes transposes down from t(A %*% B) -> t(B) %*% t(A), which
fails to be the optimal plan in some instances especially with large
matrices.
An example would be R = t(A %*% B) %*% C with dimensions A = [16, 23], B
= [23, 22], C = [16, 34]
which would be according to the old rewrite class solved with (t(B) %*%
t(A)) %*% C -> costs: t(B) -> 23*22 + t(A) -> 16 * 23 + t(B) %*% t(A) ->
22*23*16 + [...] %*% C -> 22*16*34 = 20938 FLOPs
Optimal would be simply: t(A %*% B) %*% C - costs: A %*% B -> 16*23*22 +
t(A %*% B) -> 16*22 + [...] %*% C -> 22*16*34 = 20416 FLOPs - difference
gets larger with higher matrix dimensions.
To solve this, we applied a DP Algorithm with a Memo Table containing
Plans without transposing and Plans containing Transposing subchains
calculating wether an algebraic transpose pushdown or direct transpose
operation is cheaper.
This also includes 24 automated DML test cases asserting intermediate
HOP dimensions to validate optimal parenthesization and transpose
placement. = 20938 FLOPs
Optimal would be simply: t(A %*% B) %*% C - costs: A %*% B -> 16*23*22 +
t(A %*% B) -> 16*22 + [...] %*% C -> 22*16*34 = 20416 FLOPs - difference
gets larger with higher matrix dimensions.
To solve this, we applied a DP Algorithm with a Memo Table containing
Plans without transposing and Plans containing Transposing subchains
calculating wether an algebraic transpose pushdown or direct transpose
operation is cheaper.
This also includes 24 automated DML test cases asserting intermediate
HOP dimensions to validate optimal parenthesization and transpose
placement.
Closes #2465.
---
pom.xml | 1 -
.../java/org/apache/sysds/hops/OptimizerUtils.java | 5 +
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 3 +
...ewriteMatrixMultChainWithTransOptimization.java | 457 +++++++++++++++++++++
.../rewrite/RewriteMatrixChainDPTest.java | 325 +++++++++++++++
.../scripts/functions/rewrite/mmchain/test1.dml | 41 ++
.../scripts/functions/rewrite/mmchain/test10.dml | 33 ++
.../scripts/functions/rewrite/mmchain/test11.dml | 33 ++
.../scripts/functions/rewrite/mmchain/test12.dml | 36 ++
.../scripts/functions/rewrite/mmchain/test13.dml | 35 ++
.../scripts/functions/rewrite/mmchain/test14.dml | 37 ++
.../scripts/functions/rewrite/mmchain/test15.dml | 36 ++
.../scripts/functions/rewrite/mmchain/test16.dml | 36 ++
.../scripts/functions/rewrite/mmchain/test17.dml | 35 ++
.../scripts/functions/rewrite/mmchain/test18.dml | 33 ++
.../scripts/functions/rewrite/mmchain/test19.dml | 33 ++
.../scripts/functions/rewrite/mmchain/test2.dml | 42 ++
.../scripts/functions/rewrite/mmchain/test20.dml | 32 ++
.../scripts/functions/rewrite/mmchain/test21.dml | 33 ++
.../scripts/functions/rewrite/mmchain/test22.dml | 39 ++
.../scripts/functions/rewrite/mmchain/test23.dml | 36 ++
.../scripts/functions/rewrite/mmchain/test24.dml | 32 ++
.../scripts/functions/rewrite/mmchain/test3.dml | 42 ++
.../scripts/functions/rewrite/mmchain/test4.dml | 34 ++
.../scripts/functions/rewrite/mmchain/test5.dml | 38 ++
.../scripts/functions/rewrite/mmchain/test6.dml | 33 ++
.../scripts/functions/rewrite/mmchain/test7.dml | 36 ++
.../scripts/functions/rewrite/mmchain/test8.dml | 33 ++
.../scripts/functions/rewrite/mmchain/test9.dml | 36 ++
29 files changed, 1644 insertions(+), 1 deletion(-)
diff --git a/pom.xml b/pom.xml
index 732fb76082..a70d89501c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1577,6 +1577,5 @@
<artifactId>fastdoubleparser</artifactId>
<version>0.9.0</version>
</dependency>
-
</dependencies>
</project>
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index a02a550350..04850cf863 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -204,6 +204,11 @@ public class OptimizerUtils
* ALLOW_SUM_PRODUCT_REWRITES.
*/
public static boolean ALLOW_ADVANCED_MMCHAIN_REWRITES = false;
+
+ /**
+ * Enables a DPSize inspired algorithm rewrite for MMChain with
transposes
+ */
+ public static boolean ALLOW_NEW_MMCHAIN_REWRITE = false;
/**
* Enables a specific hop dag rewrite that splits hop dags after csv
persistent reads with
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 49c85191d2..efc3de5a65 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -136,6 +136,9 @@ public class ProgramRewriter{
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
_dagRuleSet.add( new
RewriteElementwiseMultChainOptimization()); //dependency: cse
}
+ if( OptimizerUtils.ALLOW_NEW_MMCHAIN_REWRITE ) {
+ _dagRuleSet.add( new
RewriteMatrixMultChainWithTransOptimization() );
+ }
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
_dagRuleSet.add( new
RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse
_dagRuleSet.add( new
RewriteMatrixMultChainOptimizationSparse() ); //dependency: cse
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainWithTransOptimization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainWithTransOptimization.java
new file mode 100644
index 0000000000..d866fe9343
--- /dev/null
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainWithTransOptimization.java
@@ -0,0 +1,457 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.HopsException;
+
+/**
+ * Advanced Matrix Multiplication Chain Optimizer using Dynamic Programming.
+ * <p>
+ * This rewrite optimizes matrix multiplication chains by simultaneously
exploring
+ * standard parenthesization and the transpose property: (A %*% B)^T = B^T %*%
A^T.
+ * It uses a DP algorithm to find the execution plan with the minimal
+ * computational cost (FLOPs), inserting physical transposes only when
mathematically cheaper.
+ * In comparison to RewriteMatrixMultChainOptimization.java this builds
complete new HOP DAG and returns it
+ */
+public class RewriteMatrixMultChainWithTransOptimization extends
HopRewriteRule {
+
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state)
+ {
+ if( roots == null )
+ return null;
+
+ // Find the optimal order for the chain whose result is the
current HOP
+ for( Hop h : roots )
+ rule_OptimizeMMChains(h, state);
+
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
+ {
+ if( root == null )
+ return null;
+
+ // Find the optimal order for the chain whose result is the
current HOP
+ rule_OptimizeMMChains(root, state);
+
+ return root;
+ }
+
+ /**
+ * Recursively traverses the HOP DAG to identify matrix multiplication
chains.
+ * Looks for either direct AggBinaryOps (%*%) or TransposeOps (t())
wrapping an AggBinaryOp.
+ *
+ * @param hop The current high-level operator node.
+ * @param state The rewrite status.
+ */
+ private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state)
{
+ if (hop.isVisited()) return;
+
+ boolean isMatrixMult = HopRewriteUtils.isMatrixMultiply(hop) &&
!((AggBinaryOp) hop).hasLeftPMInput();
+
+ boolean isTranspose = HopRewriteUtils.isTransposeOperation(hop)
// hop is a t() operator.
+ &&
HopRewriteUtils.isMatrixMultiply(hop.getInput().get(0)) // HOP's only child is
(..) %*% (...)
+ && !((AggBinaryOp)
hop.getInput().get(0)).hasLeftPMInput();
+
+ Hop currentHop = hop;
+
+ if (isMatrixMult || isTranspose) {
+ // Try to find and optimize the chain in which current
Hop is the
+ // last operator
+ currentHop = prepAndOptimizeMMChain(hop, state);
+ }
+
+ currentHop.setVisited();
+
+ // .toArray(new Hop[0]) this prevents
ConcurrentModificationException because the optimizer
+ // may replace or modify parts of the HOP DAG during recursion
+ for( Hop i : currentHop.getInput().toArray(new Hop[0]) ) {
+ rule_OptimizeMMChains(i, state);
+ }
+ }
+
+
+ private Hop prepAndOptimizeMMChain(Hop hop, ProgramRewriteStatus state)
{
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("MM Chain Optimization for HOP: (" +
hop.getClass().getSimpleName()
+ + ", " + hop.getHopID() + ", " + hop.getName()
+ ")");
+ }
+
+ List<Hop> mmChain = new ArrayList<>();
+ List<Boolean> isTransposedChain = new ArrayList<>();
+
+ boolean isRootTranspose =
HopRewriteUtils.isTransposeOperation(hop);
+
+ // if top node is a transpose, then we look at children nodes,
eitherweise not
+ Hop currentRoot = isRootTranspose ? hop.getInput().get(0) : hop;
+
+ if( isRootTranspose ) {
+ // if e.g: t(A %*% B) then we store it like t(B) %*%
t(A) in other order
+ mmChain.add(currentRoot.getInput().get(1));
+ mmChain.add(currentRoot.getInput().get(0));
+ }
+ else {
+ // no transpose: store in normal order matrices that
are multiplied
+ mmChain.add(currentRoot.getInput().get(0));
+ mmChain.add(currentRoot.getInput().get(1));
+ }
+
+ // store, wether matrices need to be transposed
+ isTransposedChain.add(isRootTranspose);
+ isTransposedChain.add(isRootTranspose);
+
+ int i = 0;
+ while (i < mmChain.size()) {
+ Hop currentHop = mmChain.get(i);
+ boolean currentIsTransposed = isTransposedChain.get(i);
+
+ Hop matrixMultHop = currentHop;
+
+ // does current HOP contain a transpose underneath?
+ boolean hasTranspose =
HopRewriteUtils.isTransposeOperation(currentHop);
+
+ // if yes, take the child node as matrixMultHop
+ matrixMultHop = hasTranspose ?
currentHop.getInput().get(0) : currentHop;
+
+ // default assumption: cannot expand this node
+ boolean expandable = false;
+
+ // only try to expand if standard matrix multiply
+ if (HopRewriteUtils.isMatrixMultiply(matrixMultHop)) {
+
+ // how many other HOPs are using matrixMultHop
as input
+ // excluding current position in our flattening
process and original root
+ long externalParents =
matrixMultHop.getParent().stream()
+ .filter(p -> (p != currentHop) && (p !=
currentRoot) && !p.isVisited())
+ .count();
+
+ // if current node is wrapped in t(..) also
need to check parent nodes
+ // of transpose node itself
+ if (hasTranspose) {
+ externalParents +=
currentHop.getParent().stream()
+ .filter(p -> p != currentRoot
&& !p.isVisited())
+ .count();
+ }
+
+ expandable = (externalParents == 0);
+ }
+
+ // Decision
+ // 1. Not expandable
+ if (!expandable) {
+ mmChain.set(i, currentHop);
+ isTransposedChain.set(i, currentIsTransposed);
+ i++;
+ }
+ else {
+ // 2. node is expandable
+ matrixMultHop.setVisited();
+ if (hasTranspose) {
+ currentHop.setVisited();
+ currentIsTransposed =
!currentIsTransposed;
+ }
+
+ List<Hop> children = matrixMultHop.getInput();
+
+ if (currentIsTransposed) {
+ mmChain.set(i, children.get(1));
+ mmChain.add(i+1, children.get(0));
+ }
+ else {
+ mmChain.set(i, children.get(0));
+ mmChain.add(i+1, children.get(1));
+ }
+
+ isTransposedChain.set(i, currentIsTransposed);
+ isTransposedChain.add(i+1, currentIsTransposed);
+ }
+ }
+ // only invoke if chain longer than 2 matrices
+ if (mmChain.size() > 2) {
+ return optimizeMMChain(hop, mmChain, isTransposedChain,
isRootTranspose);
+ }
+ return hop;
+ }
+
+ protected Hop optimizeMMChain(Hop hop, List<Hop> mmChain, List<Boolean>
isTransposedChain, boolean isRootTranspose) {
+ // Step 2: construct dims array
+ double[] dimsArray = new double[mmChain.size() + 1];
+ boolean dimsKnown = getDimsArray( hop, mmChain,
isTransposedChain, dimsArray );
+
+ if (dimsKnown) {
+ // Find the optimal ordering via dynamic programming.
+ // Step 3: Invoke Dynamic Programming
+ MemoTable memo = mmChainDP(dimsArray, mmChain.size(),
mmChain, isTransposedChain);
+
+ LOG.trace("Optimal MM Chain: ");
+ // Step 4: read optimal ordering and construct new tree
from that
+ Hop newRoot = mmChainBuildTree(0, mmChain.size() - 1,
mmChain, memo, isRootTranspose, hop);
+
+ // swap pointers to new tree if new tree was built
+ if (newRoot != hop) {
+ List<Hop> parents = new
ArrayList<>(hop.getParent());
+
+ for (Hop parent : parents) {
+
HopRewriteUtils.replaceChildReference(parent,hop, newRoot);
+ }
+ HopRewriteUtils.removeAllChildReferences(hop);
+ }
+ // return new tree
+ return newRoot;
+ }
+ // no optimization happened
+ return hop;
+ }
+
+
+ /**
+ * mmChainDP(): Core method to perform dynamic programming on a given
array
+ * of matrix dimensions and additional array with transpose flags
+ */
+ private static MemoTable mmChainDP(double[] dimArray, int size,
List<Hop> mmChain, List<Boolean> isTransposeChain) {
+ // create memo table
+ MemoTable memo = new MemoTable(size);
+
+ // 1.) THE BASE CASE
+ // loop through every matrix in the chain
+ for (int i = 0; i < size; i++) {
+ // fetch and store rows, cols and transpose flag
+ double rows = dimArray[i];
+ double cols = dimArray[i+1];
+ boolean isTransposed = isTransposeChain.get(i);
+
+ // for standard matrix:
+ if (!isTransposed) {
+ // create the normal plan
+ Plan normalPlan = new Plan();
+ normalPlan.cost = 0; // no costs
+ normalPlan.withTranspose = false;
+ memo.setNormal(i, i, normalPlan);
+
+ // create the transposed plan
+ Plan transposedPlan = new Plan();
+ transposedPlan.cost = rows * cols; // cost is
FLOPs for transposing: rows * cols
+ transposedPlan.withTranspose = true;
+ memo.setTransposed(i, i, transposedPlan);
+ }
+ // opposite for transposed matrix
+ else {
+ // since matrix is transposed, normal plan
requires transposing it again
+ Plan normalPlan = new Plan();
+ normalPlan.cost = rows * cols;
+ normalPlan.withTranspose = true;
+ memo.setNormal(i, i, normalPlan);
+
+ // already transposed, so no costs
+ Plan transposedPlan = new Plan();
+ transposedPlan.cost = 0;
+ transposedPlan.withTranspose = false;
+ memo.setTransposed(i, i, transposedPlan);
+ }
+ }
+
+ // 2. COMBINATIONS OF BLOCKS
+ for (int subchainSize = 2; subchainSize <= size;
subchainSize++) {
+ for (int i = 0; i < size - subchainSize + 1; i++ ) {
+ int j = i + subchainSize - 1;
+
+ // final dimensions of subchain if multiplied
normally
+ double normalOutRows = dimArray[i];
+ double normalOutCols = dimArray[j+1];
+
+ Plan bestNormalPlan = new Plan();
+ Plan bestTransposedPlan = new Plan();
+
+ // evaluate and compare every split point of
the chain A %*% (B %*% C) or (A %*% B) %*% C
+ for (int k = i; k < j; k++) {
+ // evaluate normal plan
+ // The index where the line between
Left and Right chain is splitted -> k
+ Plan normalLeft = memo.getNormal(i, k);
+ Plan normalRight = memo.getNormal(k+1,
j);
+
+ double costMatMult = normalOutRows *
dimArray[k+1] * normalOutCols;
+ double costNormal = normalLeft.cost +
normalRight.cost + costMatMult;
+
+ if (costNormal < bestNormalPlan.cost) {
+ bestNormalPlan.cost =
costNormal;
+ bestNormalPlan.splitIndex = k;
+ bestNormalPlan.withTranspose =
false;
+ }
+
+ // evaluate transposed plan
+ Plan transposedLeft =
memo.getTransposed(i, k);
+ Plan transposedRight =
memo.getTransposed(k+1, j);
+ double costTransposed =
transposedLeft.cost + transposedRight.cost + costMatMult;
+
+ if (costTransposed <
bestTransposedPlan.cost) {
+ bestTransposedPlan.cost =
costTransposed;
+ bestTransposedPlan.splitIndex =
k;
+
bestTransposedPlan.withTranspose = false;
+ }
+ }
+ // costs, after the full subchain is calculated
and then transposed
+ double transposeCost = normalOutRows *
normalOutCols;
+
+ // check if t(A %*% B) cheaper than t(B) %*%
t(A)
+ if (bestNormalPlan.cost + transposeCost <
bestTransposedPlan.cost) {
+ bestTransposedPlan.cost =
bestNormalPlan.cost + transposeCost;
+ bestTransposedPlan.splitIndex =
bestNormalPlan.splitIndex;
+ bestTransposedPlan.withTranspose = true;
+ }
+
+ // check if t(t(B) %*% t(A)) cheaper than A %*%
B
+ if (bestTransposedPlan.cost + transposeCost <
bestNormalPlan.cost) {
+ bestNormalPlan.cost =
bestTransposedPlan.cost + transposeCost;
+ bestNormalPlan.splitIndex =
bestTransposedPlan.splitIndex;
+ bestNormalPlan.withTranspose = true;
+ }
+ memo.setNormal(i, j, bestNormalPlan);
+ memo.setTransposed(i, j, bestTransposedPlan);
+ }
+ }
+ return memo;
+ }
+
+
+ private Hop mmChainBuildTree(int i, int j, List<Hop> mmChain, MemoTable
memo, boolean isTransposed, Hop rootHop) {
+ Plan plan = isTransposed ? memo.getTransposed(i, j) :
memo.getNormal(i, j);
+
+ // Base Case with one matrix
+ if (i == j) {
+ Hop leaf = mmChain.get(i);
+ if (plan.withTranspose) {
+ Hop t = HopRewriteUtils.createTranspose(leaf);
+ t.setExecType(rootHop.getExecType());
+ t.refreshSizeInformation();
+ t.setBlocksize(rootHop.getBlocksize());
+ t.setVisited();
+ return t;
+ }
+ return leaf;
+ }
+ if (plan.withTranspose) {
+ Hop child = mmChainBuildTree(i, j, mmChain, memo,
!isTransposed, rootHop);
+ Hop t = HopRewriteUtils.createTranspose(child);
+ t.setExecType(rootHop.getExecType());
+ t.refreshSizeInformation();
+ t.setBlocksize(rootHop.getBlocksize());
+ t.setVisited();
+ return t;
+ }
+
+ Hop leftChild, rightChild;
+ if (isTransposed) {
+ leftChild = mmChainBuildTree(plan.splitIndex + 1, j,
mmChain, memo, true, rootHop);
+ rightChild = mmChainBuildTree(i, plan.splitIndex,
mmChain, memo, true, rootHop);
+ }
+ else {
+ leftChild = mmChainBuildTree(i, plan.splitIndex,
mmChain, memo, false, rootHop);
+ rightChild = mmChainBuildTree(plan.splitIndex + 1, j,
mmChain, memo, false, rootHop);
+ }
+ Hop multOp = HopRewriteUtils.createMatrixMultiply(leftChild,
rightChild);
+ multOp.setExecType(rootHop.getExecType());
+ multOp.refreshSizeInformation();
+ multOp.setBlocksize(rootHop.getBlocksize());
+ multOp.setVisited();
+ return multOp;
+ }
+
+
+ /**
+ * Obtains all dimension information of the chain and constructs the
dimArray.
+ *
+ * If all dimensions are known it returns true; othrewise the mmchain
rewrite
+ * should be ended without modifications.
+ *
+ * @param hop high-level operator
+ * @param chain list of high-level operators
+ * @param isTransposeChain Parallel list of boolean flags indicating if
a matrix is transposed
+ * @param dimsArray dimension array
+ * @return true if all dimensions known
+ */
+ protected static boolean getDimsArray(Hop hop, List<Hop> chain,
List<Boolean> isTransposeChain, double[] dimsArray) {
+ boolean dimsKnown = true;
+
+ // Build the array containing dimensions from all matrices in
the chain
+ // check the dimensions in the matrix chain to ensure all
dimensions are known
+ for (int i = 0; i < chain.size(); i++) {
+ Hop leaf = chain.get(i);
+ // fetching dimensions
+ long dim1 = leaf.getDim1();
+ long dim2 = leaf.getDim2();
+
+
+ if( chain.get(i).getDim1() <= 0 ||
chain.get(i).getDim2() <= 0 ) {
+ dimsKnown = false;
+ break;
+ }
+
+ if (isTransposeChain.get(i)) {
+ long temp = dim1;
+ dim1 = dim2;
+ dim2 = temp;
+ }
+
+ if (i == 0) {
+ dimsArray[i] = dim1;
+ }
+ else if (dimsArray[i] != dim1) {
+ throw new HopsException(
hop.printErrorLocation() +
+ "Hops::optimizeMMChain() : Matrix
Dimension Mismatch: " +
+ dimsArray[i] +" != "+ dim1);
+ }
+ dimsArray[i + 1] = dim2;
+ }
+ return dimsKnown;
+ }
+
+ /**
+ * A blueprint object tracking the cheapest cost and split-point for a
sub-problem.
+ */
+ private static class Plan {
+ double cost = Double.MAX_VALUE;
+ int splitIndex = -1;
+ boolean withTranspose;
+ }
+
+ /**
+ * Dual-state 2D array matrix holding the memoized sub-problems.
+ */
+ private static class MemoTable {
+ private final Plan[][] normalPlans;
+ private final Plan[][] transposedPlans;
+ public MemoTable(int size) {
+ normalPlans = new Plan[size][size];
+ transposedPlans = new Plan[size][size];
+ }
+ public Plan getNormal(int i, int j) { return normalPlans[i][j];
}
+ public Plan getTransposed(int i, int j) { return
transposedPlans[i][j]; }
+ public void setNormal(int i, int j, Plan plan) {
normalPlans[i][j] = plan; }
+ public void setTransposed(int i, int j, Plan plan) {
transposedPlans[i][j] = plan; }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixChainDPTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixChainDPTest.java
new file mode 100644
index 0000000000..66d224af77
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixChainDPTest.java
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.io.ByteArrayOutputStream;
+import java.io.PrintStream;
+
+public class RewriteMatrixChainDPTest extends AutomatedTestBase {
+
+ private static final String TEST_DIR = "functions/rewrite/mmchain/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteMatrixChainDPTest.class.getSimpleName() + "/";
+
+ private static final String[] TEST_CASES = {
+ "test1", "test2", "test3", "test4", "test5", "test6", "test7",
+ "test8", "test9", "test10", "test11", "test12","test13",
+ "test14", "test15", "test16", "test17", "test18", "test19",
+ "test20", "test21", "test22", "test23", "test24"
+ };
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ for (String testName : TEST_CASES) {
+ addTestConfiguration(testName, new
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {"R"}));
+ }
+ }
+
+ @Test
+ public void testMatrixChainDP_Test1() {
runTestMatrixChainDP(TEST_CASES[0]); }
+
+ @Test
+ public void testMatrixChainDP_Test2() {
runTestMatrixChainDP(TEST_CASES[1]); }
+
+ @Test
+ public void testMatrixChainDP_Test3() {
runTestMatrixChainDP(TEST_CASES[2]); }
+
+ @Test
+ public void testMatrixChainDP_Test4() {
runTestMatrixChainDP(TEST_CASES[3]); }
+
+ @Test
+ public void testMatrixChainDP_Test5() {
runTestMatrixChainDP(TEST_CASES[4]); }
+
+ @Test
+ public void testMatrixChainDP_Test6() {
runTestMatrixChainDP(TEST_CASES[5]); }
+
+ @Test
+ public void testMatrixChainDP_Test7() {
runTestMatrixChainDP(TEST_CASES[6]); }
+
+ @Test
+ public void testMatrixChainDP_Test8() {
runTestMatrixChainDP(TEST_CASES[7]); }
+
+ @Test
+ public void testMatrixChainDP_Test9() {
runTestMatrixChainDP(TEST_CASES[8]); }
+
+ @Test
+ public void testMatrixChainDP_Test10() {
runTestMatrixChainDP(TEST_CASES[9]); }
+
+ @Test
+ public void testMatrixChainDP_Test11() {
runTestMatrixChainDP(TEST_CASES[10]); }
+
+ @Test
+ public void testMatrixChainDP_Test12() {
runTestMatrixChainDP(TEST_CASES[11]); }
+
+ @Test
+ public void testMatrixChainDP_Test13() {
runTestMatrixChainDP(TEST_CASES[12]); }
+
+ @Test
+ public void testMatrixChainDP_Test14() {
runTestMatrixChainDP(TEST_CASES[13]); }
+
+ @Test
+ public void testMatrixChainDP_Test15() {
runTestMatrixChainDP(TEST_CASES[14]); }
+
+ @Test
+ public void testMatrixChainDP_Test16() {
runTestMatrixChainDP(TEST_CASES[15]); }
+
+ @Test
+ public void testMatrixChainDP_Test17() {
runTestMatrixChainDP(TEST_CASES[16]); }
+
+ @Test
+ public void testMatrixChainDP_Test18() {
runTestMatrixChainDP(TEST_CASES[17]); }
+
+ @Test
+ public void testMatrixChainDP_Test19() {
runTestMatrixChainDP(TEST_CASES[18]); }
+
+ @Test
+ public void testMatrixChainDP_Test20() {
runTestMatrixChainDP(TEST_CASES[19]); }
+
+ @Test
+ public void testMatrixChainDP_Test21() {
runTestMatrixChainDP(TEST_CASES[20]); }
+
+ @Test
+ public void testMatrixChainDP_Test22() {
runTestMatrixChainDP(TEST_CASES[21]); }
+
+ @Test
+ public void testMatrixChainDP_Test23() {
runTestMatrixChainDP(TEST_CASES[22]); }
+
+ @Test
+ public void testMatrixChainDP_Test24()
{runTestMatrixChainDP(TEST_CASES[23]);}
+
+
+ private void runTestMatrixChainDP(String testName) {
+ ExecMode platformOld = rtplatform;
+ boolean rewritesOld =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ boolean newMMchain1 =
OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES;
+ boolean newMMchain2 = OptimizerUtils.ALLOW_NEW_MMCHAIN_REWRITE;
+
+ try {
+ rtplatform = ExecMode.SINGLE_NODE;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
+ OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES = true;
+ OptimizerUtils.ALLOW_NEW_MMCHAIN_REWRITE = true;
+
+ TestConfiguration config =
getTestConfiguration(testName);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testName + ".dml";
+
+ programArgs = new String[]{ "-explain", "hops",
"-stats", "-args", output("R") };
+
+ // print HOP DAG
+ PrintStream originalOut = System.out;
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ System.setOut(new PrintStream(bos));
+
+ try {
+ // Execute the DML script
+ runTest(true, false, null, -1);
+ } finally {
+ System.setOut(originalOut);
+ }
+
+ String output = bos.toString();
+
+ System.out.println("Output for " + testName + ":\n" +
output);
+
+ /* the following uses the intermediate matrices
dimensions to check, wether
+ * the rewrite rule has found the optimal plan, which
is commented in each script
+ */
+ switch(testName) {
+ case "test1" -> {
+
Assert.assertTrue(output.contains("[4,1"));
+
Assert.assertFalse(output.contains("[2,4"));
+ }
+ case "test2" -> {
+
Assert.assertTrue(output.contains("[10,5"));
+
Assert.assertTrue(output.contains("[50,5"));
+ }
+ case "test3" -> {
+
Assert.assertTrue(output.contains("[30,2"));
+
Assert.assertTrue(output.contains("[2,5"));
+ }
+ case "test4" ->{
+
Assert.assertTrue(output.contains("[4,5"));
+
Assert.assertTrue(output.contains("[4,9"));
+
Assert.assertTrue(output.contains("[4,8"));
+ }
+ case "test5" -> {
+
Assert.assertTrue(output.contains("[8,3"));
+
Assert.assertTrue(output.contains("[3,4"));
+
Assert.assertFalse(output.contains("[4,8"));
+
Assert.assertFalse(output.contains("[8,8"));
+ }
+ case "test6" -> {
+
Assert.assertTrue(output.contains("[8,9"));
+
Assert.assertTrue(output.contains("[6,2"));
+
Assert.assertTrue(output.contains("[9,8"));
+
Assert.assertTrue(output.contains("[2,2"));
+
Assert.assertTrue(output.contains("[8,2"));
+ }
+ case "test7" -> {
+
Assert.assertTrue(output.contains("[1,1000"));
+
Assert.assertFalse(output.contains("[1000,2"));
+ }
+ case "test8" -> {
+
Assert.assertTrue(output.contains("[30,2"));
+
Assert.assertTrue(output.contains("[2,30"));
+
Assert.assertTrue(output.contains("[30,2"));
+
Assert.assertTrue(output.contains("[2,4"));
+ }
+ case "test9" -> {
+
Assert.assertTrue(output.contains("[10,2"));
+
Assert.assertTrue(output.contains("[3,2"));
+
Assert.assertTrue(output.contains("[2,10"));
+
Assert.assertTrue(output.contains("[2,30"));
+
Assert.assertFalse(output.contains("[3,10"));
+ }
+ case "test10" -> {
+
Assert.assertTrue(output.contains("[2,55"));
+
Assert.assertTrue(output.contains("[35,2"));
+
Assert.assertTrue(output.contains("[2,3"));
+
Assert.assertTrue(output.contains("[3,2"));
+ }
+ case "test11" -> {
+
Assert.assertTrue(output.contains("[3,55"));
+
Assert.assertTrue(output.contains("[3,23"));
+
Assert.assertTrue(output.contains("[35,3"));
+
Assert.assertTrue(output.contains("[23,3"));
+ }
+ case "test12" -> {
+
Assert.assertTrue(output.contains("[3,43"));
+
Assert.assertTrue(output.contains("[3,12"));
+
Assert.assertTrue(output.contains("[3,23"));
+
Assert.assertTrue(output.contains("[23,3"));
+
Assert.assertTrue(output.contains("[33,3"));
+ }
+ case "test13" -> {
+
Assert.assertFalse(output.contains("[13,14"));
+
Assert.assertTrue(output.contains("[14,12"));
+
Assert.assertTrue(output.contains("[12,14"));
+
Assert.assertTrue(output.contains("[12,16"));
+
Assert.assertTrue(output.contains("[16,12"));
+ }
+ case "test14" -> {
+
Assert.assertTrue(output.contains("[9,12"));
+
Assert.assertTrue(output.contains("[12,9"));
+
Assert.assertTrue(output.contains("[13,9"));
+
Assert.assertTrue(output.contains("[16,9"));
+
Assert.assertTrue(output.contains("[9,16"));
+
Assert.assertTrue(output.contains("[9,14"));
+ }
+ case "test15" -> {
+
Assert.assertTrue(output.contains("[12,12"));
+
Assert.assertTrue(output.contains("[12,13"));
+
Assert.assertTrue(output.contains("[13,16"));
+
Assert.assertTrue(output.contains("[13,14"));
+ }
+ case "test16" -> {
+
Assert.assertFalse(output.contains("[16,22"));
+
Assert.assertTrue(output.contains("[13,16"));
+
Assert.assertTrue(output.contains("[13,14"));
+
Assert.assertTrue(output.contains("[22,13"));
+
Assert.assertTrue(output.contains("[18,13"));
+
Assert.assertTrue(output.contains("[12,13"));
+ }
+ case "test17" -> {
+
Assert.assertFalse(output.contains("[23,16"));
+
Assert.assertTrue(output.contains("[16,22"));
+
Assert.assertTrue(output.contains("[22,16"));
+
Assert.assertTrue(output.contains("[443,16"));
+
Assert.assertTrue(output.contains("[124,16"));
+
Assert.assertTrue(output.contains("[124,34"));
+ }
+ case "test18" -> {
+
Assert.assertFalse(output.contains("[23,16"));
+
Assert.assertTrue(output.contains("[16,22"));
+
Assert.assertTrue(output.contains("[22,16"));
+
Assert.assertTrue(output.contains("[33,22"));
+
Assert.assertTrue(output.contains("[33,16"));
+ }
+ case "test19" -> {
+
Assert.assertTrue(output.contains("[2,6"));
+
Assert.assertTrue(output.contains("[2,4"));
+
Assert.assertTrue(output.contains("[3,2"));
+ }
+ case "test20" -> {
+
Assert.assertTrue(output.contains("[10,30"));
+
Assert.assertTrue(output.contains("[10,5"));
+
Assert.assertTrue(output.contains("[50,5"));
+ }
+ case "test21" -> {
+
Assert.assertTrue(output.contains("[5,3"));
+
Assert.assertTrue(output.contains("[40,3"));
+
Assert.assertTrue(output.contains("[6,3"));
+
Assert.assertTrue(output.contains("[3,6"));
+
Assert.assertTrue(output.contains("[3,50"));
+
Assert.assertFalse(output.contains("[20,6"));
+ }
+ case "test22" -> {
+
Assert.assertTrue(output.contains("[23,34"));
+
Assert.assertTrue(output.contains("[15,34"));
+
Assert.assertTrue(output.contains("[15,25"));
+
Assert.assertTrue(output.contains("[15,18"));
+
Assert.assertTrue(output.contains("[18,15"));
+
Assert.assertTrue(output.contains("[15,24"));
+
Assert.assertTrue(output.contains("[15,16"));
+ }
+ case "test23" -> {
+
Assert.assertTrue(output.contains("[10,5"));
+
Assert.assertTrue(output.contains("[20,5"));
+
Assert.assertTrue(output.contains("[10,6"));
+
Assert.assertTrue(output.contains("[1000,6"));
+ }
+ case "test24" -> {
+
Assert.assertTrue(output.contains("[9,5"));
+
Assert.assertTrue(output.contains("[4,9"));
+
Assert.assertTrue(output.contains("[5,4"));
+
Assert.assertTrue(output.contains("[6,5"));
+ }
+ }
+ } finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewritesOld;
+ OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES =
newMMchain1;
+ OptimizerUtils.ALLOW_NEW_MMCHAIN_REWRITE = newMMchain2;
+ rtplatform = platformOld;
+ Recompiler.reinitRecompiler();
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/mmchain/test1.dml
b/src/test/scripts/functions/rewrite/mmchain/test1.dml
new file mode 100644
index 0000000000..bee945b7b0
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test1.dml
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# A = 4x2
+# B = 2x1
+# C = 1x4
+#
+# Cost (A%*%B)%*%C = (4*2*1) + (4*1*4) = 8 + 16 = 24 -> OPTIMAL PLAN
+#
+# Cost A%*%(B%*%C) = (2*1*4) + (4*2*4) = 8 + 32 = 40
+
+
+# initialize matrices with random values
+A = rand(rows=4, cols=2)
+B = rand(rows=2, cols=1)
+C = rand(rows=1, cols=4)
+
+# operation chain
+D = A %*% B %*% C
+
+
+write(D, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test10.dml
b/src/test/scripts/functions/rewrite/mmchain/test10.dml
new file mode 100644
index 0000000000..314864c4e9
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test10.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=3, cols=3)
+B = rand(rows=3, cols=35)
+C = rand(rows=55, cols=2)
+D = rand(rows=2, cols=35)
+E = rand(rows=55, cols=3)
+
+# optimal plan: (A %*% (B %*% t(D))) %*% (t(C) %*% E)
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test11.dml
b/src/test/scripts/functions/rewrite/mmchain/test11.dml
new file mode 100644
index 0000000000..e5f1c893f7
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test11.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=3, cols=3)
+B = rand(rows=3, cols=35)
+C = rand(rows=55, cols=23)
+D = rand(rows=23, cols=35)
+E = rand(rows=55, cols=3)
+
+# A %*% t( t(E) %*% C) %*% (D %*% t(B) ) -> optimal plan
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test12.dml
b/src/test/scripts/functions/rewrite/mmchain/test12.dml
new file mode 100644
index 0000000000..98a42426ae
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test12.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=61, cols=33)
+B = rand(rows=33, cols=23)
+C = rand(rows=43, cols=12)
+D = rand(rows=12, cols=23)
+E = rand(rows=43, cols=3)
+
+# A %*% (B %*% t( (t(E) %*% C) %*% D )) -> optimal plan
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
+
+
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test13.dml
b/src/test/scripts/functions/rewrite/mmchain/test13.dml
new file mode 100644
index 0000000000..a0abd3feb0
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test13.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=16, cols=15)
+B = rand(rows=15, cols=12)
+C = rand(rows=14, cols=13)
+D = rand(rows=13, cols=12)
+E = rand(rows=14, cols=16)
+
+# (A %*% B) %*% (t(C %*% D) %*% E) -> optimal plan
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
+
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test14.dml
b/src/test/scripts/functions/rewrite/mmchain/test14.dml
new file mode 100644
index 0000000000..fdbde9e45a
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test14.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=9, cols=18)
+B = rand(rows=18, cols=12)
+C = rand(rows=16, cols=13)
+D = rand(rows=13, cols=12)
+E = rand(rows=16, cols=14)
+
+# t( C %*% (D %*% t(A %*% B) ) ) %*% E -> optimal plan
+# 9 * 18 * 12 + 9 * 12 + 13 * 12 * 9 + 16*13*9 + 16*9 + 9 * 16 * 14 = 7488
FLOPs
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
+
+
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test15.dml
b/src/test/scripts/functions/rewrite/mmchain/test15.dml
new file mode 100644
index 0000000000..5c44487cf7
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test15.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=12, cols=18)
+B = rand(rows=18, cols=12)
+C = rand(rows=16, cols=13)
+D = rand(rows=13, cols=12)
+E = rand(rows=16, cols=14)
+
+# ((A %*% B) %*% t(D)) %*% (t(C) %*% E) -> optimal plan
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
+
+
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test16.dml
b/src/test/scripts/functions/rewrite/mmchain/test16.dml
new file mode 100644
index 0000000000..39dc2a18f1
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test16.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=12, cols=18)
+B = rand(rows=18, cols=22)
+C = rand(rows=16, cols=13)
+D = rand(rows=13, cols=22)
+E = rand(rows=16, cols=14)
+
+# (A %*% (B %*% t(D))) %*% (t(C) %*% E) -> optimal plan
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
+
+
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test17.dml
b/src/test/scripts/functions/rewrite/mmchain/test17.dml
new file mode 100644
index 0000000000..d9f6614640
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test17.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=124, cols=443)
+B = rand(rows=443, cols=22)
+C = rand(rows=16, cols=23)
+D = rand(rows=23, cols=22)
+E = rand(rows=16, cols=34)
+
+# (A %*% (B %*% t(C %*% D))) %*% E -> optimal plan
+
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test18.dml
b/src/test/scripts/functions/rewrite/mmchain/test18.dml
new file mode 100644
index 0000000000..7fe8671bcf
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test18.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=33, cols=443)
+B = rand(rows=443, cols=22)
+C = rand(rows=16, cols=23)
+D = rand(rows=23, cols=22)
+E = rand(rows=16, cols=34)
+
+# ((A %*% B) %*% t(C %*% D)) %*% E -> optimal plan
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test19.dml
b/src/test/scripts/functions/rewrite/mmchain/test19.dml
new file mode 100644
index 0000000000..858d9b451c
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test19.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=3, cols=4)
+B = rand(rows=4, cols=2)
+C = rand(rows=6, cols=2)
+D = rand(rows=2, cols=2)
+E = rand(rows=6, cols=4)
+
+# ((A %*% B) %*% t(D)) %*% (t(C) %*% E) -> optimal plan
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test2.dml
b/src/test/scripts/functions/rewrite/mmchain/test2.dml
new file mode 100644
index 0000000000..b38685e9a9
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test2.dml
@@ -0,0 +1,42 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+
+# A = 50x10
+# B = 10x40
+# C = 40x5
+# D = 5x60
+#
+# 1. ((A%*%B)%*%C)%*%D = (50*10*40) + (50*40*5) + (50*5*60) = 20000 + 10000 +
15000 = 45,000
+# 2. (A%*%(B%*%C))%*%D = (10*40*5) + (50*10*5) + (50*5*60) = 2000 + 2500 +
15000 = 19,500 -> OPTIMAL PLAN
+# 3. (A%*%B)%*%(C%*%D) = (50*10*40) + (40*5*60) + (50*40*60) = 20000 + 12000 +
120000 = 152,000
+# 4. A%*%((B%*%C)%*%D) = (10*40*5) + (10*5*60) + (50*10*60) = 2000 + 3000 +
30000 = 35,000
+# 5. A%*%(B%*%(C%*%D)) = (40*5*60) + (10*40*60) + (50*10*60) = 12000 + 24000 +
30000 = 66,000
+
+A = rand(rows=50, cols=10)
+B = rand(rows=10, cols=40)
+C = rand(rows=40, cols=5)
+D = rand(rows=5, cols=60)
+
+E = A %*% B %*% C %*% D
+
+write(E, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test20.dml
b/src/test/scripts/functions/rewrite/mmchain/test20.dml
new file mode 100644
index 0000000000..fa0b55dd05
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test20.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=50, cols=10)
+B = rand(rows=30, cols=10)
+C = rand(rows=30, cols=5)
+D = rand(rows=20, cols=5)
+
+# (A %*% (t(B) %*% C) ) %*% t(D) -> optimal Plan
+
+R = A %*% t(B) %*% C %*% t(D)
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test21.dml
b/src/test/scripts/functions/rewrite/mmchain/test21.dml
new file mode 100644
index 0000000000..59277b1666
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test21.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=40, cols=5)
+B = rand(rows=6, cols=20)
+C = rand(rows=20, cols=3)
+D = rand(rows=3, cols=5)
+E = rand(rows=6, cols=50)
+
+R = A %*% t(B %*% C %*% D) %*% E
+
+# OPTIMAL PLAN: (A %*% t(D)) %*% (t(B %*% C) %*% E)
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test22.dml
b/src/test/scripts/functions/rewrite/mmchain/test22.dml
new file mode 100644
index 0000000000..ad1f6fbcdf
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test22.dml
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=34, cols=25)
+B = rand(rows=25, cols=18)
+C = rand(rows=34, cols=23)
+D = rand(rows=24, cols=15)
+E = rand(rows=15, cols=23)
+F = rand(rows=24, cols=16)
+
+
+# t(((E %*% t(C)) %*% A) %*% B) %*% (t(D) %*% F) -> optimal plan
+# 34*23 + 15*23*34 + 15 * 34 * 25 + 15*25*18 + 18*15 + 24*15 + 15*24*16 +
18*15*16 = 42722 FLOPs
+
+res = t(A %*% B) %*% C %*% t(D %*% E) %*% F
+
+write(res, $1)
+
+
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test23.dml
b/src/test/scripts/functions/rewrite/mmchain/test23.dml
new file mode 100644
index 0000000000..36672fe572
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test23.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=10, cols=5)
+B = rand(rows=5, cols=6)
+C = rand(rows=1000, cols=10)
+D = rand(rows=10, cols=20)
+E = rand(rows=20, cols=10)
+F = rand(rows=10, cols=10)
+
+# t( C %*% ( ( D %*% ( E %*% ( F %*% A ) ) ) %*% B ) ) -> optimal plan
+# - costs: 10*10*5 + 20*10*5 + 10*20*5 + 10*5*6 + 1000*10*6+ 1000*6 = 68800
FLOPs
+
+R = t(A %*% B) %*% t(C %*% D %*% E %*% F)
+
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test24.dml
b/src/test/scripts/functions/rewrite/mmchain/test24.dml
new file mode 100644
index 0000000000..4c075e83f4
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test24.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=6, cols=9)
+B = rand(rows=9, cols=5)
+C = rand(rows=4, cols=9)
+D = rand(rows=9, cols=5)
+
+# R = (A %*% B) %*% (t(D) %*% t(C)) -> optimal plan
+
+R = A %*% B %*% t(C %*% D)
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test3.dml
b/src/test/scripts/functions/rewrite/mmchain/test3.dml
new file mode 100644
index 0000000000..c2d7d2883d
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test3.dml
@@ -0,0 +1,42 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# A = 30x50
+# B = 50x2
+# C = 2x30
+# D = 30x5
+#
+# 1. ((A%*%B)%*%C)%*%D = (30*50*2) + (30*2*30) + (30*30*5) = 9300
+# 2. (A%*%(B%*%C))%*%D = (50*2*30) + (30*50*30) + (30*30*5) = 52500
+# 3. (A%*%B)%*%(C%*%D) = (30*50*2) + (2*30*5) + (30*2*5) = 3600 -> OPTIMAL PLAN
+# 4. A%*%((B%*%C)%*%D) = (50*2*30) + (50*30*5) + (30*50*5) = 18000
+# 5. A%*%(B%*%(C%*%D)) = (2*30*5) + (50*2*5) + (30*50*5) = 8300
+
+A = rand(rows=30, cols=50)
+B = rand(rows=50, cols=2)
+C = rand(rows=2, cols=30)
+D = rand(rows=30, cols=5)
+
+E = A %*% B %*% C %*% D
+
+
+write(E, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test4.dml
b/src/test/scripts/functions/rewrite/mmchain/test4.dml
new file mode 100644
index 0000000000..6e231a7bc4
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test4.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=5, cols=9)
+B = rand(rows=9, cols=8)
+C = rand(rows=5, cols=4)
+
+D = t(A %*% B) %*% C
+
+# optimal plan:
+# t ( (t(C) * A) * B) = 5*4 + 4*5*9 + 4*9*8 + 4*9 = 524
+
+
+write(D, $1)
+
diff --git a/src/test/scripts/functions/rewrite/mmchain/test5.dml
b/src/test/scripts/functions/rewrite/mmchain/test5.dml
new file mode 100644
index 0000000000..4596f46d53
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test5.dml
@@ -0,0 +1,38 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# A = 8x3
+# B = 3x8
+# C = 8x4
+#
+# 1. t(A %*% B) %*% C = (8*3*8) + (8*8) + (8*8*3) = 512
+# 2. (t(B) %*% t(A)) %*% C = (3*8) + (8*3) + (8*3*8) + (8*8*4) = 496
+# 3. t(B) %*% (t(A) %*% C) = (8*3) + (3*8*8) + (3*8) +(8*3*4) = 336 -> OPTIMAL
PLAN
+
+
+A = rand(rows=8, cols=3)
+B = rand(rows=3, cols=8)
+C = rand(rows=8, cols=4)
+
+D = t(A %*% B) %*% C
+
+write(D, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test6.dml
b/src/test/scripts/functions/rewrite/mmchain/test6.dml
new file mode 100644
index 0000000000..14b1fc3d0a
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test6.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=2, cols=6)
+B = rand(rows=2, cols=6)
+C = rand(rows=2, cols=8)
+D = rand(rows=9, cols=8)
+E = rand(rows=9, cols=2)
+
+# (A %*% t(B)) %*% (C %*% (t(D) %*% E)) -> optimal plan
+
+R = A %*% t(B) %*% C %*% t(D) %*% E
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test7.dml
b/src/test/scripts/functions/rewrite/mmchain/test7.dml
new file mode 100644
index 0000000000..83d481ae05
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test7.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# A = 1000x1
+# B = 1x2
+# C = 1x1000
+#
+#
+# R = t(A %*% (B %*% C)) -> optimal plan
+
+A = rand(rows=1000, cols=1)
+B = rand(rows=1, cols=2)
+C = rand(rows=2, cols=1000)
+
+R = t(A %*% B %*% C)
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test8.dml
b/src/test/scripts/functions/rewrite/mmchain/test8.dml
new file mode 100644
index 0000000000..e0ff113fd3
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test8.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=30, cols=5)
+B = rand(rows=5, cols=2)
+C = rand(rows=30, cols=6)
+D = rand(rows=6, cols=2)
+E = rand(rows=30, cols=4)
+
+# (A %*% B) %*% (t(C %*% D) %*% E) -> 30*6*2 + 30*2 + 2*30*4 + 30*5*2 + 30*2*4
= 1200 FLOPs
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/mmchain/test9.dml
b/src/test/scripts/functions/rewrite/mmchain/test9.dml
new file mode 100644
index 0000000000..4bfb8b2c43
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/mmchain/test9.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rand(rows=3, cols=3)
+B = rand(rows=3, cols=2)
+C = rand(rows=10, cols=20)
+D = rand(rows=20, cols=2)
+E = rand(rows=10, cols=30)
+
+# ((A %*% B) %*% t(C %*% D)) %*% E -> optimal plan (3 x 2 + 2 x 10) + 10 x 30
-> 60 + 30*30 = 2100
+# (A %*% B) %*% (t(C %*% D) %*% E) -> or this?? 3 x 2 + (2 x 10 + 10 x
30) -> 20*30 + 3*2*30 = 1200
+
+
+R = A %*% B %*% t(C %*% D) %*% E
+
+write(R, $1)
+