This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b2d83e915 [SYSTEMDS-3861] Fix redundant transposes due to multi-level 
rewrites
4b2d83e915 is described below

commit 4b2d83e915c40b7433580cddfec68aa8c440ba05
Author: aarna <aarnatya...@gmail.com>
AuthorDate: Fri Apr 18 12:43:04 2025 +0200

    [SYSTEMDS-3861] Fix redundant transposes due to multi-level rewrites
    
    Closes #2249.
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    | 487 ++++++++++-----------
 .../hops/fedplanner/FederatedMemoTablePrinter.java |  19 +
 .../functions/rewrite/RewriteTransposeTest.java    |  86 ++++
 .../functions/rewrite/RewriteTransposeCase1.R      |  32 ++
 .../functions/rewrite/RewriteTransposeCase1.dml    |  27 ++
 .../functions/rewrite/RewriteTransposeCase2.R      |  32 ++
 .../functions/rewrite/RewriteTransposeCase2.dml    |  28 ++
 .../functions/rewrite/RewriteTransposeCase3.R      |  33 ++
 .../functions/rewrite/RewriteTransposeCase3.dml    |  28 ++
 9 files changed, 519 insertions(+), 253 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 2cf651f189..5f9c6b41b3 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -43,6 +43,7 @@ import org.apache.sysds.lops.MatMultCP;
 import org.apache.sysds.lops.PMMJ;
 import org.apache.sysds.lops.PMapMult;
 import org.apache.sysds.lops.Transform;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -65,7 +66,7 @@ public class AggBinaryOp extends MultiThreadedHop {
        public static final double MAPMULT_MEM_MULTIPLIER = 1.0;
        public static MMultMethod FORCED_MMULT_METHOD = null;
 
-       public enum MMultMethod { 
+       public enum MMultMethod {
                CPMM,     //cross-product matrix multiplication (mr)
                RMM,      //replication matrix multiplication (mr)
                MAPMM_L,  //map-side matrix-matrix multiplication using 
distributed cache (mr/sp)
@@ -78,27 +79,27 @@ public class AggBinaryOp extends MultiThreadedHop {
                ZIPMM,    //zip matrix multiplication (sp)
                MM        //in-memory matrix multiplication (cp)
        }
-       
-       public enum SparkAggType{
+
+       public enum SparkAggType {
                NONE,
                SINGLE_BLOCK,
                MULTI_BLOCK,
        }
-       
+
        private OpOp2 innerOp;
        private AggOp outerOp;
 
        private MMultMethod _method = null;
-       
+
        //hints set by previous to operator selection
        private boolean _hasLeftPMInput = false; //left input is permutation 
matrix
-       
+
        private AggBinaryOp() {
                //default constructor for clone
        }
-       
+
        public AggBinaryOp(String l, DataType dt, ValueType vt, OpOp2 innOp,
-                       AggOp outOp, Hop in1, Hop in2) {
+                                          AggOp outOp, Hop in1, Hop in2) {
                super(l, dt, vt);
                innerOp = innOp;
                outerOp = outOp;
@@ -106,7 +107,7 @@ public class AggBinaryOp extends MultiThreadedHop {
                getInput().add(1, in2);
                in1.getParent().add(this);
                in2.getParent().add(this);
-               
+
                //compute unknown dims and nnz
                refreshSizeInformation();
        }
@@ -114,30 +115,30 @@ public class AggBinaryOp extends MultiThreadedHop {
        public void setHasLeftPMInput(boolean flag) {
                _hasLeftPMInput = flag;
        }
-       
-       public boolean hasLeftPMInput(){
+
+       public boolean hasLeftPMInput() {
                return _hasLeftPMInput;
        }
 
-       public MMultMethod getMMultMethod(){
+       public MMultMethod getMMultMethod() {
                return _method;
        }
-       
+
        @Override
        public boolean isGPUEnabled() {
-               if(!DMLScript.USE_ACCELERATOR)
+               if (!DMLScript.USE_ACCELERATOR)
                        return false;
-               
+
                Hop input1 = getInput().get(0);
                Hop input2 = getInput().get(1);
                //matrix mult operation selection part 2 (specific pattern)
                MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern
                ChainType chain = checkMapMultChain(); //determine mmchain 
pattern
-               
-               _method = optFindMMultMethodCP ( input1.getDim1(), 
input1.getDim2(),   
-                             input2.getDim1(), input2.getDim2(), mmtsj, chain, 
_hasLeftPMInput );
-               switch( _method ){
-                       case TSMM: 
+
+               _method = optFindMMultMethodCP(input1.getDim1(), 
input1.getDim2(),
+                               input2.getDim1(), input2.getDim2(), mmtsj, 
chain, _hasLeftPMInput);
+               switch (_method) {
+                       case TSMM:
                                //return false; // TODO: Disabling any fused 
transa optimization in 1.0 release.
                                return true;
                        case MAPMM_CHAIN:
@@ -150,50 +151,47 @@ public class AggBinaryOp extends MultiThreadedHop {
                                throw new RuntimeException("Unsupported 
method:" + _method);
                }
        }
-       
+
        /**
         * NOTE: overestimated mem in case of transpose-identity matmult, but 
3/2 at worst
-        *       and existing mem estimate advantageous in terms of consistency 
hops/lops,
-        *       and some special cases internally materialize the transpose 
for better cache locality  
+        * and existing mem estimate advantageous in terms of consistency 
hops/lops,
+        * and some special cases internally materialize the transpose for 
better cache locality
         */
        @Override
-       public Lop constructLops() 
-       {
+       public Lop constructLops() {
                //return already created lops
-               if( getLops() != null )
+               if (getLops() != null)
                        return getLops();
-       
+
                //construct matrix mult lops (currently only supported 
aggbinary)
-               if ( isMatrixMultiply() ) 
-               {
+               if (isMatrixMultiply()) {
                        Hop input1 = getInput().get(0);
                        Hop input2 = getInput().get(1);
-                       
+
                        //matrix mult operation selection part 1 (CP vs MR vs 
Spark)
                        ExecType et = optFindExecType();
-                       
+
                        //matrix mult operation selection part 2 (specific 
pattern)
                        MMTSJType mmtsj = checkTransposeSelf(); //determine 
tsmm pattern
                        ChainType chain = checkMapMultChain(); //determine 
mmchain pattern
 
-                       if(mmtsj == MMTSJType.LEFT && 
input2.isCompressedOutput()){
+                       if (mmtsj == MMTSJType.LEFT && 
input2.isCompressedOutput()) {
                                // if tsmm and input is compressed. (using 
input2, since input1 is transposed and therefore not compressed.)
                                et = ExecType.CP;
                        }
 
-                       if( et == ExecType.CP || et == ExecType.GPU || et == 
ExecType.FED )
-                       {
+                       if (et == ExecType.CP || et == ExecType.GPU || et == 
ExecType.FED) {
                                //matrix mult operation selection part 3 (CP 
type)
-                               _method = optFindMMultMethodCP ( 
input1.getDim1(), input1.getDim2(),   
-                                                     input2.getDim1(), 
input2.getDim2(), mmtsj, chain, _hasLeftPMInput );
-                               
+                               _method = 
optFindMMultMethodCP(input1.getDim1(), input1.getDim2(),
+                                               input2.getDim1(), 
input2.getDim2(), mmtsj, chain, _hasLeftPMInput);
+
                                //dispatch CP lops construction 
-                               switch( _method ){
-                                       case TSMM: 
-                                               constructCPLopsTSMM( mmtsj, et 
);
+                               switch (_method) {
+                                       case TSMM:
+                                               constructCPLopsTSMM(mmtsj, et);
                                                break;
                                        case MAPMM_CHAIN:
-                                               constructCPLopsMMChain( chain );
+                                               constructCPLopsMMChain(chain);
                                                break;
                                        case PMM:
                                                constructCPLopsPMM();
@@ -204,53 +202,49 @@ public class AggBinaryOp extends MultiThreadedHop {
                                        default:
                                                throw new 
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + 
_method + ") while constructing CP lops.");
                                }
-                       }
-                       else if( et == ExecType.SPARK ) 
-                       {
+                       } else if (et == ExecType.SPARK) {
                                //matrix mult operation selection part 3 (SPARK 
type)
                                boolean tmmRewrite = 
HopRewriteUtils.isTransposeOperation(input1);
-                               _method = optFindMMultMethodSpark ( 
+                               _method = optFindMMultMethodSpark(
                                                input1.getDim1(), 
input1.getDim2(), input1.getBlocksize(), input1.getNnz(),
                                                input2.getDim1(), 
input2.getDim2(), input2.getBlocksize(), input2.getNnz(),
-                                               mmtsj, chain, _hasLeftPMInput, 
tmmRewrite );
+                                               mmtsj, chain, _hasLeftPMInput, 
tmmRewrite);
                                //dispatch SPARK lops construction
-                               switch( _method )
-                               {
+                               switch (_method) {
                                        case TSMM:
-                                       case TSMM2:     
-                                               constructSparkLopsTSMM( mmtsj, 
_method==MMultMethod.TSMM2 );
+                                       case TSMM2:
+                                               constructSparkLopsTSMM(mmtsj, 
_method == MMultMethod.TSMM2);
                                                break;
                                        case MAPMM_L:
                                        case MAPMM_R:
-                                               constructSparkLopsMapMM( 
_method );
+                                               
constructSparkLopsMapMM(_method);
                                                break;
                                        case MAPMM_CHAIN:
-                                               constructSparkLopsMapMMChain( 
chain );
+                                               
constructSparkLopsMapMMChain(chain);
                                                break;
                                        case PMAPMM:
                                                constructSparkLopsPMapMM();
                                                break;
-                                       case CPMM:      
+                                       case CPMM:
                                                constructSparkLopsCPMM();
                                                break;
-                                       case RMM:       
+                                       case RMM:
                                                constructSparkLopsRMM();
                                                break;
                                        case PMM:
-                                               constructSparkLopsPMM(); 
+                                               constructSparkLopsPMM();
                                                break;
                                        case ZIPMM:
                                                constructSparkLopsZIPMM();
                                                break;
-                                               
+
                                        default:
-                                               throw new 
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + 
_method + ") while constructing SPARK lops.");     
+                                               throw new 
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + 
_method + ") while constructing SPARK lops.");
                                }
                        }
-               } 
-               else
+               } else
                        throw new HopsException(this.printErrorLocation() + 
"Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") 
while constructing lops.");
-               
+
                //add reblock/checkpoint lops if necessary
                constructAndSetLopsDataFlowProperties();
 
@@ -260,30 +254,28 @@ public class AggBinaryOp extends MultiThreadedHop {
        @Override
        public String getOpString() {
                //ba - binary aggregate, for consistency with runtime 
-               return "ba(" + outerOp.toString() + innerOp.toString()+")";
+               return "ba(" + outerOp.toString() + innerOp.toString() + ")";
        }
-       
+
        @Override
-       public void computeMemEstimate(MemoTable memo) 
-       {
+       public void computeMemEstimate(MemoTable memo) {
                //extension of default compute memory estimate in order to 
                //account for smaller tsmm memory requirements.
                super.computeMemEstimate(memo);
-               
+
                //tsmm left is guaranteed to require only X but not t(X), while
                //tsmm right might have additional requirements to transpose X 
if sparse
                //NOTE: as a heuristic this correction is only applied if not a 
column vector because
                //most other vector operations require memory for at least two 
vectors (we aim for 
                //consistency in order to prevent anomalies in parfor opt 
leading to small degree of par)
                MMTSJType mmtsj = checkTransposeSelf();
-               if( mmtsj.isLeft() && getInput().get(1).dimsKnown() && 
getInput().get(1).getDim2()>1 ) {
+               if (mmtsj.isLeft() && getInput().get(1).dimsKnown() && 
getInput().get(1).getDim2() > 1) {
                        _memEstimate = _memEstimate - 
getInput().get(0)._outputMemEstimate;
                }
        }
 
        @Override
-       protected double computeOutputMemEstimate( long dim1, long dim2, long 
nnz )
-       {               
+       protected double computeOutputMemEstimate(long dim1, long dim2, long 
nnz) {
                //NOTES:  
                // * The estimate for transpose-self is the same as for normal 
matrix multiplications
                //   because (1) this decouples the decision of TSMM over 
default MM and (2) some cases
@@ -314,10 +306,9 @@ public class AggBinaryOp extends MultiThreadedHop {
 
                return ret;
        }
-       
+
        @Override
-       protected double computeIntermediateMemEstimate( long dim1, long dim2, 
long nnz )
-       {
+       protected double computeIntermediateMemEstimate(long dim1, long dim2, 
long nnz) {
                double ret = 0;
 
                if (isGPUEnabled()) {
@@ -327,277 +318,254 @@ public class AggBinaryOp extends MultiThreadedHop {
                        double in2Sparsity = 
OptimizerUtils.getSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz());
                        boolean in1Sparse = in1Sparsity < 
MatrixBlock.SPARSITY_TURN_POINT;
                        boolean in2Sparse = in2Sparsity < 
MatrixBlock.SPARSITY_TURN_POINT;
-                       if(in1Sparse && !in2Sparse) {
+                       if (in1Sparse && !in2Sparse) {
                                // Only in sparse-dense cases, we need 
additional memory budget for GPU
                                ret += 
OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
                        }
                }
 
                //account for potential final dense-sparse transformation 
(worst-case sparse representation)
-               if( dim2 >= 2 && nnz != 0 ) //vectors always dense
+               if (dim2 >= 2 && nnz != 0) //vectors always dense
                        ret += MatrixBlock.estimateSizeSparseInMemory(dim1, 
dim2,
-                               MatrixBlock.SPARSITY_TURN_POINT - 
UtilFunctions.DOUBLE_EPS);
-               
+                                       MatrixBlock.SPARSITY_TURN_POINT - 
UtilFunctions.DOUBLE_EPS);
+
                return ret;
        }
-       
+
        @Override
-       protected DataCharacteristics inferOutputCharacteristics( MemoTable 
memo )
-       {
+       protected DataCharacteristics inferOutputCharacteristics(MemoTable 
memo) {
                DataCharacteristics[] dc = memo.getAllInputStats(getInput());
                DataCharacteristics ret = null;
-               if( dc[0].rowsKnown() && dc[1].colsKnown() ) {
+               if (dc[0].rowsKnown() && dc[1].colsKnown()) {
                        ret = new MatrixCharacteristics(dc[0].getRows(), 
dc[1].getCols());
-                       double sp1 = (dc[0].getNonZeros()>0) ? 
OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(), 
dc[0].getNonZeros()) : 1.0;
-                       double sp2 = (dc[1].getNonZeros()>0) ? 
OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(), 
dc[1].getNonZeros()) : 1.0;
-                       ret.setNonZeros((long)(ret.getLength() * 
OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(), 
ret.getCols(), true)));
+                       double sp1 = (dc[0].getNonZeros() > 0) ? 
OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(), 
dc[0].getNonZeros()) : 1.0;
+                       double sp2 = (dc[1].getNonZeros() > 0) ? 
OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(), 
dc[1].getNonZeros()) : 1.0;
+                       ret.setNonZeros((long) (ret.getLength() * 
OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(), 
ret.getCols(), true)));
                }
                return ret;
        }
-       
+
 
        public boolean isMatrixMultiply() {
-               return ( this.innerOp == OpOp2.MULT && this.outerOp == 
AggOp.SUM );
+               return (this.innerOp == OpOp2.MULT && this.outerOp == 
AggOp.SUM);
        }
-       
+
        private boolean isOuterProduct() {
-               return ( getInput().get(0).isVector() && 
getInput().get(1).isVector() )
-                       && ( getInput().get(0).getDim1() == 1 && 
getInput().get(0).getDim1() > 1
-                                       && getInput().get(1).getDim1() > 1 && 
getInput().get(1).getDim2() == 1 );
+               return (getInput().get(0).isVector() && 
getInput().get(1).isVector())
+                               && (getInput().get(0).getDim1() == 1 && 
getInput().get(0).getDim1() > 1
+                               && getInput().get(1).getDim1() > 1 && 
getInput().get(1).getDim2() == 1);
        }
-       
+
        @Override
        public boolean isMultiThreadedOpType() {
                return isMatrixMultiply();
        }
-       
+
        @Override
-       public boolean allowsAllExecTypes()
-       {
+       public boolean allowsAllExecTypes() {
                return true;
        }
-       
+
        @Override
-       protected ExecType optFindExecType(boolean transitive)
-       {
+       protected ExecType optFindExecType(boolean transitive) {
                checkAndSetForcedPlatform();
-               
-               if( _etypeForced != null ) {
+
+               if (_etypeForced != null) {
                        setExecType(_etypeForced);
-               }
-               else 
-               {
-                       if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
+               } else {
+                       if (OptimizerUtils.isMemoryBasedOptLevel()) {
                                setExecType(findExecTypeByMemEstimate());
                        }
                        // choose CP if the dimensions of both inputs are below 
Hops.CPThreshold 
                        // OR if it is vector-vector inner product
-                       else if ( (getInput().get(0).areDimsBelowThreshold() && 
getInput().get(1).areDimsBelowThreshold())
-                                               || 
(getInput().get(0).isVector() && getInput().get(1).isVector() && 
!isOuterProduct()) )
-                       {
+                       else if ((getInput().get(0).areDimsBelowThreshold() && 
getInput().get(1).areDimsBelowThreshold())
+                                       || (getInput().get(0).isVector() && 
getInput().get(1).isVector() && !isOuterProduct())) {
                                setExecType(ExecType.CP);
-                       }
-                       else
-                       {
+                       } else {
                                setExecType(ExecType.SPARK);
                        }
-                       
+
                        //check for valid CP mmchain, send invalid memory 
requirements to remote
-                       if( _etype == ExecType.CP
-                               && checkMapMultChain() != ChainType.NONE
-                               && OptimizerUtils.getLocalMemBudget() < 
-                               
getInput().get(0).getInput().get(0).getOutputMemEstimate() ) {
+                       if (_etype == ExecType.CP
+                                       && checkMapMultChain() != ChainType.NONE
+                                       && OptimizerUtils.getLocalMemBudget() <
+                                       
getInput().get(0).getInput().get(0).getOutputMemEstimate()) {
                                setExecType(ExecType.SPARK);
                        }
-                       
+
                        //check for valid CP dimensions and matrix size
                        checkAndSetInvalidCPDimsAndSize();
                }
-               
+
                //spark-specific decision refinement (execute binary aggregate 
w/ left or right spark input and 
                //single parent also in spark because it's likely cheap and 
reduces data transfer)
                MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern
-               if( transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP 
-                       && ((!mmtsj.isLeft() && 
isApplicableForTransitiveSparkExecType(true))
-                       || ( !mmtsj.isRight() && 
isApplicableForTransitiveSparkExecType(false))) )
-               {
+               if (transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP
+                               && ((!mmtsj.isLeft() && 
isApplicableForTransitiveSparkExecType(true))
+                               || (!mmtsj.isRight() && 
isApplicableForTransitiveSparkExecType(false)))) {
                        //pull binary aggregate into spark 
                        setExecType(ExecType.SPARK);
                }
 
                //mark for recompile (forever)
                setRequiresRecompileIfNecessary();
-               
+
                return _etype;
        }
-       
-       private boolean isApplicableForTransitiveSparkExecType(boolean left) 
-       {
+
+       private boolean isApplicableForTransitiveSparkExecType(boolean left) {
                int index = left ? 0 : 1;
-               return !(getInput(index) instanceof DataOp && 
((DataOp)getInput(index)).requiresCheckpoint())
-                       && 
(!HopRewriteUtils.isTransposeOperation(getInput(index))
+               return !(getInput(index) instanceof DataOp && ((DataOp) 
getInput(index)).requiresCheckpoint())
+                               && 
(!HopRewriteUtils.isTransposeOperation(getInput(index))
                                || (left && 
!isLeftTransposeRewriteApplicable(true)))
-                       && getInput(index).getParent().size()==1 //bagg is only 
parent
-                       && !getInput(index).areDimsBelowThreshold() 
-                       && (getInput(index).optFindExecType() == ExecType.SPARK
-                               || (getInput(index) instanceof DataOp && 
((DataOp)getInput(index)).hasOnlyRDD()))
-                       && 
getInput(index).getOutputMemEstimate()>getOutputMemEstimate();
+                               && getInput(index).getParent().size() == 1 
//bagg is only parent
+                               && !getInput(index).areDimsBelowThreshold()
+                               && (getInput(index).optFindExecType() == 
ExecType.SPARK
+                               || (getInput(index) instanceof DataOp && 
((DataOp) getInput(index)).hasOnlyRDD()))
+                               && getInput(index).getOutputMemEstimate() > 
getOutputMemEstimate();
        }
-       
+
        /**
         * TSMM: Determine if XtX pattern applies for this aggbinary and if yes
-        * which type. 
-        * 
+        * which type.
+        *
         * @return MMTSJType
         */
-       public MMTSJType checkTransposeSelf()
-       {
+       public MMTSJType checkTransposeSelf() {
                MMTSJType ret = MMTSJType.NONE;
-               
+
                Hop in1 = getInput().get(0);
                Hop in2 = getInput().get(1);
-               
-               if( HopRewriteUtils.isTransposeOperation(in1)
-                       && in1.getInput().get(0) == in2 )
-               {
+
+               if (HopRewriteUtils.isTransposeOperation(in1)
+                               && in1.getInput().get(0) == in2) {
                        ret = MMTSJType.LEFT;
                }
-               
-               if( HopRewriteUtils.isTransposeOperation(in2) 
-                       && in2.getInput().get(0) == in1 )
-               {
+
+               if (HopRewriteUtils.isTransposeOperation(in2)
+                               && in2.getInput().get(0) == in1) {
                        ret = MMTSJType.RIGHT;
                }
-               
+
                return ret;
        }
 
        /**
-        * MapMultChain: Determine if XtwXv/XtXv pattern applies for this 
aggbinary 
-        * and if yes which type. 
-        * 
+        * MapMultChain: Determine if XtwXv/XtXv pattern applies for this 
aggbinary
+        * and if yes which type.
+        *
         * @return ChainType
         */
-       public ChainType checkMapMultChain()
-       {
+       public ChainType checkMapMultChain() {
                ChainType chainType = ChainType.NONE;
-               
+
                Hop in1 = getInput().get(0);
                Hop in2 = getInput().get(1);
-               
+
                //check for transpose left input (both chain types)
-               if( HopRewriteUtils.isTransposeOperation(in1) )
-               {
+               if (HopRewriteUtils.isTransposeOperation(in1)) {
                        Hop X = in1.getInput().get(0);
-                               
+
                        //check mapmultchain patterns
                        //t(X)%*%(w*(X%*%v))
-                       if( in2 instanceof BinaryOp && 
((BinaryOp)in2).getOp()==OpOp2.MULT )
-                       {
+                       if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() 
== OpOp2.MULT) {
                                Hop in3b = in2.getInput().get(1);
-                               if( in3b instanceof AggBinaryOp )
-                               {
+                               if (in3b instanceof AggBinaryOp) {
                                        Hop in4 = in3b.getInput().get(0);
-                                       if( X == in4 ) //common input
+                                       if (X == in4) //common input
                                                chainType = ChainType.XtwXv;
                                }
                        }
                        //t(X)%*%((X%*%v)-y)
-                       else if( in2 instanceof BinaryOp && 
((BinaryOp)in2).getOp()==OpOp2.MINUS )
-                       {
+                       else if (in2 instanceof BinaryOp && ((BinaryOp) 
in2).getOp() == OpOp2.MINUS) {
                                Hop in3a = in2.getInput().get(0);
-                               Hop in3b = in2.getInput().get(1);               
                
-                               if( in3a instanceof AggBinaryOp && 
in3b.getDataType()==DataType.MATRIX )
-                               {
+                               Hop in3b = in2.getInput().get(1);
+                               if (in3a instanceof AggBinaryOp && 
in3b.getDataType() == DataType.MATRIX) {
                                        Hop in4 = in3a.getInput().get(0);
-                                       if( X == in4 ) //common input
+                                       if (X == in4) //common input
                                                chainType = ChainType.XtXvy;
                                }
                        }
                        //t(X)%*%(X%*%v)
-                       else if( in2 instanceof AggBinaryOp )
-                       {
+                       else if (in2 instanceof AggBinaryOp) {
                                Hop in3 = in2.getInput().get(0);
-                               if( X == in3 ) //common input
+                               if (X == in3) //common input
                                        chainType = ChainType.XtXv;
                        }
                }
-               
+
                return chainType;
        }
-       
+
        //////////////////////////
        // CP Lops generation
-       /////////////////////////
-       
-       private void constructCPLopsTSMM( MMTSJType mmtsj, ExecType et ) {
+
+       /// //////////////////////
+
+       private void constructCPLopsTSMM(MMTSJType mmtsj, ExecType et) {
                int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-               Lop matmultCP = new 
MMTSJ(getInput().get(mmtsj.isLeft()?1:0).constructLops(),
-                       getDataType(), getValueType(), et, mmtsj, false, k);
+               Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft() ? 1 : 
0).constructLops(),
+                               getDataType(), getValueType(), et, mmtsj, 
false, k);
                matmultCP.getOutputParameters().setDimensions(getDim1(), 
getDim2(), getBlocksize(), getNnz());
-               setLineNumbers( matmultCP );
+               setLineNumbers(matmultCP);
                setLops(matmultCP);
        }
 
-       private void constructCPLopsMMChain( ChainType chain )
-       {
+       private void constructCPLopsMMChain(ChainType chain) {
                MapMultChain mapmmchain = null;
-               if( chain == ChainType.XtXv ) {
+               if (chain == ChainType.XtXv) {
                        Hop hX = getInput().get(0).getInput().get(0);
                        Hop hv = getInput().get(1).getInput().get(1);
-                       mapmmchain = new MapMultChain( hX.constructLops(), 
hv.constructLops(), getDataType(), getValueType(), ExecType.CP);
-               }
-               else { //ChainType.XtwXv / ChainType.XtwXvy
+                       mapmmchain = new MapMultChain(hX.constructLops(), 
hv.constructLops(), getDataType(), getValueType(), ExecType.CP);
+               } else { //ChainType.XtwXv / ChainType.XtwXvy
                        int wix = (chain == ChainType.XtwXv) ? 0 : 1;
                        int vix = (chain == ChainType.XtwXv) ? 1 : 0;
                        Hop hX = getInput().get(0).getInput().get(0);
                        Hop hw = getInput().get(1).getInput().get(wix);
                        Hop hv = 
getInput().get(1).getInput().get(vix).getInput().get(1);
-                       mapmmchain = new MapMultChain( hX.constructLops(), 
hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), 
ExecType.CP);
+                       mapmmchain = new MapMultChain(hX.constructLops(), 
hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), 
ExecType.CP);
                }
-               
+
                //set degree of parallelism
                int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-               mapmmchain.setNumThreads( k );
-               
+               mapmmchain.setNumThreads(k);
+
                //set basic lop properties
                setOutputDimensions(mapmmchain);
                setLineNumbers(mapmmchain);
                setLops(mapmmchain);
        }
-       
+
        /**
         * NOTE: exists for consistency since removeEmtpy might be scheduled to 
MR
-        * but matrix mult on small output might be scheduled to CP. Hence, we 
+        * but matrix mult on small output might be scheduled to CP. Hence, we
         * need to handle directly passed selection vectors in CP as well.
         */
-       private void constructCPLopsPMM() 
-       {
+       private void constructCPLopsPMM() {
                Hop pmInput = getInput().get(0);
                Hop rightInput = getInput().get(1);
-               
+
                Hop nrow = HopRewriteUtils.createValueHop(pmInput, true); //NROW
                nrow.setBlocksize(0);
                nrow.setForcedExecType(ExecType.CP);
                HopRewriteUtils.copyLineNumbers(this, nrow);
                Lop lnrow = nrow.constructLops();
-               
+
                PMMJ pmm = new PMMJ(pmInput.constructLops(), 
rightInput.constructLops(), lnrow, getDataType(), getValueType(), false, false, 
ExecType.CP);
-               
+
                //set degree of parallelism
                int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                pmm.setNumThreads(k);
-               
+
                pmm.getOutputParameters().setDimensions(getDim1(), getDim2(), 
getBlocksize(), getNnz());
                setLineNumbers(pmm);
-               
+
                setLops(pmm);
-               
+
                HopRewriteUtils.removeChildReference(pmInput, nrow);
        }
 
-       private void constructCPLopsMM(ExecType et) 
-       {
+       private void constructCPLopsMM(ExecType et) {
                Lop matmultCP = null;
                String cla = 
ConfigurationManager.getDMLConfig().getTextValue("sysds.compressed.linalg");
                if (et == ExecType.GPU) {
@@ -610,72 +578,85 @@ public class AggBinaryOp extends MultiThreadedHop {
                        boolean leftTrans = false; // 
HopRewriteUtils.isTransposeOperation(h1);
                        boolean rightTrans = false; // 
HopRewriteUtils.isTransposeOperation(h2);
                        Lop left = !leftTrans ? h1.constructLops() :
-                               h1.getInput().get(0).constructLops();
+                                       h1.getInput().get(0).constructLops();
                        Lop right = !rightTrans ? h2.constructLops() :
-                               h2.getInput().get(0).constructLops();
+                                       h2.getInput().get(0).constructLops();
                        matmultCP = new MatMultCP(left, right, getDataType(), 
getValueType(), et, leftTrans, rightTrans);
                        setOutputDimensions(matmultCP);
-               }
-               else if (cla.equals("true") || cla.equals("cost")){
+               } else if (cla.equals("true") || cla.equals("cost")) {
                        Hop h1 = getInput().get(0);
                        Hop h2 = getInput().get(1);
                        int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                        boolean leftTrans = 
HopRewriteUtils.isTransposeOperation(h1);
-                       boolean rightTrans =  
HopRewriteUtils.isTransposeOperation(h2);
+                       boolean rightTrans = 
HopRewriteUtils.isTransposeOperation(h2);
                        Lop left = !leftTrans ? h1.constructLops() :
-                               h1.getInput().get(0).constructLops();
+                                       h1.getInput().get(0).constructLops();
                        Lop right = !rightTrans ? h2.constructLops() :
-                               h2.getInput().get(0).constructLops();
+                                       h2.getInput().get(0).constructLops();
                        matmultCP = new MatMultCP(left, right, getDataType(), 
getValueType(), et, k, leftTrans, rightTrans);
-               }
-               else {
-                       if( isLeftTransposeRewriteApplicable(true) ) {
+               } else {
+                       if (isLeftTransposeRewriteApplicable(true)) {
                                matmultCP = 
constructCPLopsMMWithLeftTransposeRewrite(et);
-                       }
-                       else { 
+                       } else {
                                int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                                matmultCP = new 
MatMultCP(getInput().get(0).constructLops(),
-                                       getInput().get(1).constructLops(), 
getDataType(), getValueType(), et, k);
+                                               
getInput().get(1).constructLops(), getDataType(), getValueType(), et, k);
                                updateLopFedOut(matmultCP);
                        }
                        setOutputDimensions(matmultCP);
                }
-               
+
                setLineNumbers(matmultCP);
                setLops(matmultCP);
        }
 
-       private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) 
-       {
-               Hop X = getInput().get(0).getInput().get(0); //guaranteed to 
exists
+       private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) {
+               Hop X = getInput().get(0).getInput().get(0); // guaranteed to 
exist
                Hop Y = getInput().get(1);
                int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-               
+
+               //Check if X is already a transpose operation
+               boolean isXTransposed = X instanceof ReorgOp && 
((ReorgOp)X).getOp() == ReOrgOp.TRANS;
+               Hop actualX = isXTransposed ? X.getInput().get(0) : X;
+
+               //Check if Y is a transpose operation
+               boolean isYTransposed = Y instanceof ReorgOp && 
((ReorgOp)Y).getOp() == ReOrgOp.TRANS;
+               Hop actualY = isYTransposed ? Y.getInput().get(0) : Y;
+
+               //Handle Y or actualY for transpose
+               Lop yLop = isYTransposed ? actualY.constructLops() : 
Y.constructLops();
+               ExecType inputReorgExecType = (Y.hasFederatedOutput()) ? 
ExecType.FED : ExecType.CP;
+
                //right vector transpose
-               Lop lY = Y.constructLops();
-               ExecType inputReorgExecType = ( Y.hasFederatedOutput() ) ? 
ExecType.FED : ExecType.CP;
-               Lop tY = (lY instanceof Transform && 
((Transform)lY).getOp()==ReOrgOp.TRANS ) ?
-                               lY.getInputs().get(0) : //if input is already a 
transpose, avoid redundant transpose ops
-                               new Transform(lY, ReOrgOp.TRANS, getDataType(), 
getValueType(), inputReorgExecType, k);
-               tY.getOutputParameters().setDimensions(Y.getDim2(), 
Y.getDim1(), getBlocksize(), Y.getNnz());
+               Lop tY = (yLop instanceof Transform && 
((Transform)yLop).getOp() == ReOrgOp.TRANS) ?
+                               yLop.getInputs().get(0) : //if input is already 
a transpose, avoid redundant transpose ops
+                               new Transform(yLop, ReOrgOp.TRANS, 
getDataType(), getValueType(), inputReorgExecType, k);
+
+               //Set dimensions for tY
+               long tYRows = isYTransposed ? actualY.getDim1() : Y.getDim2();
+               long tYCols = isYTransposed ? actualY.getDim2() : Y.getDim1();
+               tY.getOutputParameters().setDimensions(tYRows, tYCols, 
getBlocksize(), Y.getNnz());
                setLineNumbers(tY);
                if (Y.hasFederatedOutput())
                        updateLopFedOut(tY);
-               
+
+               //Construct X lops for matrix multiplication
+               Lop xLop = isXTransposed ? actualX.constructLops() : 
X.constructLops();
+
                //matrix mult
-               Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(), 
getValueType(), et, k); //CP or FED
-               mult.getOutputParameters().setDimensions(Y.getDim2(), 
X.getDim2(), getBlocksize(), getNnz());
+               Lop mult = new MatMultCP(tY, xLop, getDataType(), 
getValueType(), et, k);
+               mult.getOutputParameters().setDimensions(tYRows, isXTransposed 
? actualX.getDim1() : X.getDim2(), getBlocksize(), getNnz());
                mult.setFederatedOutput(_federatedOutput);
                setLineNumbers(mult);
 
                //result transpose (dimensions set outside)
-               ExecType outTransposeExecType = ( _federatedOutput == 
FederatedOutput.FOUT ) ?
-                       ExecType.FED : ExecType.CP;
+               ExecType outTransposeExecType = (_federatedOutput == 
FederatedOutput.FOUT) ?
+                               ExecType.FED : ExecType.CP;
                Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), 
getValueType(), outTransposeExecType, k);
 
                return out;
        }
-       
+
        //////////////////////////
        // Spark Lops generation
        /////////////////////////
@@ -718,25 +699,25 @@ public class AggBinaryOp extends MultiThreadedHop {
        {
                Hop X = getInput().get(0).getInput().get(0); //guaranteed to 
exists
                Hop Y = getInput().get(1);
-               
+
                //right vector transpose
                Lop tY = new Transform(Y.constructLops(), ReOrgOp.TRANS, 
getDataType(), getValueType(), ExecType.CP);
                tY.getOutputParameters().setDimensions(Y.getDim2(), 
Y.getDim1(), getBlocksize(), Y.getNnz());
                setLineNumbers(tY);
-               
+
                //matrix mult spark
-               boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R); 
+               boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R);
                SparkAggType aggtype = getSparkMMAggregationType(needAgg);
-               _outputEmptyBlocks = 
!OptimizerUtils.allowsToFilterEmptyBlockOutputs(this); 
-               
-               Lop mult = new MapMult( tY, X.constructLops(), getDataType(), 
getValueType(), 
-                                     false, false, _outputEmptyBlocks, 
aggtype);       
+               _outputEmptyBlocks = 
!OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
+
+               Lop mult = new MapMult( tY, X.constructLops(), getDataType(), 
getValueType(),
+                                     false, false, _outputEmptyBlocks, 
aggtype);
                mult.getOutputParameters().setDimensions(Y.getDim2(), 
X.getDim2(), getBlocksize(), getNnz());
                setLineNumbers(mult);
-               
+
                //result transpose (dimensions set outside)
                Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), 
getValueType(), ExecType.CP);
-               
+
                return out;
        }
 
@@ -892,13 +873,13 @@ public class AggBinaryOp extends MultiThreadedHop {
                setLineNumbers( zipmm );
                setLops(zipmm);
        }
-                       
+
        /**
         * Determines if the rewrite t(X)%*%Y -> t(t(Y)%*%X) is applicable
         * and cost effective. Whenever X is a wide matrix and Y is a vector
         * this has huge impact, because the transpose of X would dominate
         * the entire operation costs.
-        * 
+        *
         * @param CP true if CP
         * @return true if left transpose rewrite applicable
         */
@@ -910,38 +891,38 @@ public class AggBinaryOp extends MultiThreadedHop {
                {
                        return false;
                }
-               
+
                boolean ret = false;
                Hop h1 = getInput().get(0);
                Hop h2 = getInput().get(1);
-               
+
                //check for known dimensions and cost for t(X) vs t(v) + t(tvX)
                //(for both CP/MR, we explicitly check that new transposes fit 
in memory,
                //even a ba in CP does not imply that both transposes can be 
executed in CP)
-               if( CP ) //in-memory ba 
+               if( CP ) //in-memory ba
                {
                        if( HopRewriteUtils.isTransposeOperation(h1) )
                        {
                                long m = h1.getDim1();
                                long cd = h1.getDim2();
                                long n = h2.getDim2();
-                               
+
                                //check for known dimensions (necessary 
condition for subsequent checks)
-                               ret = (m>0 && cd>0 && n>0); 
-                               
-                               //check operation memory with changed transpose 
(this is important if we have 
+                               ret = (m>0 && cd>0 && n>0);
+
+                               //check operation memory with changed transpose 
(this is important if we have
                                //e.g., t(X) %*% v, where X is sparse and tX 
fits in memory but X does not
                                double memX = 
h1.getInput().get(0).getOutputMemEstimate();
                                double memtv = 
OptimizerUtils.estimateSizeExactSparsity(n, cd, 1.0);
                                double memtXv = 
OptimizerUtils.estimateSizeExactSparsity(n, m, 1.0);
                                double newMemEstimate = memtv + memX + memtXv;
                                ret &= ( newMemEstimate < 
OptimizerUtils.getLocalMemBudget() );
-                               
+
                                //check for cost benefit of t(X) vs t(v) + 
t(tvX) and memory of additional transpose ops
                                ret &= ( m*cd > (cd*n + m*n) &&
-                                       2 * 
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < 
OptimizerUtils.getLocalMemBudget() &&
-                                       2 * 
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < 
OptimizerUtils.getLocalMemBudget() ); 
-                               
+                                               2 * 
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < 
OptimizerUtils.getLocalMemBudget() &&
+                                               2 * 
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < 
OptimizerUtils.getLocalMemBudget() );
+
                                //update operation memory estimate (e.g., for 
parfor optimizer)
                                if( ret )
                                        _memEstimate = newMemEstimate;
@@ -955,14 +936,14 @@ public class AggBinaryOp extends MultiThreadedHop {
                                long n = h2.getDim2();
                                //note: output size constraint for mapmult 
already checked by optfindmmultmethod
                                if( m>0 && cd>0 && n>0 && (m*cd > (cd*n + m*n)) 
&&
-                                       2 * 
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) <  
OptimizerUtils.getLocalMemBudget() &&
-                                       2 * 
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) <  
OptimizerUtils.getLocalMemBudget() ) 
+                                               2 * 
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) <  
OptimizerUtils.getLocalMemBudget() &&
+                                               2 * 
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) <  
OptimizerUtils.getLocalMemBudget() )
                                {
                                        ret = true;
                                }
                        }
                }
-               
+
                return ret;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
index 2841256607..05e8d171b7 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
@@ -1,3 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
 package org.apache.sysds.hops.fedplanner;
 
 import org.apache.commons.lang3.tuple.Pair;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java
new file mode 100644
index 0000000000..ac28b12caf
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import java.util.HashMap;
+
+public class RewriteTransposeTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "RewriteTransposeCase1"; // 
t(X)%*%Y
+       private final static String TEST_NAME2 = "RewriteTransposeCase2"; // 
X=t(A); t(X)%*%Y
+       private final static String TEST_NAME3 = "RewriteTransposeCase3"; // 
Y=t(A); t(X)%*%Y
+
+       private final static String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteTransposeTest.class.getSimpleName() + "/";
+
+       private static final double eps = 1e-9;
+
+       @Override
+       public void setUp() {
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION=false;
+
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{"R"}));
+       }
+
+       @Test
+       public void testTransposeRewrite1CP() {
+               runTransposeRewriteTest(TEST_NAME1, false);
+       }
+
+       @Test
+       public void testTransposeRewrite2CP() {
+               runTransposeRewriteTest(TEST_NAME2, true);
+       }
+
+       @Test
+       public void testTransposeRewrite3CP() {
+               runTransposeRewriteTest(TEST_NAME3, false);
+       }
+
+       private void runTransposeRewriteTest(String testname, boolean 
expectedMerge) {
+               TestConfiguration config = getTestConfiguration(testname);
+               loadTestConfiguration(config);
+
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               fullDMLScriptName = HOME + testname + ".dml";
+
+               programArgs = new String[]{"-explain", "-stats", "-args", 
output("R")};
+
+               fullRScriptName = HOME + testname + ".R";
+               rCmd = getRCmd(expectedDir());
+
+               runTest(true, false, null, -1);
+               runRScript(true);
+
+               HashMap<MatrixValue.CellIndex, Double> dmlOutput = 
readDMLMatrixFromOutputDir("R");
+               HashMap<MatrixValue.CellIndex, Double> rOutput = 
readRMatrixFromExpectedDir("R");
+               TestUtils.compareMatrices(dmlOutput, rOutput, eps, "Stat-DML", 
"Stat-R");
+               
+               Assert.assertTrue(Statistics.getCPHeavyHitterCount("r'") <= 2);
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R 
b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R
new file mode 100644
index 0000000000..5b0e19dca2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("matrixStats")
+
+X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE)
+Y <- matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE)
+
+R <- t(t(Y)%*%X)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml 
b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml
new file mode 100644
index 0000000000..83cfb65dc6
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1, 20), rows=4, cols=5);
+Y = matrix(seq(1, 12), rows=4, cols=3);
+
+R = t(X)%*%Y;
+
+write(R, $1);
\ No newline at end of file
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R 
b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R
new file mode 100644
index 0000000000..fea8c26669
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("matrixStats")
+A = matrix(seq(1, 20), nrow=5, ncol=4, byrow=TRUE)
+Y = matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE)
+X = t(A)
+
+R <- t(t(Y)%*%X)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml 
b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml
new file mode 100644
index 0000000000..cb9332423b
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix(seq(1, 20), rows=5, cols=4);
+Y = matrix(seq(1, 12), rows=4, cols=3);
+X = t(A);
+
+R = t(X) %*% Y;
+
+write(R, $1);
\ No newline at end of file
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R 
b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R
new file mode 100644
index 0000000000..2bdd22f674
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("matrixStats")
+
+X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE)
+A <- matrix(seq(1, 12), nrow=3, ncol=4, byrow=TRUE)
+Y <- t(A)
+
+R <- t(t(Y)%*%X)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml 
b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml
new file mode 100644
index 0000000000..2e26920aed
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1, 20), rows=4, cols=5);
+A = matrix(seq(1, 12), rows=3, cols=4);
+Y = t(A);
+
+R = t(X) %*% Y;
+
+write(R, $1);
\ No newline at end of file


Reply via email to