[SYSTEMML-1964] Extended codegen outer template and rework close types This patch makes a major change to the codegen outer template OFMC conditions in order to increase its applicability, which is crucial for sparsity exploitation in algorithms such as ALS-CG. In order to guarantee correctness, this patch also cleans up the close types used during candidate exploration and consolidates the redundant evaluation of valid entry points during candidate selection.
Furthermore, this patch also improves the code generation of sparse binary nodes and outer templates with neq 0 on the main input. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2ca2d8aa Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2ca2d8aa Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2ca2d8aa Branch: refs/heads/master Commit: 2ca2d8aa73c4c0463a52d7f299320fc9b3865aea Parents: 586f822 Author: Matthias Boehm <mboe...@gmail.com> Authored: Mon Oct 16 13:44:40 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Mon Oct 16 15:38:58 2017 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/OptimizerUtils.java | 5 ++- .../sysml/hops/codegen/SpoofCompiler.java | 41 ++++++++++++++++++- .../sysml/hops/codegen/cplan/CNodeBinary.java | 12 +++--- .../sysml/hops/codegen/opt/PlanAnalyzer.java | 3 +- .../sysml/hops/codegen/opt/PlanSelection.java | 19 +-------- .../codegen/opt/PlanSelectionFuseCostBased.java | 2 +- .../opt/PlanSelectionFuseCostBasedV2.java | 2 +- .../opt/PlanSelectionFuseNoRedundancy.java | 2 +- .../hops/codegen/template/CPlanMemoTable.java | 42 +++++++++++++------- .../hops/codegen/template/TemplateBase.java | 31 ++++++++++----- .../hops/codegen/template/TemplateCell.java | 11 +++-- .../hops/codegen/template/TemplateMultiAgg.java | 6 +-- .../codegen/template/TemplateOuterProduct.java | 20 ++++++---- .../hops/codegen/template/TemplateRow.java | 8 ++-- .../hops/codegen/template/TemplateUtils.java | 23 ++++++----- .../sysml/hops/rewrite/HopRewriteUtils.java | 12 +++++- .../functions/codegen/MiscPatternTest.java | 41 +++++++++++++++++-- .../scripts/functions/codegen/miscPattern3.R | 34 ++++++++++++++++ .../scripts/functions/codegen/miscPattern3.dml | 34 ++++++++++++++++ .../scripts/functions/codegen/miscPattern4.R | 35 ++++++++++++++++ .../scripts/functions/codegen/miscPattern4.dml | 35 ++++++++++++++++ 21 files changed, 328 insertions(+), 90 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index e44e439..5d831e5 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -1017,7 +1017,10 @@ public class OptimizerUtils ||(op==OpOp2.LESS && val==0) ||(op==OpOp2.NOTEQUAL && val==0) ||(op==OpOp2.EQUAL && val!=0) - ||(op==OpOp2.MINUS && val==0)); + ||(op==OpOp2.MINUS && val==0) + ||(op==OpOp2.PLUS && val==0) + ||(op==OpOp2.MAX && val<=0) + ||(op==OpOp2.MIN && val>=0)); } public static double getBinaryOpSparsityConditionalSparseSafe( double sp1, OpOp2 op, LiteralOp lit ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 0e5e194..5ff90fb 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -39,6 +39,7 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysml.hops.codegen.cplan.CNodeCell; import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg; @@ -52,6 +53,7 @@ import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBased; import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBasedV2; import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseNoRedundancy; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.TemplateBase; import org.apache.sysml.hops.codegen.template.TemplateBase.CloseType; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; @@ -66,6 +68,7 @@ import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.rewrite.ProgramRewriteStatus; @@ -535,8 +538,7 @@ public class SpoofCompiler CloseType ccode = tpl.close(hop); if( ccode == CloseType.CLOSED_INVALID ) iter.remove(); - else if( ccode == CloseType.CLOSED_VALID ) - me.closed = true; + me.ctype = ccode; } } @@ -721,6 +723,10 @@ public class SpoofCompiler else rFindAndRemoveLookup(tpl.getOutput(), in1, !(tpl instanceof CNodeRow)); + //remove unnecessary neq 0 on main input of outer template + if( tpl instanceof CNodeOuterProduct ) + rFindAndRemoveBinaryMS(tpl.getOutput(), in1, BinType.NOTEQUAL, "0", "1"); + //remove invalid row templates (e.g., unsatisfied blocksize constraint) if( tpl instanceof CNodeRow ) { //check for invalid row cplan over column vector @@ -800,6 +806,37 @@ public class SpoofCompiler } } + @SuppressWarnings("unused") + private static void rFindAndRemoveUnary(CNode node, CNodeData mainInput, UnaryType type) { + for( int i=0; i<node.getInput().size(); i++ ) { + CNode tmp = node.getInput().get(i); + if( TemplateUtils.isUnary(tmp, type) && tmp.getInput().get(0) instanceof CNodeData + && ((CNodeData)tmp.getInput().get(0)).getHopID()==mainInput.getHopID() ) + { + node.getInput().set(i, tmp.getInput().get(0)); + } + else + rFindAndRemoveUnary(tmp, mainInput, type); + } + } + + private static void rFindAndRemoveBinaryMS(CNode node, CNodeData mainInput, BinType type, String lit, String replace) { + for( int i=0; i<node.getInput().size(); i++ ) { + CNode tmp = node.getInput().get(i); + if( TemplateUtils.isBinary(tmp, type) && tmp.getInput().get(1).isLiteral() + && tmp.getInput().get(1).getVarname().equals(lit) + && tmp.getInput().get(0) instanceof CNodeData + && ((CNodeData)tmp.getInput().get(0)).getHopID()==mainInput.getHopID() ) + { + CNodeData cnode = new CNodeData(new LiteralOp(replace)); + cnode.setLiteral(true); + node.getInput().set(i, cnode); + } + else + rFindAndRemoveBinaryMS(tmp, mainInput, type, lit, replace); + } + } + private static boolean rHasLookupRC1(CNode node, CNodeData mainInput, boolean includeRC1) { boolean ret = false; for( int i=0; i<node.getInput().size() && !ret; i++ ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index cac8ab8..d188afd 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -70,13 +70,13 @@ public class CNodeBinary extends CNode public String getTemplate(boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) { switch (this) { - case DOT_PRODUCT: + case DOT_PRODUCT: return sparseLhs ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; - case VECT_MATRIXMULT: + case VECT_MATRIXMULT: return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; - case VECT_OUTERMULT_ADD: + case VECT_OUTERMULT_ADD: return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n"; @@ -110,7 +110,7 @@ public class CNodeBinary extends CNode case VECT_PLUS_SCALAR: case VECT_POW_SCALAR: case VECT_MIN_SCALAR: - case VECT_MAX_SCALAR: + case VECT_MAX_SCALAR: case VECT_EQUAL_SCALAR: case VECT_NOTEQUAL_SCALAR: case VECT_LESS_SCALAR: @@ -119,7 +119,7 @@ public class CNodeBinary extends CNode case VECT_GREATEREQUAL_SCALAR: { String vectName = getVectorPrimitiveName(); if( scalarVector ) - return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : + return sparseRhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS2%, %LEN%);\n"; else return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : @@ -274,7 +274,7 @@ public class CNodeBinary extends CNode boolean lsparseLhs = sparse && _inputs.get(0) instanceof CNodeData && _inputs.get(0).getVarname().startsWith("a"); boolean lsparseRhs = sparse && _inputs.get(1) instanceof CNodeData - && _inputs.get(1).getVarname().startsWith("a"); + && _inputs.get(1).getVarname().startsWith("a"); boolean scalarInput = _inputs.get(0).getDataType().isScalar(); boolean scalarVector = (_inputs.get(0).getDataType().isScalar() && _inputs.get(1).getDataType().isMatrix()); http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java index db1ee4d..9910814 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java @@ -266,8 +266,7 @@ public class PlanAnalyzer long[] refs = memo.getAllRefs(hopID); for( int i=0; i<3; i++ ) { if( refs[i] < 0 ) continue; - List<TemplateType> tmp = memo.getDistinctTemplateTypes(hopID, i); - + List<TemplateType> tmp = memo.getDistinctTemplateTypes(hopID, i, true); if( memo.containsNotIn(refs[i], tmp, true, true) ) ret.add(new InterestingPoint(DecisionType.TEMPLATE_CHANGE, hopID, refs[i])); } http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java index 369bf75..5242211 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java @@ -29,7 +29,6 @@ import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.codegen.template.CPlanMemoTable; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; -import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.runtime.util.UtilFunctions; public abstract class PlanSelection @@ -51,22 +50,6 @@ public abstract class PlanSelection */ public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots); - /** - * Determines if the given partial fusion plan is a valid entry point - * of a fused operator. - * - * @param me memo table entry - * @param hop current hop - * @return true if entry is valid as top-level plan - */ - public static boolean isValid(MemoTableEntry me, Hop hop) { - return (me.type == TemplateType.CELL) - || (me.type == TemplateType.MAGG) - || (me.type == TemplateType.ROW && !HopRewriteUtils.isTransposeOperation(hop)) - || (me.type == TemplateType.OUTER - && (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop))); - } - protected void addBestPlan(long hopID, MemoTableEntry me) { if( me == null ) return; if( !_bestPlans.containsKey(hopID) ) @@ -108,7 +91,7 @@ public abstract class PlanSelection if( memo.contains(current.getHopID()) ) { if( currentType == null ) { best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) + .filter(p -> p.isValid()) .min(BASE_COMPARE).orElse(null); } else { http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java index d01ffe2..521ef61 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java @@ -559,7 +559,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection if( memo.contains(current.getHopID()) ) { if( currentType == null ) { best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) + .filter(p -> p.isValid()) .filter(p -> hasNoRefToMaterialization(p, M, plan)) .min(new BasicPlanComparator()).orElse(null); opened = true; http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 4b214d0..d2ed3ac 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -773,7 +773,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //use streams, lambda expressions, etc to avoid unnecessary overhead if( currentType == null ) { for( MemoTableEntry me : memo.get(currentHopId) ) - best = isValid(me, current) + best = me.isValid() && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && BasicPlanComparator.icompare(me, best)<0 ? me : best; opened = true; http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java index 2fc90d7..fe3789f 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java @@ -85,7 +85,7 @@ public class PlanSelectionFuseNoRedundancy extends PlanSelection if( memo.contains(current.getHopID()) ) { if( currentType == null ) { best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) + .filter(p -> p.isValid()) .min(new BasicPlanComparator()).orElse(null); } else { http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 30672f3..882cde2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -39,6 +39,7 @@ import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.codegen.SpoofCompiler; import org.apache.sysml.hops.codegen.opt.InterestingPoint; import org.apache.sysml.hops.codegen.opt.PlanSelection; +import org.apache.sysml.hops.codegen.template.TemplateBase.CloseType; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; import org.apache.sysml.runtime.util.UtilFunctions; @@ -91,7 +92,7 @@ public class CPlanMemoTable return contains(hopID, type[0]); Set<TemplateType> probe = UtilFunctions.asSet(type); return contains(hopID) && get(hopID).stream() - .anyMatch(p -> (!checkClose||!p.closed) && probe.contains(p.type)); + .anyMatch(p -> (!checkClose||!p.isClosed()) && probe.contains(p.type)); } public boolean containsNotIn(long hopID, Collection<TemplateType> types, @@ -99,7 +100,7 @@ public class CPlanMemoTable return contains(hopID) && get(hopID).stream() .anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) && (!excludeCell || p.type!=TemplateType.CELL) - && !types.contains(p.type)); + && p.isValid() && !types.contains(p.type)); } public int countEntries(long hopID) { @@ -176,7 +177,7 @@ public class CPlanMemoTable setDistinct(hopID, _plans.get(hopID)); //prune closed templates without group references - _plans.get(hopID).removeIf(p -> p.closed && !p.hasPlanRef()); + _plans.get(hopID).removeIf(p -> p.isClosed() && !p.hasPlanRef()); //prune dominated plans (e.g., opened plan subsumed by fused plan //if single consumer of input; however this only applies to fusion @@ -268,16 +269,21 @@ public class CPlanMemoTable return Collections.emptyList(); //return distinct entries wrt type and closed attributes return _plans.get(hopID).stream() - .map(p -> TemplateUtils.createTemplate(p.type, p.closed)) + .map(p -> TemplateUtils.createTemplate(p.type, p.ctype)) .distinct().collect(Collectors.toList()); } public List<TemplateType> getDistinctTemplateTypes(long hopID, int refAt) { + return getDistinctTemplateTypes(hopID, refAt, false); + } + + public List<TemplateType> getDistinctTemplateTypes(long hopID, int refAt, boolean exclInvalOuter) { if(!contains(hopID)) return Collections.emptyList(); //return distinct template types with reference at given position return _plans.get(hopID).stream() - .filter(p -> p.isPlanRef(refAt)) + .filter(p -> p.isPlanRef(refAt) && (!exclInvalOuter + || p.type!=TemplateType.OUTER || p.isValid()) ) .map(p -> p.type) //extract type .distinct().collect(Collectors.toList()); } @@ -289,7 +295,7 @@ public class CPlanMemoTable //single plan per type, get plan w/ best rank in preferred order //but ensure that the plans valid as a top-level plan - return tmp.stream().filter(p -> PlanSelection.isValid(p, _hopRefs.get(hopID))) + return tmp.stream().filter(p -> p.isValid()) .min(Comparator.comparing(p -> p.type.getRank())).orElse(null); } @@ -319,7 +325,7 @@ public class CPlanMemoTable for( MemoTableEntry me : get(hopID) ) for( int i=0; i<3; i++ ) if( me.isPlanRef(i) ) - refs[i] |= me.input(i); + refs[i] = me.input(i); return refs; } @@ -357,17 +363,23 @@ public class CPlanMemoTable public final long input2; public final long input3; public final int size; - public boolean closed = false; + public CloseType ctype; public MemoTableEntry(TemplateType t, long in1, long in2, long in3, int inlen) { - this(t, in1, in2, in3, inlen, false); + this(t, in1, in2, in3, inlen, CloseType.OPEN_VALID); } - public MemoTableEntry(TemplateType t, long in1, long in2, long in3, int inlen, boolean close) { + public MemoTableEntry(TemplateType t, long in1, long in2, long in3, int inlen, CloseType close) { type = t; input1 = in1; input2 = in2; input3 = in3; size = inlen; - closed = close; + ctype = close; + } + public boolean isClosed() { + return ctype.isClosed(); + } + public boolean isValid() { + return ctype.isValid(); } public boolean isPlanRef(int index) { return (index==0 && input1 >=0) @@ -404,7 +416,7 @@ public class CPlanMemoTable h = UtilFunctions.intHashCode(h, Long.hashCode(input2)); h = UtilFunctions.intHashCode(h, Long.hashCode(input3)); h = UtilFunctions.intHashCode(h, size); - h = UtilFunctions.intHashCode(h, Boolean.hashCode(closed)); + h = UtilFunctions.intHashCode(h, ctype.ordinal()); return h; } @Override @@ -414,7 +426,7 @@ public class CPlanMemoTable MemoTableEntry that = (MemoTableEntry)obj; return type == that.type && input1 == that.input1 && input2 == that.input2 && input3 == that.input3 - && size == that.size && closed == that.closed; + && size == that.size && ctype == that.ctype; } @Override public String toString() { @@ -426,6 +438,8 @@ public class CPlanMemoTable sb.append(","); sb.append(input(i)); } + if( !isValid() ) + sb.append(", x"); sb.append(")"); return sb.toString(); } @@ -439,7 +453,7 @@ public class CPlanMemoTable int pos = (c != null) ? hop.getInput().indexOf(c) : -1; int size = (hop instanceof IndexingOp) ? 1 : hop.getInput().size(); plans.add(new MemoTableEntry(tpl.getType(), (pos==0)?c.getHopID():-1, - (pos==1)?c.getHopID():-1, (pos==2)?c.getHopID():-1, size, tpl.isClosed())); + (pos==1)?c.getHopID():-1, (pos==2)?c.getHopID():-1, size, tpl.getCType())); } public void crossProduct(int pos, Long... refs) { http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java index b42eecf..9d4ff9b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java @@ -42,35 +42,46 @@ public abstract class TemplateBase } public enum CloseType { - CLOSED_VALID, - CLOSED_INVALID, - OPEN, + CLOSED_VALID, //no further fusion, valid entry point + CLOSED_INVALID, //no further fusion, invalid entry point (to be discarded) + OPEN_VALID, //further fusion allowed, valid entry point + OPEN_INVALID; //further fusion allowed, but invalid entry point + public boolean isClosed() { + return (this == CLOSED_VALID || this == CloseType.CLOSED_INVALID); + } + public boolean isValid() { + return (this == CLOSED_VALID || this == OPEN_VALID); + } } protected final TemplateType _type; - protected final boolean _closed; + protected final CloseType _ctype; protected TemplateBase(TemplateType type) { - this(type, false); + this(type, CloseType.OPEN_VALID); } - protected TemplateBase(TemplateType type, boolean closed) { + protected TemplateBase(TemplateType type, CloseType ctype) { _type = type; - _closed = closed; + _ctype = ctype; } public TemplateType getType() { return _type; } + public CloseType getCType() { + return _ctype; + } + public boolean isClosed() { - return _closed; + return _ctype.isClosed(); } @Override public int hashCode() { return UtilFunctions.intHashCode( - _type.ordinal(), Boolean.hashCode(_closed)); + _type.ordinal(), _ctype.ordinal()); } @Override @@ -79,7 +90,7 @@ public abstract class TemplateBase return false; TemplateBase that = (TemplateBase)obj; return _type == that._type - && _closed == that._closed; + && _ctype == that._ctype; } ///////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 2b29ce2..c9b0734 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -63,15 +63,14 @@ public class TemplateCell extends TemplateBase super(TemplateType.CELL); } - public TemplateCell(boolean closed) { - super(TemplateType.CELL, closed); + public TemplateCell(CloseType ctype) { + super(TemplateType.CELL, ctype); } - public TemplateCell(TemplateType type, boolean closed) { - super(type, closed); + public TemplateCell(TemplateType type, CloseType ctype) { + super(type, ctype); } - @Override public boolean open(Hop hop) { return hop.dimsKnown() && isValidOperation(hop) @@ -108,7 +107,7 @@ public class TemplateCell extends TemplateBase else if( hop instanceof AggUnaryOp || hop instanceof AggBinaryOp ) return CloseType.CLOSED_INVALID; else - return CloseType.OPEN; + return CloseType.OPEN_VALID; } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java index ebd6078..740604c 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java @@ -42,11 +42,11 @@ import org.apache.sysml.runtime.matrix.data.Pair; public class TemplateMultiAgg extends TemplateCell { public TemplateMultiAgg() { - super(TemplateType.MAGG, false); + super(TemplateType.MAGG, CloseType.OPEN_VALID); } - public TemplateMultiAgg(boolean closed) { - super(TemplateType.MAGG, closed); + public TemplateMultiAgg(CloseType ctype) { + super(TemplateType.MAGG, ctype); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java index 904fbb3..f3880b1 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java @@ -51,8 +51,8 @@ public class TemplateOuterProduct extends TemplateBase { super(TemplateType.OUTER); } - public TemplateOuterProduct(boolean closed) { - super(TemplateType.OUTER, closed); + public TemplateOuterProduct(CloseType ctype) { + super(TemplateType.OUTER, ctype); } @Override @@ -68,7 +68,7 @@ public class TemplateOuterProduct extends TemplateBase { &&((hop instanceof UnaryOp && TemplateUtils.isOperationSupported(hop)) || (hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop) && (TemplateUtils.isBinaryMatrixColVector(hop) || HopRewriteUtils.isBinaryMatrixScalarOperation(hop) - || (HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) && HopRewriteUtils.isBinary(hop, OpOp2.MULT, OpOp2.DIV)) )) + || (HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)) )) || (HopRewriteUtils.isTransposeOperation(hop) && input instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(input)) || (hop instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop) @@ -89,18 +89,24 @@ public class TemplateOuterProduct extends TemplateBase { @Override public CloseType close(Hop hop) { // close on second matrix multiply (after open) or unary aggregate - if( hop instanceof AggUnaryOp && HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) + if( (hop instanceof AggUnaryOp && (HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) + || !HopRewriteUtils.isBinarySparseSafe(hop.getInput().get(0)))) || (hop instanceof AggBinaryOp && (HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) - || HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(1)))) ) - return CloseType.CLOSED_INVALID; + || HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(1)) + || (!HopRewriteUtils.isOuterProductLikeMM(hop) + && !HopRewriteUtils.isBinarySparseSafe(HopRewriteUtils.getLargestInput(hop))))) ) + return CloseType.CLOSED_INVALID; else if( (hop instanceof AggUnaryOp) || (hop instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop) && !HopRewriteUtils.isTransposeOperation(hop.getParent().get(0))) || (HopRewriteUtils.isTransposeOperation(hop) && hop.getInput().get(0) instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) )) return CloseType.CLOSED_VALID; + else if( HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) + && HopRewriteUtils.isBinary(hop, OpOp2.MULT, OpOp2.DIV) ) + return CloseType.OPEN_VALID; else - return CloseType.OPEN; + return CloseType.OPEN_INVALID; } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 5f14d6b..64014da 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -69,8 +69,8 @@ public class TemplateRow extends TemplateBase super(TemplateType.ROW); } - public TemplateRow(boolean closed) { - super(TemplateType.ROW, closed); + public TemplateRow(CloseType ctype) { + super(TemplateType.ROW, ctype); } @Override @@ -136,8 +136,10 @@ public class TemplateRow extends TemplateBase if( (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.Row) || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))) return CloseType.CLOSED_VALID; + else if( HopRewriteUtils.isTransposeOperation(hop) ) + return CloseType.OPEN_INVALID; else - return CloseType.OPEN; + return CloseType.OPEN_VALID; } private static boolean isValidBinaryOperation(Hop hop) { http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 4dc0bf2..497dae0 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -42,6 +42,7 @@ import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; +import org.apache.sysml.hops.codegen.template.TemplateBase.CloseType; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; @@ -143,27 +144,27 @@ public class TemplateUtils } public static TemplateBase createTemplate(TemplateType type) { - return createTemplate(type, false); + return createTemplate(type, CloseType.OPEN_VALID); } - public static TemplateBase createTemplate(TemplateType type, boolean closed) { + public static TemplateBase createTemplate(TemplateType type, CloseType ctype) { TemplateBase tpl = null; switch( type ) { - case CELL: tpl = new TemplateCell(closed); break; - case ROW: tpl = new TemplateRow(closed); break; - case MAGG: tpl = new TemplateMultiAgg(closed); break; - case OUTER: tpl = new TemplateOuterProduct(closed); break; + case CELL: tpl = new TemplateCell(ctype); break; + case ROW: tpl = new TemplateRow(ctype); break; + case MAGG: tpl = new TemplateMultiAgg(ctype); break; + case OUTER: tpl = new TemplateOuterProduct(ctype); break; } return tpl; } - public static TemplateBase[] createCompatibleTemplates(TemplateType type, boolean closed) { + public static TemplateBase[] createCompatibleTemplates(TemplateType type, CloseType ctype) { TemplateBase[] tpl = null; switch( type ) { - case CELL: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRow(closed)}; break; - case ROW: tpl = new TemplateBase[]{new TemplateRow(closed)}; break; - case MAGG: tpl = new TemplateBase[]{new TemplateMultiAgg(closed)}; break; - case OUTER: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break; + case CELL: tpl = new TemplateBase[]{new TemplateCell(ctype), new TemplateRow(ctype)}; break; + case ROW: tpl = new TemplateBase[]{new TemplateRow(ctype)}; break; + case MAGG: tpl = new TemplateBase[]{new TemplateMultiAgg(ctype)}; break; + case OUTER: tpl = new TemplateBase[]{new TemplateOuterProduct(ctype)}; break; } return tpl; } http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index d96d1e4..b0f46b7 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -294,6 +294,16 @@ public class HopRewriteUtils return null; } + public static Hop getLargestInput(Hop hop) { + Hop max = null; long maxSize = -1; + for(Hop in : hop.getInput()) + if(in.getLength() > maxSize) { + max = in; + maxSize = in.getLength(); + } + return max; + } + public static Hop createDataGenOp( Hop input, double value ) throws HopsException { @@ -854,7 +864,7 @@ public class HopRewriteUtils public static boolean isBinarySparseSafe(Hop hop) { if( !(hop instanceof BinaryOp) ) return false; - if( isBinary(hop, OpOp2.MULT) ) + if( isBinary(hop, OpOp2.MULT, OpOp2.DIV) ) return true; BinaryOp bop = (BinaryOp) hop; Hop lit = bop.getInput().get(0) instanceof LiteralOp ? bop.getInput().get(0) : http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java index 75c28eb..b2dfc10 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java @@ -38,6 +38,8 @@ public class MiscPatternTest extends AutomatedTestBase private static final String TEST_NAME = "miscPattern"; private static final String TEST_NAME1 = TEST_NAME+"1"; //Y + (X * U%*%t(V)) overlapping cell-outer private static final String TEST_NAME2 = TEST_NAME+"2"; //multi-agg w/ large common subexpression + private static final String TEST_NAME3 = TEST_NAME+"3"; //sum((X!=0) * (U %*% t(V) - X)^2) + private static final String TEST_NAME4 = TEST_NAME+"4"; //((X!=0) * (U %*% t(V) - X)) %*% V + Y overlapping row-outer private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + MiscPatternTest.class.getSimpleName() + "/"; @@ -49,11 +51,11 @@ public class MiscPatternTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=2; i++) + for(int i=1; i<=4; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } - @Test + @Test public void testCodegenMiscRewrite1CP() { testCodegenIntegration( TEST_NAME1, true, ExecType.CP ); } @@ -68,7 +70,7 @@ public class MiscPatternTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME1, false, ExecType.SPARK ); } - @Test + @Test public void testCodegenMiscRewrite2CP() { testCodegenIntegration( TEST_NAME2, true, ExecType.CP ); } @@ -83,6 +85,36 @@ public class MiscPatternTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME2, false, ExecType.SPARK ); } + @Test + public void testCodegenMiscRewrite3CP() { + testCodegenIntegration( TEST_NAME3, true, ExecType.CP ); + } + + @Test + public void testCodegenMisc3CP() { + testCodegenIntegration( TEST_NAME3, false, ExecType.CP ); + } + + @Test + public void testCodegenMisc3SP() { + testCodegenIntegration( TEST_NAME3, false, ExecType.SPARK ); + } + + @Test + public void testCodegenMiscRewrite4CP() { + testCodegenIntegration( TEST_NAME4, true, ExecType.CP ); + } + + @Test + public void testCodegenMisc4CP() { + testCodegenIntegration( TEST_NAME4, false, ExecType.CP ); + } + + @Test + public void testCodegenMisc4SP() { + testCodegenIntegration( TEST_NAME4, false, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; @@ -128,6 +160,9 @@ public class MiscPatternTest extends AutomatedTestBase else if( testname.equals(TEST_NAME2) ) Assert.assertTrue(!heavyHittersContainsSubString("spoof", 2) && !heavyHittersContainsSubString("sp_spoof", 2)); + else if( testname.equals(TEST_NAME3) || testname.equals(TEST_NAME4) ) + Assert.assertTrue(heavyHittersContainsSubString("spoofOP", "sp+spoofOP") + && !heavyHittersContainsSubString("ba+*")); } finally { rtplatform = platformOld; http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/test/scripts/functions/codegen/miscPattern3.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern3.R b/src/test/scripts/functions/codegen/miscPattern3.R new file mode 100644 index 0000000..e04ac73 --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern3.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(1, 1100, 2200); +U = matrix(3, 1100, 10); +V = matrix(4, 2200, 10) +X[4:900,3:1000] = matrix(0, 897, 998); + +R1 = sum((X!=0) * (U %*% t(V) - X)^2) +R2 = as.matrix(R1); + +writeMM(as(R2, "CsparseMatrix"), paste(args[1], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/test/scripts/functions/codegen/miscPattern3.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern3.dml b/src/test/scripts/functions/codegen/miscPattern3.dml new file mode 100644 index 0000000..593bd8a --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern3.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, 1100, 2200); +U = matrix(3, 1100, 10); +V = matrix(4, 2200, 10) +X[4:900,3:1000] = matrix(0, 897, 998); + +while(FALSE){} + +R1 = sum((X!=0) * (U %*% t(V) - X)^2) + +while(FALSE){} + +R2 = as.matrix(R1); +write(R2, $1) http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/test/scripts/functions/codegen/miscPattern4.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern4.R b/src/test/scripts/functions/codegen/miscPattern4.R new file mode 100644 index 0000000..b8ea2e7 --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern4.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(1, 1100, 2200); +Y = matrix(2, 1100, 10); +U = matrix(3, 1100, 10); +V = matrix(4, 2200, 10) +X[4:900,3:1000] = matrix(0, 897, 998); + +R1 = ((X!=0) * (U %*% t(V) - X)) %*% V + Y; +R2 = as.matrix(sum(R1)); + +writeMM(as(R2, "CsparseMatrix"), paste(args[1], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca2d8aa/src/test/scripts/functions/codegen/miscPattern4.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern4.dml b/src/test/scripts/functions/codegen/miscPattern4.dml new file mode 100644 index 0000000..c4a93d3 --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern4.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, 1100, 2200); +Y = matrix(2, 1100, 10); +U = matrix(3, 1100, 10); +V = matrix(4, 2200, 10) +X[4:900,3:1000] = matrix(0, 897, 998); + +while(FALSE){} + +R1 = ((X!=0) * (U %*% t(V) - X)) %*% V + Y; + +while(FALSE){} + +R2 = as.matrix(sum(R1)); +write(R2, $1)