This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 4b2d83e915 [SYSTEMDS-3861] Fix redundant transposes due to multi-level rewrites 4b2d83e915 is described below commit 4b2d83e915c40b7433580cddfec68aa8c440ba05 Author: aarna <aarnatya...@gmail.com> AuthorDate: Fri Apr 18 12:43:04 2025 +0200 [SYSTEMDS-3861] Fix redundant transposes due to multi-level rewrites Closes #2249. --- .../java/org/apache/sysds/hops/AggBinaryOp.java | 487 ++++++++++----------- .../hops/fedplanner/FederatedMemoTablePrinter.java | 19 + .../functions/rewrite/RewriteTransposeTest.java | 86 ++++ .../functions/rewrite/RewriteTransposeCase1.R | 32 ++ .../functions/rewrite/RewriteTransposeCase1.dml | 27 ++ .../functions/rewrite/RewriteTransposeCase2.R | 32 ++ .../functions/rewrite/RewriteTransposeCase2.dml | 28 ++ .../functions/rewrite/RewriteTransposeCase3.R | 33 ++ .../functions/rewrite/RewriteTransposeCase3.dml | 28 ++ 9 files changed, 519 insertions(+), 253 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index 2cf651f189..5f9c6b41b3 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -43,6 +43,7 @@ import org.apache.sysds.lops.MatMultCP; import org.apache.sysds.lops.PMMJ; import org.apache.sysds.lops.PMapMult; import org.apache.sysds.lops.Transform; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -65,7 +66,7 @@ public class AggBinaryOp extends MultiThreadedHop { public static final double MAPMULT_MEM_MULTIPLIER = 1.0; public static MMultMethod FORCED_MMULT_METHOD = null; - public enum MMultMethod { + public enum MMultMethod { CPMM, //cross-product matrix multiplication (mr) RMM, //replication matrix multiplication (mr) MAPMM_L, //map-side matrix-matrix multiplication using distributed cache (mr/sp) @@ -78,27 +79,27 @@ public class AggBinaryOp extends MultiThreadedHop { ZIPMM, //zip matrix multiplication (sp) MM //in-memory matrix multiplication (cp) } - - public enum SparkAggType{ + + public enum SparkAggType { NONE, SINGLE_BLOCK, MULTI_BLOCK, } - + private OpOp2 innerOp; private AggOp outerOp; private MMultMethod _method = null; - + //hints set by previous to operator selection private boolean _hasLeftPMInput = false; //left input is permutation matrix - + private AggBinaryOp() { //default constructor for clone } - + public AggBinaryOp(String l, DataType dt, ValueType vt, OpOp2 innOp, - AggOp outOp, Hop in1, Hop in2) { + AggOp outOp, Hop in1, Hop in2) { super(l, dt, vt); innerOp = innOp; outerOp = outOp; @@ -106,7 +107,7 @@ public class AggBinaryOp extends MultiThreadedHop { getInput().add(1, in2); in1.getParent().add(this); in2.getParent().add(this); - + //compute unknown dims and nnz refreshSizeInformation(); } @@ -114,30 +115,30 @@ public class AggBinaryOp extends MultiThreadedHop { public void setHasLeftPMInput(boolean flag) { _hasLeftPMInput = flag; } - - public boolean hasLeftPMInput(){ + + public boolean hasLeftPMInput() { return _hasLeftPMInput; } - public MMultMethod getMMultMethod(){ + public MMultMethod getMMultMethod() { return _method; } - + @Override public boolean isGPUEnabled() { - if(!DMLScript.USE_ACCELERATOR) + if (!DMLScript.USE_ACCELERATOR) return false; - + Hop input1 = getInput().get(0); Hop input2 = getInput().get(1); //matrix mult operation selection part 2 (specific pattern) MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern ChainType chain = checkMapMultChain(); //determine mmchain pattern - - _method = optFindMMultMethodCP ( input1.getDim1(), input1.getDim2(), - input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput ); - switch( _method ){ - case TSMM: + + _method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), + input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput); + switch (_method) { + case TSMM: //return false; // TODO: Disabling any fused transa optimization in 1.0 release. return true; case MAPMM_CHAIN: @@ -150,50 +151,47 @@ public class AggBinaryOp extends MultiThreadedHop { throw new RuntimeException("Unsupported method:" + _method); } } - + /** * NOTE: overestimated mem in case of transpose-identity matmult, but 3/2 at worst - * and existing mem estimate advantageous in terms of consistency hops/lops, - * and some special cases internally materialize the transpose for better cache locality + * and existing mem estimate advantageous in terms of consistency hops/lops, + * and some special cases internally materialize the transpose for better cache locality */ @Override - public Lop constructLops() - { + public Lop constructLops() { //return already created lops - if( getLops() != null ) + if (getLops() != null) return getLops(); - + //construct matrix mult lops (currently only supported aggbinary) - if ( isMatrixMultiply() ) - { + if (isMatrixMultiply()) { Hop input1 = getInput().get(0); Hop input2 = getInput().get(1); - + //matrix mult operation selection part 1 (CP vs MR vs Spark) ExecType et = optFindExecType(); - + //matrix mult operation selection part 2 (specific pattern) MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern ChainType chain = checkMapMultChain(); //determine mmchain pattern - if(mmtsj == MMTSJType.LEFT && input2.isCompressedOutput()){ + if (mmtsj == MMTSJType.LEFT && input2.isCompressedOutput()) { // if tsmm and input is compressed. (using input2, since input1 is transposed and therefore not compressed.) et = ExecType.CP; } - if( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED ) - { + if (et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED) { //matrix mult operation selection part 3 (CP type) - _method = optFindMMultMethodCP ( input1.getDim1(), input1.getDim2(), - input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput ); - + _method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), + input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput); + //dispatch CP lops construction - switch( _method ){ - case TSMM: - constructCPLopsTSMM( mmtsj, et ); + switch (_method) { + case TSMM: + constructCPLopsTSMM(mmtsj, et); break; case MAPMM_CHAIN: - constructCPLopsMMChain( chain ); + constructCPLopsMMChain(chain); break; case PMM: constructCPLopsPMM(); @@ -204,53 +202,49 @@ public class AggBinaryOp extends MultiThreadedHop { default: throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing CP lops."); } - } - else if( et == ExecType.SPARK ) - { + } else if (et == ExecType.SPARK) { //matrix mult operation selection part 3 (SPARK type) boolean tmmRewrite = HopRewriteUtils.isTransposeOperation(input1); - _method = optFindMMultMethodSpark ( + _method = optFindMMultMethodSpark( input1.getDim1(), input1.getDim2(), input1.getBlocksize(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getBlocksize(), input2.getNnz(), - mmtsj, chain, _hasLeftPMInput, tmmRewrite ); + mmtsj, chain, _hasLeftPMInput, tmmRewrite); //dispatch SPARK lops construction - switch( _method ) - { + switch (_method) { case TSMM: - case TSMM2: - constructSparkLopsTSMM( mmtsj, _method==MMultMethod.TSMM2 ); + case TSMM2: + constructSparkLopsTSMM(mmtsj, _method == MMultMethod.TSMM2); break; case MAPMM_L: case MAPMM_R: - constructSparkLopsMapMM( _method ); + constructSparkLopsMapMM(_method); break; case MAPMM_CHAIN: - constructSparkLopsMapMMChain( chain ); + constructSparkLopsMapMMChain(chain); break; case PMAPMM: constructSparkLopsPMapMM(); break; - case CPMM: + case CPMM: constructSparkLopsCPMM(); break; - case RMM: + case RMM: constructSparkLopsRMM(); break; case PMM: - constructSparkLopsPMM(); + constructSparkLopsPMM(); break; case ZIPMM: constructSparkLopsZIPMM(); break; - + default: - throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops."); + throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops."); } } - } - else + } else throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops."); - + //add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); @@ -260,30 +254,28 @@ public class AggBinaryOp extends MultiThreadedHop { @Override public String getOpString() { //ba - binary aggregate, for consistency with runtime - return "ba(" + outerOp.toString() + innerOp.toString()+")"; + return "ba(" + outerOp.toString() + innerOp.toString() + ")"; } - + @Override - public void computeMemEstimate(MemoTable memo) - { + public void computeMemEstimate(MemoTable memo) { //extension of default compute memory estimate in order to //account for smaller tsmm memory requirements. super.computeMemEstimate(memo); - + //tsmm left is guaranteed to require only X but not t(X), while //tsmm right might have additional requirements to transpose X if sparse //NOTE: as a heuristic this correction is only applied if not a column vector because //most other vector operations require memory for at least two vectors (we aim for //consistency in order to prevent anomalies in parfor opt leading to small degree of par) MMTSJType mmtsj = checkTransposeSelf(); - if( mmtsj.isLeft() && getInput().get(1).dimsKnown() && getInput().get(1).getDim2()>1 ) { + if (mmtsj.isLeft() && getInput().get(1).dimsKnown() && getInput().get(1).getDim2() > 1) { _memEstimate = _memEstimate - getInput().get(0)._outputMemEstimate; } } @Override - protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) - { + protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { //NOTES: // * The estimate for transpose-self is the same as for normal matrix multiplications // because (1) this decouples the decision of TSMM over default MM and (2) some cases @@ -314,10 +306,9 @@ public class AggBinaryOp extends MultiThreadedHop { return ret; } - + @Override - protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) - { + protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { double ret = 0; if (isGPUEnabled()) { @@ -327,277 +318,254 @@ public class AggBinaryOp extends MultiThreadedHop { double in2Sparsity = OptimizerUtils.getSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz()); boolean in1Sparse = in1Sparsity < MatrixBlock.SPARSITY_TURN_POINT; boolean in2Sparse = in2Sparsity < MatrixBlock.SPARSITY_TURN_POINT; - if(in1Sparse && !in2Sparse) { + if (in1Sparse && !in2Sparse) { // Only in sparse-dense cases, we need additional memory budget for GPU ret += OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0); } } //account for potential final dense-sparse transformation (worst-case sparse representation) - if( dim2 >= 2 && nnz != 0 ) //vectors always dense + if (dim2 >= 2 && nnz != 0) //vectors always dense ret += MatrixBlock.estimateSizeSparseInMemory(dim1, dim2, - MatrixBlock.SPARSITY_TURN_POINT - UtilFunctions.DOUBLE_EPS); - + MatrixBlock.SPARSITY_TURN_POINT - UtilFunctions.DOUBLE_EPS); + return ret; } - + @Override - protected DataCharacteristics inferOutputCharacteristics( MemoTable memo ) - { + protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) { DataCharacteristics[] dc = memo.getAllInputStats(getInput()); DataCharacteristics ret = null; - if( dc[0].rowsKnown() && dc[1].colsKnown() ) { + if (dc[0].rowsKnown() && dc[1].colsKnown()) { ret = new MatrixCharacteristics(dc[0].getRows(), dc[1].getCols()); - double sp1 = (dc[0].getNonZeros()>0) ? OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(), dc[0].getNonZeros()) : 1.0; - double sp2 = (dc[1].getNonZeros()>0) ? OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(), dc[1].getNonZeros()) : 1.0; - ret.setNonZeros((long)(ret.getLength() * OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(), ret.getCols(), true))); + double sp1 = (dc[0].getNonZeros() > 0) ? OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(), dc[0].getNonZeros()) : 1.0; + double sp2 = (dc[1].getNonZeros() > 0) ? OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(), dc[1].getNonZeros()) : 1.0; + ret.setNonZeros((long) (ret.getLength() * OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(), ret.getCols(), true))); } return ret; } - + public boolean isMatrixMultiply() { - return ( this.innerOp == OpOp2.MULT && this.outerOp == AggOp.SUM ); + return (this.innerOp == OpOp2.MULT && this.outerOp == AggOp.SUM); } - + private boolean isOuterProduct() { - return ( getInput().get(0).isVector() && getInput().get(1).isVector() ) - && ( getInput().get(0).getDim1() == 1 && getInput().get(0).getDim1() > 1 - && getInput().get(1).getDim1() > 1 && getInput().get(1).getDim2() == 1 ); + return (getInput().get(0).isVector() && getInput().get(1).isVector()) + && (getInput().get(0).getDim1() == 1 && getInput().get(0).getDim1() > 1 + && getInput().get(1).getDim1() > 1 && getInput().get(1).getDim2() == 1); } - + @Override public boolean isMultiThreadedOpType() { return isMatrixMultiply(); } - + @Override - public boolean allowsAllExecTypes() - { + public boolean allowsAllExecTypes() { return true; } - + @Override - protected ExecType optFindExecType(boolean transitive) - { + protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - - if( _etypeForced != null ) { + + if (_etypeForced != null) { setExecType(_etypeForced); - } - else - { - if ( OptimizerUtils.isMemoryBasedOptLevel() ) { + } else { + if (OptimizerUtils.isMemoryBasedOptLevel()) { setExecType(findExecTypeByMemEstimate()); } // choose CP if the dimensions of both inputs are below Hops.CPThreshold // OR if it is vector-vector inner product - else if ( (getInput().get(0).areDimsBelowThreshold() && getInput().get(1).areDimsBelowThreshold()) - || (getInput().get(0).isVector() && getInput().get(1).isVector() && !isOuterProduct()) ) - { + else if ((getInput().get(0).areDimsBelowThreshold() && getInput().get(1).areDimsBelowThreshold()) + || (getInput().get(0).isVector() && getInput().get(1).isVector() && !isOuterProduct())) { setExecType(ExecType.CP); - } - else - { + } else { setExecType(ExecType.SPARK); } - + //check for valid CP mmchain, send invalid memory requirements to remote - if( _etype == ExecType.CP - && checkMapMultChain() != ChainType.NONE - && OptimizerUtils.getLocalMemBudget() < - getInput().get(0).getInput().get(0).getOutputMemEstimate() ) { + if (_etype == ExecType.CP + && checkMapMultChain() != ChainType.NONE + && OptimizerUtils.getLocalMemBudget() < + getInput().get(0).getInput().get(0).getOutputMemEstimate()) { setExecType(ExecType.SPARK); } - + //check for valid CP dimensions and matrix size checkAndSetInvalidCPDimsAndSize(); } - + //spark-specific decision refinement (execute binary aggregate w/ left or right spark input and //single parent also in spark because it's likely cheap and reduces data transfer) MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern - if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP - && ((!mmtsj.isLeft() && isApplicableForTransitiveSparkExecType(true)) - || ( !mmtsj.isRight() && isApplicableForTransitiveSparkExecType(false))) ) - { + if (transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP + && ((!mmtsj.isLeft() && isApplicableForTransitiveSparkExecType(true)) + || (!mmtsj.isRight() && isApplicableForTransitiveSparkExecType(false)))) { //pull binary aggregate into spark setExecType(ExecType.SPARK); } //mark for recompile (forever) setRequiresRecompileIfNecessary(); - + return _etype; } - - private boolean isApplicableForTransitiveSparkExecType(boolean left) - { + + private boolean isApplicableForTransitiveSparkExecType(boolean left) { int index = left ? 0 : 1; - return !(getInput(index) instanceof DataOp && ((DataOp)getInput(index)).requiresCheckpoint()) - && (!HopRewriteUtils.isTransposeOperation(getInput(index)) + return !(getInput(index) instanceof DataOp && ((DataOp) getInput(index)).requiresCheckpoint()) + && (!HopRewriteUtils.isTransposeOperation(getInput(index)) || (left && !isLeftTransposeRewriteApplicable(true))) - && getInput(index).getParent().size()==1 //bagg is only parent - && !getInput(index).areDimsBelowThreshold() - && (getInput(index).optFindExecType() == ExecType.SPARK - || (getInput(index) instanceof DataOp && ((DataOp)getInput(index)).hasOnlyRDD())) - && getInput(index).getOutputMemEstimate()>getOutputMemEstimate(); + && getInput(index).getParent().size() == 1 //bagg is only parent + && !getInput(index).areDimsBelowThreshold() + && (getInput(index).optFindExecType() == ExecType.SPARK + || (getInput(index) instanceof DataOp && ((DataOp) getInput(index)).hasOnlyRDD())) + && getInput(index).getOutputMemEstimate() > getOutputMemEstimate(); } - + /** * TSMM: Determine if XtX pattern applies for this aggbinary and if yes - * which type. - * + * which type. + * * @return MMTSJType */ - public MMTSJType checkTransposeSelf() - { + public MMTSJType checkTransposeSelf() { MMTSJType ret = MMTSJType.NONE; - + Hop in1 = getInput().get(0); Hop in2 = getInput().get(1); - - if( HopRewriteUtils.isTransposeOperation(in1) - && in1.getInput().get(0) == in2 ) - { + + if (HopRewriteUtils.isTransposeOperation(in1) + && in1.getInput().get(0) == in2) { ret = MMTSJType.LEFT; } - - if( HopRewriteUtils.isTransposeOperation(in2) - && in2.getInput().get(0) == in1 ) - { + + if (HopRewriteUtils.isTransposeOperation(in2) + && in2.getInput().get(0) == in1) { ret = MMTSJType.RIGHT; } - + return ret; } /** - * MapMultChain: Determine if XtwXv/XtXv pattern applies for this aggbinary - * and if yes which type. - * + * MapMultChain: Determine if XtwXv/XtXv pattern applies for this aggbinary + * and if yes which type. + * * @return ChainType */ - public ChainType checkMapMultChain() - { + public ChainType checkMapMultChain() { ChainType chainType = ChainType.NONE; - + Hop in1 = getInput().get(0); Hop in2 = getInput().get(1); - + //check for transpose left input (both chain types) - if( HopRewriteUtils.isTransposeOperation(in1) ) - { + if (HopRewriteUtils.isTransposeOperation(in1)) { Hop X = in1.getInput().get(0); - + //check mapmultchain patterns //t(X)%*%(w*(X%*%v)) - if( in2 instanceof BinaryOp && ((BinaryOp)in2).getOp()==OpOp2.MULT ) - { + if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MULT) { Hop in3b = in2.getInput().get(1); - if( in3b instanceof AggBinaryOp ) - { + if (in3b instanceof AggBinaryOp) { Hop in4 = in3b.getInput().get(0); - if( X == in4 ) //common input + if (X == in4) //common input chainType = ChainType.XtwXv; } } //t(X)%*%((X%*%v)-y) - else if( in2 instanceof BinaryOp && ((BinaryOp)in2).getOp()==OpOp2.MINUS ) - { + else if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MINUS) { Hop in3a = in2.getInput().get(0); - Hop in3b = in2.getInput().get(1); - if( in3a instanceof AggBinaryOp && in3b.getDataType()==DataType.MATRIX ) - { + Hop in3b = in2.getInput().get(1); + if (in3a instanceof AggBinaryOp && in3b.getDataType() == DataType.MATRIX) { Hop in4 = in3a.getInput().get(0); - if( X == in4 ) //common input + if (X == in4) //common input chainType = ChainType.XtXvy; } } //t(X)%*%(X%*%v) - else if( in2 instanceof AggBinaryOp ) - { + else if (in2 instanceof AggBinaryOp) { Hop in3 = in2.getInput().get(0); - if( X == in3 ) //common input + if (X == in3) //common input chainType = ChainType.XtXv; } } - + return chainType; } - + ////////////////////////// // CP Lops generation - ///////////////////////// - - private void constructCPLopsTSMM( MMTSJType mmtsj, ExecType et ) { + + /// ////////////////////// + + private void constructCPLopsTSMM(MMTSJType mmtsj, ExecType et) { int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); - Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft()?1:0).constructLops(), - getDataType(), getValueType(), et, mmtsj, false, k); + Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft() ? 1 : 0).constructLops(), + getDataType(), getValueType(), et, mmtsj, false, k); matmultCP.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), getNnz()); - setLineNumbers( matmultCP ); + setLineNumbers(matmultCP); setLops(matmultCP); } - private void constructCPLopsMMChain( ChainType chain ) - { + private void constructCPLopsMMChain(ChainType chain) { MapMultChain mapmmchain = null; - if( chain == ChainType.XtXv ) { + if (chain == ChainType.XtXv) { Hop hX = getInput().get(0).getInput().get(0); Hop hv = getInput().get(1).getInput().get(1); - mapmmchain = new MapMultChain( hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), ExecType.CP); - } - else { //ChainType.XtwXv / ChainType.XtwXvy + mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), ExecType.CP); + } else { //ChainType.XtwXv / ChainType.XtwXvy int wix = (chain == ChainType.XtwXv) ? 0 : 1; int vix = (chain == ChainType.XtwXv) ? 1 : 0; Hop hX = getInput().get(0).getInput().get(0); Hop hw = getInput().get(1).getInput().get(wix); Hop hv = getInput().get(1).getInput().get(vix).getInput().get(1); - mapmmchain = new MapMultChain( hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), ExecType.CP); + mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), ExecType.CP); } - + //set degree of parallelism int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); - mapmmchain.setNumThreads( k ); - + mapmmchain.setNumThreads(k); + //set basic lop properties setOutputDimensions(mapmmchain); setLineNumbers(mapmmchain); setLops(mapmmchain); } - + /** * NOTE: exists for consistency since removeEmtpy might be scheduled to MR - * but matrix mult on small output might be scheduled to CP. Hence, we + * but matrix mult on small output might be scheduled to CP. Hence, we * need to handle directly passed selection vectors in CP as well. */ - private void constructCPLopsPMM() - { + private void constructCPLopsPMM() { Hop pmInput = getInput().get(0); Hop rightInput = getInput().get(1); - + Hop nrow = HopRewriteUtils.createValueHop(pmInput, true); //NROW nrow.setBlocksize(0); nrow.setForcedExecType(ExecType.CP); HopRewriteUtils.copyLineNumbers(this, nrow); Lop lnrow = nrow.constructLops(); - + PMMJ pmm = new PMMJ(pmInput.constructLops(), rightInput.constructLops(), lnrow, getDataType(), getValueType(), false, false, ExecType.CP); - + //set degree of parallelism int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); pmm.setNumThreads(k); - + pmm.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), getNnz()); setLineNumbers(pmm); - + setLops(pmm); - + HopRewriteUtils.removeChildReference(pmInput, nrow); } - private void constructCPLopsMM(ExecType et) - { + private void constructCPLopsMM(ExecType et) { Lop matmultCP = null; String cla = ConfigurationManager.getDMLConfig().getTextValue("sysds.compressed.linalg"); if (et == ExecType.GPU) { @@ -610,72 +578,85 @@ public class AggBinaryOp extends MultiThreadedHop { boolean leftTrans = false; // HopRewriteUtils.isTransposeOperation(h1); boolean rightTrans = false; // HopRewriteUtils.isTransposeOperation(h2); Lop left = !leftTrans ? h1.constructLops() : - h1.getInput().get(0).constructLops(); + h1.getInput().get(0).constructLops(); Lop right = !rightTrans ? h2.constructLops() : - h2.getInput().get(0).constructLops(); + h2.getInput().get(0).constructLops(); matmultCP = new MatMultCP(left, right, getDataType(), getValueType(), et, leftTrans, rightTrans); setOutputDimensions(matmultCP); - } - else if (cla.equals("true") || cla.equals("cost")){ + } else if (cla.equals("true") || cla.equals("cost")) { Hop h1 = getInput().get(0); Hop h2 = getInput().get(1); int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); boolean leftTrans = HopRewriteUtils.isTransposeOperation(h1); - boolean rightTrans = HopRewriteUtils.isTransposeOperation(h2); + boolean rightTrans = HopRewriteUtils.isTransposeOperation(h2); Lop left = !leftTrans ? h1.constructLops() : - h1.getInput().get(0).constructLops(); + h1.getInput().get(0).constructLops(); Lop right = !rightTrans ? h2.constructLops() : - h2.getInput().get(0).constructLops(); + h2.getInput().get(0).constructLops(); matmultCP = new MatMultCP(left, right, getDataType(), getValueType(), et, k, leftTrans, rightTrans); - } - else { - if( isLeftTransposeRewriteApplicable(true) ) { + } else { + if (isLeftTransposeRewriteApplicable(true)) { matmultCP = constructCPLopsMMWithLeftTransposeRewrite(et); - } - else { + } else { int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); matmultCP = new MatMultCP(getInput().get(0).constructLops(), - getInput().get(1).constructLops(), getDataType(), getValueType(), et, k); + getInput().get(1).constructLops(), getDataType(), getValueType(), et, k); updateLopFedOut(matmultCP); } setOutputDimensions(matmultCP); } - + setLineNumbers(matmultCP); setLops(matmultCP); } - private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) - { - Hop X = getInput().get(0).getInput().get(0); //guaranteed to exists + private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) { + Hop X = getInput().get(0).getInput().get(0); // guaranteed to exist Hop Y = getInput().get(1); int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); - + + //Check if X is already a transpose operation + boolean isXTransposed = X instanceof ReorgOp && ((ReorgOp)X).getOp() == ReOrgOp.TRANS; + Hop actualX = isXTransposed ? X.getInput().get(0) : X; + + //Check if Y is a transpose operation + boolean isYTransposed = Y instanceof ReorgOp && ((ReorgOp)Y).getOp() == ReOrgOp.TRANS; + Hop actualY = isYTransposed ? Y.getInput().get(0) : Y; + + //Handle Y or actualY for transpose + Lop yLop = isYTransposed ? actualY.constructLops() : Y.constructLops(); + ExecType inputReorgExecType = (Y.hasFederatedOutput()) ? ExecType.FED : ExecType.CP; + //right vector transpose - Lop lY = Y.constructLops(); - ExecType inputReorgExecType = ( Y.hasFederatedOutput() ) ? ExecType.FED : ExecType.CP; - Lop tY = (lY instanceof Transform && ((Transform)lY).getOp()==ReOrgOp.TRANS ) ? - lY.getInputs().get(0) : //if input is already a transpose, avoid redundant transpose ops - new Transform(lY, ReOrgOp.TRANS, getDataType(), getValueType(), inputReorgExecType, k); - tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), getBlocksize(), Y.getNnz()); + Lop tY = (yLop instanceof Transform && ((Transform)yLop).getOp() == ReOrgOp.TRANS) ? + yLop.getInputs().get(0) : //if input is already a transpose, avoid redundant transpose ops + new Transform(yLop, ReOrgOp.TRANS, getDataType(), getValueType(), inputReorgExecType, k); + + //Set dimensions for tY + long tYRows = isYTransposed ? actualY.getDim1() : Y.getDim2(); + long tYCols = isYTransposed ? actualY.getDim2() : Y.getDim1(); + tY.getOutputParameters().setDimensions(tYRows, tYCols, getBlocksize(), Y.getNnz()); setLineNumbers(tY); if (Y.hasFederatedOutput()) updateLopFedOut(tY); - + + //Construct X lops for matrix multiplication + Lop xLop = isXTransposed ? actualX.constructLops() : X.constructLops(); + //matrix mult - Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(), getValueType(), et, k); //CP or FED - mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), getBlocksize(), getNnz()); + Lop mult = new MatMultCP(tY, xLop, getDataType(), getValueType(), et, k); + mult.getOutputParameters().setDimensions(tYRows, isXTransposed ? actualX.getDim1() : X.getDim2(), getBlocksize(), getNnz()); mult.setFederatedOutput(_federatedOutput); setLineNumbers(mult); //result transpose (dimensions set outside) - ExecType outTransposeExecType = ( _federatedOutput == FederatedOutput.FOUT ) ? - ExecType.FED : ExecType.CP; + ExecType outTransposeExecType = (_federatedOutput == FederatedOutput.FOUT) ? + ExecType.FED : ExecType.CP; Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), getValueType(), outTransposeExecType, k); return out; } - + ////////////////////////// // Spark Lops generation ///////////////////////// @@ -718,25 +699,25 @@ public class AggBinaryOp extends MultiThreadedHop { { Hop X = getInput().get(0).getInput().get(0); //guaranteed to exists Hop Y = getInput().get(1); - + //right vector transpose Lop tY = new Transform(Y.constructLops(), ReOrgOp.TRANS, getDataType(), getValueType(), ExecType.CP); tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), getBlocksize(), Y.getNnz()); setLineNumbers(tY); - + //matrix mult spark - boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R); + boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R); SparkAggType aggtype = getSparkMMAggregationType(needAgg); - _outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this); - - Lop mult = new MapMult( tY, X.constructLops(), getDataType(), getValueType(), - false, false, _outputEmptyBlocks, aggtype); + _outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this); + + Lop mult = new MapMult( tY, X.constructLops(), getDataType(), getValueType(), + false, false, _outputEmptyBlocks, aggtype); mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), getBlocksize(), getNnz()); setLineNumbers(mult); - + //result transpose (dimensions set outside) Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), getValueType(), ExecType.CP); - + return out; } @@ -892,13 +873,13 @@ public class AggBinaryOp extends MultiThreadedHop { setLineNumbers( zipmm ); setLops(zipmm); } - + /** * Determines if the rewrite t(X)%*%Y -> t(t(Y)%*%X) is applicable * and cost effective. Whenever X is a wide matrix and Y is a vector * this has huge impact, because the transpose of X would dominate * the entire operation costs. - * + * * @param CP true if CP * @return true if left transpose rewrite applicable */ @@ -910,38 +891,38 @@ public class AggBinaryOp extends MultiThreadedHop { { return false; } - + boolean ret = false; Hop h1 = getInput().get(0); Hop h2 = getInput().get(1); - + //check for known dimensions and cost for t(X) vs t(v) + t(tvX) //(for both CP/MR, we explicitly check that new transposes fit in memory, //even a ba in CP does not imply that both transposes can be executed in CP) - if( CP ) //in-memory ba + if( CP ) //in-memory ba { if( HopRewriteUtils.isTransposeOperation(h1) ) { long m = h1.getDim1(); long cd = h1.getDim2(); long n = h2.getDim2(); - + //check for known dimensions (necessary condition for subsequent checks) - ret = (m>0 && cd>0 && n>0); - - //check operation memory with changed transpose (this is important if we have + ret = (m>0 && cd>0 && n>0); + + //check operation memory with changed transpose (this is important if we have //e.g., t(X) %*% v, where X is sparse and tX fits in memory but X does not double memX = h1.getInput().get(0).getOutputMemEstimate(); double memtv = OptimizerUtils.estimateSizeExactSparsity(n, cd, 1.0); double memtXv = OptimizerUtils.estimateSizeExactSparsity(n, m, 1.0); double newMemEstimate = memtv + memX + memtXv; ret &= ( newMemEstimate < OptimizerUtils.getLocalMemBudget() ); - + //check for cost benefit of t(X) vs t(v) + t(tvX) and memory of additional transpose ops ret &= ( m*cd > (cd*n + m*n) && - 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && - 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ); - + 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && + 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ); + //update operation memory estimate (e.g., for parfor optimizer) if( ret ) _memEstimate = newMemEstimate; @@ -955,14 +936,14 @@ public class AggBinaryOp extends MultiThreadedHop { long n = h2.getDim2(); //note: output size constraint for mapmult already checked by optfindmmultmethod if( m>0 && cd>0 && n>0 && (m*cd > (cd*n + m*n)) && - 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && - 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ) + 2 * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getLocalMemBudget() && + 2 * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) < OptimizerUtils.getLocalMemBudget() ) { ret = true; } } } - + return ret; } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java index 2841256607..05e8d171b7 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.hops.fedplanner; import org.apache.commons.lang3.tuple.Pair; diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java new file mode 100644 index 0000000000..ac28b12caf --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; +import java.util.HashMap; + +public class RewriteTransposeTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "RewriteTransposeCase1"; // t(X)%*%Y + private final static String TEST_NAME2 = "RewriteTransposeCase2"; // X=t(A); t(X)%*%Y + private final static String TEST_NAME3 = "RewriteTransposeCase3"; // Y=t(A); t(X)%*%Y + + private final static String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteTransposeTest.class.getSimpleName() + "/"; + + private static final double eps = 1e-9; + + @Override + public void setUp() { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION=false; + + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{"R"})); + } + + @Test + public void testTransposeRewrite1CP() { + runTransposeRewriteTest(TEST_NAME1, false); + } + + @Test + public void testTransposeRewrite2CP() { + runTransposeRewriteTest(TEST_NAME2, true); + } + + @Test + public void testTransposeRewrite3CP() { + runTransposeRewriteTest(TEST_NAME3, false); + } + + private void runTransposeRewriteTest(String testname, boolean expectedMerge) { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + + programArgs = new String[]{"-explain", "-stats", "-args", output("R")}; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(expectedDir()); + + runTest(true, false, null, -1); + runRScript(true); + + HashMap<MatrixValue.CellIndex, Double> dmlOutput = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rOutput = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlOutput, rOutput, eps, "Stat-DML", "Stat-R"); + + Assert.assertTrue(Statistics.getCPHeavyHitterCount("r'") <= 2); + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R new file mode 100644 index 0000000000..5b0e19dca2 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") + +X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE) +Y <- matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE) + +R <- t(t(Y)%*%X) + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml new file mode 100644 index 0000000000..83cfb65dc6 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1, 20), rows=4, cols=5); +Y = matrix(seq(1, 12), rows=4, cols=3); + +R = t(X)%*%Y; + +write(R, $1); \ No newline at end of file diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R new file mode 100644 index 0000000000..fea8c26669 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") +A = matrix(seq(1, 20), nrow=5, ncol=4, byrow=TRUE) +Y = matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE) +X = t(A) + +R <- t(t(Y)%*%X) + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml new file mode 100644 index 0000000000..cb9332423b --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(seq(1, 20), rows=5, cols=4); +Y = matrix(seq(1, 12), rows=4, cols=3); +X = t(A); + +R = t(X) %*% Y; + +write(R, $1); \ No newline at end of file diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R new file mode 100644 index 0000000000..2bdd22f674 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") + +X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE) +A <- matrix(seq(1, 12), nrow=3, ncol=4, byrow=TRUE) +Y <- t(A) + +R <- t(t(Y)%*%X) + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml new file mode 100644 index 0000000000..2e26920aed --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1, 20), rows=4, cols=5); +A = matrix(seq(1, 12), rows=3, cols=4); +Y = t(A); + +R = t(X) %*% Y; + +write(R, $1); \ No newline at end of file