Repository: systemml Updated Branches: refs/heads/master 87cc5ee67 -> 4de8d684f
[SYSTEMML-1729] Reduced garbage collection overhead codegen compiler This patch addresses unnecessary garbage collection overhead induced by the codegen compiler. In detail this includes (1) hash computation w/o array allocations, (2) in-place visit status maintenance during CSE, (3) data node rename w/o node replacement, (4) better collection handling, including in-place filtering, and (5) better string replacement w/o regex evaluation. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/4de8d684 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/4de8d684 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/4de8d684 Branch: refs/heads/master Commit: 4de8d684f8a1a19eeaccfe31483a070dfba206ac Parents: 87cc5ee Author: Matthias Boehm <[email protected]> Authored: Wed Jun 21 20:45:19 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 21 20:53:06 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 12 +- .../apache/sysml/hops/codegen/cplan/CNode.java | 38 ++--- .../sysml/hops/codegen/cplan/CNodeBinary.java | 20 ++- .../sysml/hops/codegen/cplan/CNodeCell.java | 28 ++-- .../sysml/hops/codegen/cplan/CNodeData.java | 16 ++- .../sysml/hops/codegen/cplan/CNodeMultiAgg.java | 18 +-- .../hops/codegen/cplan/CNodeOuterProduct.java | 37 +++-- .../sysml/hops/codegen/cplan/CNodeRow.java | 28 ++-- .../sysml/hops/codegen/cplan/CNodeTernary.java | 16 +-- .../sysml/hops/codegen/cplan/CNodeTpl.java | 137 +++++-------------- .../sysml/hops/codegen/cplan/CNodeUnary.java | 24 ++-- .../hops/codegen/template/CPlanCSERewriter.java | 34 ++--- .../hops/codegen/template/CPlanMemoTable.java | 14 +- .../hops/codegen/template/PlanSelection.java | 4 +- .../template/PlanSelectionFuseCostBased.java | 22 +-- .../hops/codegen/template/TemplateBase.java | 2 +- .../hops/codegen/template/TemplateCell.java | 12 +- .../hops/codegen/template/TemplateMultiAgg.java | 8 +- .../hops/codegen/template/TemplateRow.java | 12 +- .../sysml/runtime/util/UtilFunctions.java | 4 + 20 files changed, 200 insertions(+), 286 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 4881f6c..fc3ecde 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -22,6 +22,7 @@ package org.apache.sysml.hops.codegen; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -319,7 +320,8 @@ public class SpoofCompiler if( root == null ) return root; - return optimize(new ArrayList<Hop>(Arrays.asList(root)), recompile).get(0); + return optimize(new ArrayList<Hop>( + Collections.singleton(root)), recompile).get(0); } /** @@ -679,11 +681,9 @@ public class SpoofCompiler //update input hops (order-preserving) HashSet<Long> inputHopIDs = tpl.getInputHopIDs(false); - ArrayList<Hop> tmp = new ArrayList<Hop>(); - for( Hop input : inHops ) - if( inputHopIDs.contains(input.getHopID()) ) - tmp.add(input); - inHops = tmp.toArray(new Hop[0]); + inHops = Arrays.stream(inHops) + .filter(p -> inputHopIDs.contains(p.getHopID())) + .toArray(Hop[]::new); cplans2.put(e.getKey(), new Pair<Hop[],CNodeTpl>(inHops, tpl)); //remove invalid plans with column indexing on main input http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java index 2f2c3e7..efe468e 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java @@ -20,10 +20,10 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; -import java.util.Arrays; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysml.runtime.util.UtilFunctions; public abstract class CNode { @@ -59,12 +59,23 @@ public abstract class CNode return _inputs; } + public boolean isGenerated() { + return _generated; + } + + public void resetGenerated() { + if( isGenerated() ) + for( CNode cn : _inputs ) + cn.resetGenerated(); + _generated = false; + } + public String createVarname() { _genVar = "TMP"+_seqVar.getNextID(); return _genVar; } - protected String getCurrentVarName() { + protected static String getCurrentVarName() { return "TMP"+(_seqVar.getCurrentID()-1); } @@ -76,13 +87,6 @@ public abstract class CNode return getVarname(); } - public void resetGenerated() { - if( _generated ) - for( CNode cn : _inputs ) - cn.resetGenerated(); - _generated = false; - } - public void resetHash() { _hash = 0; } @@ -163,21 +167,19 @@ public abstract class CNode @Override public int hashCode() { if( _hash == 0 ) { - int numIn = _inputs.size(); - int[] tmp = new int[numIn + 3]; //include inputs, partitioned by matrices and scalars to increase //reuse in case of interleaved inputs (see CNodeTpl.renameInputs) - int pos = 0; + int h = 1; for( CNode c : _inputs ) if( c.getDataType()==DataType.MATRIX ) - tmp[pos++] = c.hashCode(); + h = UtilFunctions.intHashCode(h, c.hashCode()); for( CNode c : _inputs ) if( c.getDataType()!=DataType.MATRIX ) - tmp[pos++] = c.hashCode(); - tmp[numIn+0] = (_output!=null)?_output.hashCode():0; - tmp[numIn+1] = (_dataType!=null)?_dataType.hashCode():0; - tmp[numIn+2] = Boolean.valueOf(_literal).hashCode(); - _hash = Arrays.hashCode(tmp); + h = UtilFunctions.intHashCode(h, c.hashCode()); + h = UtilFunctions.intHashCode(h, (_output!=null)?_output.hashCode():0); + h = UtilFunctions.intHashCode(h, (_dataType!=null)?_dataType.hashCode():0); + h = UtilFunctions.intHashCode(h, Boolean.hashCode(_literal)); + _hash = h; } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index 7ed2408..6e72ae1 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -19,11 +19,10 @@ package org.apache.sysml.hops.codegen.cplan; -import java.util.Arrays; - import org.apache.commons.lang.StringUtils; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeBinary extends CNode @@ -234,7 +233,7 @@ public class CNodeBinary extends CNode @Override public String codegen(boolean sparse) { - if( _generated ) + if( isGenerated() ) return ""; StringBuilder sb = new StringBuilder(); @@ -251,19 +250,19 @@ public class CNodeBinary extends CNode && _inputs.get(1).getDataType().isMatrix()); String var = createVarname(); String tmp = _type.getTemplate(lsparse, scalarVector); - tmp = tmp.replaceAll("%TMP%", var); + tmp = tmp.replace("%TMP%", var); //replace input references and start indexes for( int j=1; j<=2; j++ ) { String varj = _inputs.get(j-1).getVarname(); //replace sparse and dense inputs - tmp = tmp.replaceAll("%IN"+j+"v%", varj+"vals"); - tmp = tmp.replaceAll("%IN"+j+"i%", varj+"ix"); - tmp = tmp.replaceAll("%IN"+j+"%", varj ); + tmp = tmp.replace("%IN"+j+"v%", varj+"vals"); + tmp = tmp.replace("%IN"+j+"i%", varj+"ix"); + tmp = tmp.replace("%IN"+j+"%", varj ); //replace start position of main input - tmp = tmp.replaceAll("%POS"+j+"%", (_inputs.get(j-1) instanceof CNodeData + tmp = tmp.replace("%POS"+j+"%", (_inputs.get(j-1) instanceof CNodeData && _inputs.get(j-1).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : TemplateUtils.isMatrix(_inputs.get(j-1)) ? "rowIndex*len" : "0" : "0"); } @@ -431,9 +430,8 @@ public class CNodeBinary extends CNode @Override public int hashCode() { if( _hash == 0 ) { - int h1 = super.hashCode(); - int h2 = _type.hashCode(); - _hash = Arrays.hashCode(new int[]{h1,h2}); + _hash = UtilFunctions.intHashCode( + super.hashCode(), _type.hashCode()); } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java index aec2600..062e9a0 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java @@ -20,11 +20,11 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; -import java.util.Arrays; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeCell extends CNodeTpl { @@ -102,7 +102,7 @@ public class CNodeCell extends CNodeTpl @Override public void renameInputs() { - rReplaceDataNode(_output, _inputs.get(0), "a"); + rRenameDataNode(_output, _inputs.get(0), "a"); renameInputs(_inputs, 1); } @@ -114,16 +114,16 @@ public class CNodeCell extends CNodeTpl String tmpDense = _output.codegen(false); _output.resetGenerated(); - tmp = tmp.replaceAll("%TMP%", createVarname()); - tmp = tmp.replaceAll("%BODY_dense%", tmpDense); + tmp = tmp.replace("%TMP%", createVarname()); + tmp = tmp.replace("%BODY_dense%", tmpDense); //return last TMP - tmp = tmp.replaceAll("%OUT%", getCurrentVarName()); + tmp = tmp.replace("%OUT%", getCurrentVarName()); //replace meta data information - tmp = tmp.replaceAll("%TYPE%", getCellType().name()); - tmp = tmp.replaceAll("%AGG_OP%", (_aggOp!=null) ? "AggOp."+_aggOp.name() : "null" ); - tmp = tmp.replaceAll("%SPARSE_SAFE%", String.valueOf(isSparseSafe())); + tmp = tmp.replace("%TYPE%", getCellType().name()); + tmp = tmp.replace("%AGG_OP%", (_aggOp!=null) ? "AggOp."+_aggOp.name() : "null" ); + tmp = tmp.replace("%SPARSE_SAFE%", String.valueOf(isSparseSafe())); return tmp; } @@ -157,13 +157,13 @@ public class CNodeCell extends CNodeTpl @Override public int hashCode() { if( _hash == 0 ) { - int h1 = super.hashCode(); - int h2 = _type.hashCode(); - int h3 = (_aggOp!=null) ? _aggOp.hashCode() : 0; - int h4 = Boolean.valueOf(_sparseSafe).hashCode(); - int h5 = Boolean.valueOf(_requiresCastdtm).hashCode(); + int h = super.hashCode(); + h = UtilFunctions.intHashCode(h, _type.hashCode()); + h = UtilFunctions.intHashCode(h, (_aggOp!=null) ? _aggOp.hashCode() : 0); + h = UtilFunctions.intHashCode(h, Boolean.hashCode(_sparseSafe)); + h = UtilFunctions.intHashCode(h, Boolean.hashCode(_requiresCastdtm)); //note: _multipleConsumers irrelevant for plan comparison - _hash = Arrays.hashCode(new int[]{h1,h2,h3,h4,h5}); + _hash = h; } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java index e7fee75..ce343cf 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java @@ -19,15 +19,14 @@ package org.apache.sysml.hops.codegen.cplan; -import java.util.Arrays; - import org.apache.sysml.hops.Hop; import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeData extends CNode { - protected final String _name; protected final long _hopID; + protected String _name; private boolean _strictEquals; public CNodeData(Hop hop) { @@ -69,6 +68,10 @@ public class CNodeData extends CNode return _hopID; } + public void setName(String name) { + _name = name; + } + public void setStrictEquals(boolean flag) { _strictEquals = flag; _hash = 0; @@ -92,10 +95,9 @@ public class CNodeData extends CNode @Override public int hashCode() { if( _hash == 0 ) { - int h1 = super.hashCode(); - int h2 = (isLiteral() || !_strictEquals) ? - _name.hashCode() : Long.hashCode(_hopID); - _hash = Arrays.hashCode(new int[]{h1,h2}); + _hash = UtilFunctions.intHashCode( + super.hashCode(), (isLiteral() || !_strictEquals) ? + _name.hashCode() : Long.hashCode(_hopID)); } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java index 2ffe72c..b6b3a80 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java @@ -26,6 +26,7 @@ import org.apache.commons.collections.CollectionUtils; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeMultiAgg extends CNodeTpl { @@ -90,7 +91,7 @@ public class CNodeMultiAgg extends CNodeTpl @Override public void renameInputs() { - rReplaceDataNode(_outputs, _inputs.get(0), "a"); // input matrix + rRenameDataNode(_outputs, _inputs.get(0), "a"); // input matrix renameInputs(_outputs, _inputs, 1); } @@ -119,8 +120,8 @@ public class CNodeMultiAgg extends CNodeTpl } //replace class name and body - tmp = tmp.replaceAll("%TMP%", createVarname()); - tmp = tmp.replaceAll("%BODY_dense%", sb.toString()); + tmp = tmp.replace("%TMP%", createVarname()); + tmp = tmp.replace("%BODY_dense%", sb.toString()); //replace meta data information String aggList = ""; @@ -128,7 +129,7 @@ public class CNodeMultiAgg extends CNodeTpl aggList += !aggList.isEmpty() ? "," : ""; aggList += "AggOp."+aggOp.name(); } - tmp = tmp.replaceAll("%AGG_OP%", aggList); + tmp = tmp.replace("%AGG_OP%", aggList); return tmp; } @@ -153,13 +154,12 @@ public class CNodeMultiAgg extends CNodeTpl @Override public int hashCode() { if( _hash == 0 ) { - int[] tmp = new int[2*_outputs.size()+1]; - tmp[0] = super.hashCode(); + int h = super.hashCode(); for( int i=0; i<_outputs.size(); i++ ) { - tmp[1+2*i] = _outputs.get(i).hashCode(); - tmp[1+2*i+1] = _aggOps.get(i).hashCode(); + h = UtilFunctions.intHashCode(h, UtilFunctions.intHashCode( + _outputs.get(i).hashCode(), _aggOps.get(i).hashCode())); } - _hash = Arrays.hashCode(tmp); + _hash = h; } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java index a96decb..d6a1d34 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java @@ -20,10 +20,10 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; -import java.util.Arrays; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeOuterProduct extends CNodeTpl @@ -57,9 +57,9 @@ public class CNodeOuterProduct extends CNodeTpl @Override public void renameInputs() { - rReplaceDataNode(_output, _inputs.get(0), "a"); - rReplaceDataNode(_output, _inputs.get(1), "a1"); // u - rReplaceDataNode(_output, _inputs.get(2), "a2"); // v + rRenameDataNode(_output, _inputs.get(0), "a"); + rRenameDataNode(_output, _inputs.get(1), "a1"); // u + rRenameDataNode(_output, _inputs.get(2), "a2"); // v renameInputs(_inputs, 3); } @@ -72,25 +72,25 @@ public class CNodeOuterProduct extends CNodeTpl String tmpDense = _output.codegen(false); _output.resetGenerated(); - tmp = tmp.replaceAll("%TMP%", createVarname()); + tmp = tmp.replace("%TMP%", createVarname()); if(_type == OutProdType.LEFT_OUTER_PRODUCT || _type == OutProdType.RIGHT_OUTER_PRODUCT) { - tmp = tmp.replaceAll("%BODY_dense%", tmpDense); - tmp = tmp.replaceAll("%OUT%", "c"); - tmp = tmp.replaceAll("%BODY_cellwise%", ""); - tmp = tmp.replaceAll("%OUT_cellwise%", "0"); + tmp = tmp.replace("%BODY_dense%", tmpDense); + tmp = tmp.replace("%OUT%", "c"); + tmp = tmp.replace("%BODY_cellwise%", ""); + tmp = tmp.replace("%OUT_cellwise%", "0"); } else { - tmp = tmp.replaceAll("%BODY_dense%", ""); - tmp = tmp.replaceAll("%BODY_cellwise%", tmpDense); - tmp = tmp.replaceAll("%OUT_cellwise%", getCurrentVarName()); + tmp = tmp.replace("%BODY_dense%", ""); + tmp = tmp.replace("%BODY_cellwise%", tmpDense); + tmp = tmp.replace("%OUT_cellwise%", getCurrentVarName()); } //replace size information - tmp = tmp.replaceAll("%LEN%", "k"); + tmp = tmp.replace("%LEN%", "k"); - tmp = tmp.replaceAll("%POSOUT%", "ci"); + tmp = tmp.replace("%POSOUT%", "ci"); - tmp = tmp.replaceAll("%TYPE%", _type.toString()); + tmp = tmp.replace("%TYPE%", _type.toString()); return tmp; } @@ -143,10 +143,9 @@ public class CNodeOuterProduct extends CNodeTpl @Override public int hashCode() { if( _hash == 0 ) { - int h1 = super.hashCode(); - int h2 = _type.hashCode(); - int h3 = Boolean.valueOf(_transposeOutput).hashCode(); - _hash = Arrays.hashCode(new int[]{h1,h2,h3}); + int h = UtilFunctions.intHashCode(super.hashCode(), _type.hashCode()); + h = UtilFunctions.intHashCode(h, Boolean.hashCode(_transposeOutput)); + _hash = h; } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java index 546bf60..caf379b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java @@ -20,12 +20,12 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; -import java.util.Arrays; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeRow extends CNodeTpl { @@ -78,7 +78,7 @@ public class CNodeRow extends CNodeTpl @Override public void renameInputs() { - rReplaceDataNode(_output, _inputs.get(0), "a"); // input matrix + rRenameDataNode(_output, _inputs.get(0), "a"); // input matrix renameInputs(_inputs, 1); } @@ -93,22 +93,22 @@ public class CNodeRow extends CNodeTpl _output.resetGenerated(); String tmpSparse = _output.codegen(true) + getOutputStatement(_output.getVarname()); - tmp = tmp.replaceAll("%TMP%", createVarname()); - tmp = tmp.replaceAll("%BODY_dense%", tmpDense); - tmp = tmp.replaceAll("%BODY_sparse%", tmpSparse); + tmp = tmp.replace("%TMP%", createVarname()); + tmp = tmp.replace("%BODY_dense%", tmpDense); + tmp = tmp.replace("%BODY_sparse%", tmpSparse); //replace outputs - tmp = tmp.replaceAll("%OUT%", "c"); - tmp = tmp.replaceAll("%POSOUT%", "0"); + tmp = tmp.replace("%OUT%", "c"); + tmp = tmp.replace("%POSOUT%", "0"); //replace size information - tmp = tmp.replaceAll("%LEN%", "len"); + tmp = tmp.replace("%LEN%", "len"); //replace colvector information and number of vector intermediates - tmp = tmp.replaceAll("%TYPE%", _type.name()); - tmp = tmp.replaceAll("%CBIND0%", String.valueOf( + tmp = tmp.replace("%TYPE%", _type.name()); + tmp = tmp.replace("%CBIND0%", String.valueOf( TemplateUtils.isUnary(_output, UnaryType.CBIND0))); - tmp = tmp.replaceAll("%VECT_MEM%", String.valueOf(_numVectors)); + tmp = tmp.replace("%VECT_MEM%", String.valueOf(_numVectors)); return tmp; } @@ -153,10 +153,8 @@ public class CNodeRow extends CNodeTpl @Override public int hashCode() { if( _hash == 0 ) { - int h1 = super.hashCode(); - int h2 = _type.hashCode(); - int h3 = _numVectors; - _hash = Arrays.hashCode(new int[]{h1,h2,h3}); + int h = UtilFunctions.intHashCode(super.hashCode(), _type.hashCode()); + _hash = UtilFunctions.intHashCode(h, Integer.hashCode(_numVectors)); } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java index 0aee40a..9a4b103 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java @@ -19,9 +19,8 @@ package org.apache.sysml.hops.codegen.cplan; -import java.util.Arrays; - import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeTernary extends CNode @@ -80,7 +79,7 @@ public class CNodeTernary extends CNode @Override public String codegen(boolean sparse) { - if( _generated ) + if( isGenerated() ) return ""; StringBuilder sb = new StringBuilder(); @@ -93,13 +92,13 @@ public class CNodeTernary extends CNode //generate binary operation String var = createVarname(); String tmp = _type.getTemplate(sparse); - tmp = tmp.replaceAll("%TMP%", var); + tmp = tmp.replace("%TMP%", var); for( int j=1; j<=3; j++ ) { String varj = _inputs.get(j-1).getVarname(); //replace sparse and dense inputs - tmp = tmp.replaceAll("%IN"+j+"v%", + tmp = tmp.replace("%IN"+j+"v%", varj+(varj.startsWith("b")?"":"vals") ); - tmp = tmp.replaceAll("%IN"+j+"%", varj ); + tmp = tmp.replace("%IN"+j+"%", varj ); } sb.append(tmp); @@ -140,9 +139,8 @@ public class CNodeTernary extends CNode @Override public int hashCode() { if( _hash == 0 ) { - int h1 = super.hashCode(); - int h2 = _type.hashCode(); - _hash = Arrays.hashCode(new int[]{h1,h2}); + _hash = UtilFunctions.intHashCode( + super.hashCode(), _type.hashCode()); } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java index 35ab7f8..8d61588 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java @@ -26,9 +26,6 @@ import java.util.HashSet; import java.util.List; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; -import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; -import org.apache.sysml.hops.codegen.template.TemplateUtils; -import org.apache.sysml.parser.Expression.DataType; public abstract class CNodeTpl extends CNode implements Cloneable { @@ -70,6 +67,11 @@ public abstract class CNodeTpl extends CNode implements Cloneable getOutput().resetVisitStatus(); } + public static void resetVisitStatus(List<CNode> outputs) { + for( CNode output : outputs ) + output.resetVisitStatus(); + } + public String codegen() { return codegen(false); } @@ -87,127 +89,56 @@ public abstract class CNodeTpl extends CNode implements Cloneable } protected void renameInputs(List<CNode> outputs, ArrayList<CNode> inputs, int startIndex) { - //create map of hopID to data nodes with new names, used for CSE - HashMap<Long, CNode> nodes = new HashMap<Long, CNode>(); + //create map of hopID to new names used for code generation + HashMap<Long, String> newNames = new HashMap<Long, String>(); for(int i=startIndex, sPos=0, mPos=0; i < inputs.size(); i++) { CNode cnode = inputs.get(i); if( cnode instanceof CNodeData && ((CNodeData)cnode).isLiteral() ) continue; - CNodeData cdata = (CNodeData)cnode; - if( cdata.getDataType() == DataType.SCALAR || ( cdata.getNumCols() == 0 && cdata.getNumRows() == 0) ) - nodes.put(cdata.getHopID(), new CNodeData(cdata, "scalars["+ mPos++ +"]")); - else - nodes.put(cdata.getHopID(), new CNodeData(cdata, "b["+ sPos++ +"]")); + newNames.put(((CNodeData)cnode).getHopID(), cnode.getDataType().isScalar() ? + "scalars["+ mPos++ +"]" : "b["+ sPos++ +"]"); } //single pass to replace all names + resetVisitStatus(outputs); for( CNode output : outputs ) - rReplaceDataNode(output, nodes, new HashMap<Long, CNode>()); + rRenameDataNodes(output, newNames); } - protected void rReplaceDataNode( CNode root, CNode input, String newName ) { - if( !(input instanceof CNodeData) ) - return; - - //create temporary name mapping - HashMap<Long, CNode> names = new HashMap<Long, CNode>(); - CNodeData tmp = (CNodeData)input; - names.put(tmp.getHopID(), new CNodeData(tmp, newName)); - - rReplaceDataNode(root, names, new HashMap<Long,CNode>()); + protected void rRenameDataNode( CNode root, CNode input, String newName ) { + rRenameDataNode(Collections.singletonList(root), input, newName); } - protected void rReplaceDataNode( ArrayList<CNode> roots, CNode input, String newName ) { + protected void rRenameDataNode( List<CNode> roots, CNode input, String newName ) { if( !(input instanceof CNodeData) ) return; - + //create temporary name mapping - HashMap<Long, CNode> names = new HashMap<Long, CNode>(); - CNodeData tmp = (CNodeData)input; - names.put(tmp.getHopID(), new CNodeData(tmp, newName)); + HashMap<Long, String> newNames = new HashMap<Long, String>(); + newNames.put(((CNodeData)input).getHopID(), newName); + //single pass to replace all names + resetVisitStatus(roots); for( CNode root : roots ) - rReplaceDataNode(root, names, new HashMap<Long,CNode>()); + rRenameDataNodes(root, newNames); } - /** - * Recursively searches for data nodes and replaces them if found. - * - * @param node current node in recursive descend - * @param dnodes prepared data nodes, identified by own hop id - * @param lnodes memoized lookup nodes, identified by data node hop id - */ - protected void rReplaceDataNode( CNode node, HashMap<Long, CNode> dnodes, HashMap<Long, CNode> lnodes ) - { - for( int i=0; i<node._inputs.size(); i++ ) { - //recursively process children - rReplaceDataNode(node._inputs.get(i), dnodes, lnodes); - - //replace leaf data node - if( node._inputs.get(i) instanceof CNodeData ) { - CNodeData tmp = (CNodeData)node._inputs.get(i); - if( dnodes.containsKey(tmp.getHopID()) ) - node._inputs.set(i, dnodes.get(tmp.getHopID())); - } - - //replace lookup on top of leaf data node (for CSE only) - if( node._inputs.get(i) instanceof CNodeUnary - && node._inputs.get(i)._inputs.get(0) instanceof CNodeData - && (((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_R - || ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_RC)) { - CNodeData tmp = (CNodeData)node._inputs.get(i)._inputs.get(0); - if( !lnodes.containsKey(tmp.getHopID()) ) - lnodes.put(tmp.getHopID(), node._inputs.get(i)); - else - node._inputs.set(i, lnodes.get(tmp.getHopID())); - } - } - } - - public void rReplaceDataNode( CNode node, long hopID, CNode newNode ) - { - for( int i=0; i<node._inputs.size(); i++ ) { - //replace leaf node - if( node._inputs.get(i) instanceof CNodeData ) { - CNodeData tmp = (CNodeData)node._inputs.get(i); - if( tmp.getHopID() == hopID ) - node._inputs.set(i, newNode); - } - //recursively process children - rReplaceDataNode(node._inputs.get(i), hopID, newNode); - - //remove unnecessary lookups - if( node._inputs.get(i) instanceof CNodeUnary - && (((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_R - || ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_RC) - && node._inputs.get(i)._inputs.get(0).getDataType()==DataType.SCALAR) - node._inputs.set(i, node._inputs.get(i)._inputs.get(0)); + protected void rRenameDataNodes( CNode node, HashMap<Long, String> newNames ) { + if( node.isVisited() ) + return; + + //recursively process children + for( CNode c : node.getInput() ) + rRenameDataNodes(c, newNames); + + //rename data node + if( node instanceof CNodeData ) { + CNodeData dnode = (CNodeData) node; + if( newNames.containsKey(dnode.getHopID()) ) + dnode.setName(newNames.get(dnode.getHopID())); } - } - - public void rInsertLookupNode( CNode node, long hopID, HashMap<Long, CNode> memo, UnaryType lookupType ) - { - for( int i=0; i<node._inputs.size(); i++ ) { - //recursively process children - rInsertLookupNode(node._inputs.get(i), hopID, memo, lookupType); - //replace leaf node - if( node._inputs.get(i) instanceof CNodeData ) { - CNodeData tmp = (CNodeData)node._inputs.get(i); - if( tmp.getHopID() == hopID ) { - //use memo structure to retain DAG structure - CNode lookup = memo.get(hopID); - if( lookup == null && !TemplateUtils.isLookup(node) ) { - lookup = new CNodeUnary(tmp, lookupType); - memo.put(hopID, lookup); - } - else if( TemplateUtils.isLookup(node) ) - ((CNodeUnary)node).setType(lookupType); - else - node._inputs.set(i, lookup); - } - } - } + node.setVisited(); } public void rReorderCommutativeBinaryOps(CNode node, long mainHopID) { http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index 4c60824..c560878 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -19,10 +19,9 @@ package org.apache.sysml.hops.codegen.cplan; -import java.util.Arrays; - import org.apache.commons.lang.StringUtils; import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeUnary extends CNode @@ -158,9 +157,9 @@ public class CNodeUnary extends CNode @Override public String codegen(boolean sparse) { - if( _generated ) + if( isGenerated() ) return ""; - + StringBuilder sb = new StringBuilder(); //generate children @@ -170,21 +169,21 @@ public class CNodeUnary extends CNode boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData); String var = createVarname(); String tmp = _type.getTemplate(lsparse); - tmp = tmp.replaceAll("%TMP%", var); + tmp = tmp.replace("%TMP%", var); String varj = _inputs.get(0).getVarname(); //replace sparse and dense inputs - tmp = tmp.replaceAll("%IN1v%", varj+"vals"); - tmp = tmp.replaceAll("%IN1i%", varj+"ix"); - tmp = tmp.replaceAll("%IN1%", varj ); + tmp = tmp.replace("%IN1v%", varj+"vals"); + tmp = tmp.replace("%IN1i%", varj+"ix"); + tmp = tmp.replace("%IN1%", varj ); //replace start position of main input String spos = (!varj.startsWith("b") && _inputs.get(0) instanceof CNodeData && _inputs.get(0).getDataType().isMatrix()) ? varj+"i" : "0"; - tmp = tmp.replaceAll("%POS1%", spos); - tmp = tmp.replaceAll("%POS2%", spos); + tmp = tmp.replace("%POS1%", spos); + tmp = tmp.replace("%POS2%", spos); sb.append(tmp); @@ -280,9 +279,8 @@ public class CNodeUnary extends CNode @Override public int hashCode() { if( _hash == 0 ) { - int h1 = super.hashCode(); - int h2 = _type.hashCode(); - _hash = Arrays.hashCode(new int[]{h1,h2}); + _hash = UtilFunctions.intHashCode( + super.hashCode(), _type.hashCode()); } return _hash; } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java index 95f0ed7..9916c0f 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java @@ -19,9 +19,8 @@ package org.apache.sysml.hops.codegen.template; -import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import org.apache.sysml.hops.codegen.cplan.CNode; @@ -41,30 +40,31 @@ public class CPlanCSERewriter List<CNode> outputs = (tpl instanceof CNodeMultiAgg) ? ((CNodeMultiAgg)tpl).getOutputs() : - Arrays.asList(tpl.getOutput()); + Collections.singletonList(tpl.getOutput()); //step 1: set data nodes to strict comparison - HashSet<Long> memo = new HashSet<Long>(); + tpl.resetVisitStatusOutputs(); for( CNode out : outputs ) - rSetStrictDataNodeComparision(out, memo, true); + rSetStrictDataNodeComparision(out, true); //step 2: perform common subexpression elimination HashMap<CNode,CNode> cseSet = new HashMap<CNode,CNode>(); - memo.clear(); + tpl.resetVisitStatusOutputs(); for( CNode out : outputs ) - rEliminateCommonSubexpression(out, cseSet, memo); + rEliminateCommonSubexpression(out, cseSet); //step 3: reset data nodes to imprecise comparison - memo.clear(); + tpl.resetVisitStatusOutputs(); for( CNode out : outputs ) - rSetStrictDataNodeComparision(out, memo, true); + rSetStrictDataNodeComparision(out, true); + tpl.resetVisitStatusOutputs(); return tpl; } - private void rEliminateCommonSubexpression(CNode current, HashMap<CNode,CNode> cseSet, HashSet<Long> memo) { + private void rEliminateCommonSubexpression(CNode current, HashMap<CNode,CNode> cseSet) { //avoid redundant re-evaluation - if( memo.contains(current.getID()) ) + if( current.isVisited() ) return; //replace input with existing common subexpression @@ -76,25 +76,25 @@ public class CPlanCSERewriter //process inputs recursively for( CNode input : current.getInput() ) - rEliminateCommonSubexpression(input, cseSet, memo); + rEliminateCommonSubexpression(input, cseSet); //process node itself cseSet.put(current, current); - memo.add(current.getID()); + current.setVisited(); } - private void rSetStrictDataNodeComparision(CNode current, HashSet<Long> memo, boolean flag) { + private void rSetStrictDataNodeComparision(CNode current, boolean flag) { //avoid redundant re-evaluation - if( memo.contains(current.getID()) ) + if( current.isVisited() ) return; //process inputs recursively and node itself for( CNode input : current.getInput() ) { - rSetStrictDataNodeComparision(input, memo, flag); + rSetStrictDataNodeComparision(input, flag); input.resetHash(); } if( current instanceof CNodeData ) ((CNodeData)current).setStrictEquals(flag); - memo.add(current.getID()); + current.setVisited(); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 19d5a30..6982470 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -28,6 +28,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map.Entry; +import java.util.Set; import java.util.stream.Collectors; import org.apache.commons.logging.Log; @@ -63,8 +64,8 @@ public class CPlanMemoTable } public boolean contains(long hopID, TemplateType type) { - return contains(hopID) && get(hopID).stream() - .filter(p -> p.type==type).findAny().isPresent(); + return contains(hopID) && get(hopID) + .stream().anyMatch(p -> p.type==type); } public int countEntries(long hopID) { @@ -111,9 +112,9 @@ public class CPlanMemoTable _plans.get(hop.getHopID()).addAll(P.plans); } - public void remove(Hop hop, HashSet<MemoTableEntry> blackList) { - _plans.put(hop.getHopID(), _plans.get(hop.getHopID()).stream() - .filter(p -> !blackList.contains(p)).collect(Collectors.toList())); + public void remove(Hop hop, Set<MemoTableEntry> blackList) { + _plans.get(hop.getHopID()) + .removeIf(p -> blackList.contains(p)); } public void setDistinct(long hopID, List<MemoTableEntry> plans) { @@ -167,8 +168,7 @@ public class CPlanMemoTable while( iter.hasNext() ) { Entry<Long, List<MemoTableEntry>> e = iter.next(); if( !ix.contains(e.getKey()) ) { - e.setValue(e.getValue().stream().filter( - p -> p.hasPlanRef()).collect(Collectors.toList())); + e.getValue().removeIf(p -> !p.hasPlanRef()); if( e.getValue().isEmpty() ) iter.remove(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java index 80ff725..f8a12fd 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java @@ -74,11 +74,11 @@ public abstract class PlanSelection return _bestPlans; } - public boolean isVisited(long hopID, TemplateType type) { + protected boolean isVisited(long hopID, TemplateType type) { return _visited.contains(new VisitMark(hopID, type)); } - public void setVisited(long hopID, TemplateType type) { + protected void setVisited(long hopID, TemplateType type) { _visited.add(new VisitMark(hopID, type)); } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index e3435e5..742f4d6 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -22,6 +22,7 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Map.Entry; @@ -191,16 +192,9 @@ public class PlanSelectionFuseCostBased extends PlanSelection visited, partition, ret); //remove special-case materialization points - Iterator<Long> iter = ret.iterator(); - while(iter.hasNext()) { - Long hopID = iter.next(); - //remove root nodes w/ multiple consumers - if( roots.contains(hopID) ) - iter.remove(); - //remove tsmm input if consumed in partition - else if( HopRewriteUtils.isTsmmInput(memo._hopRefs.get(hopID))) - iter.remove(); - } + //(root nodes w/ multiple consumers, tsmm input if consumed in partition) + ret.removeIf(hopID -> roots.contains(hopID) + || HopRewriteUtils.isTsmmInput(memo._hopRefs.get(hopID))); return ret; } @@ -284,11 +278,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection Hop.resetVisitStatus(roots); //remove operators with assigned multi-agg plans - Iterator<Long> iter = fullAggs.iterator(); - while( iter.hasNext() ) { - if( memo.contains(iter.next(), TemplateType.MultiAggTpl) ) - iter.remove(); - } + fullAggs.removeIf(p -> memo.contains(p, TemplateType.MultiAggTpl)); //check applicability for further analysis if( fullAggs.size() <= 1 ) @@ -471,7 +461,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection MemoTableEntry me2 = entries.get(1); MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2); if( rmEntry != null ) { - memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(Arrays.asList(rmEntry))); + memo.remove(memo._hopRefs.get(hopID), Collections.singleton(rmEntry)); memo._plansBlacklist.remove(rmEntry.input(rmEntry.getPlanRefIndex())); if( LOG.isTraceEnabled() ) LOG.trace("Removed dominated outer product memo table entry: " + rmEntry); http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java index 8ed52f6..f5527f5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java @@ -42,7 +42,7 @@ public abstract class TemplateBase OPEN, } - protected TemplateType _type = null; + protected final TemplateType _type; protected boolean _closed = false; protected TemplateBase(TemplateType type) { http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 26d477d..e94d9a5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -23,8 +23,6 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; -import java.util.List; -import java.util.stream.Collectors; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; @@ -122,9 +120,9 @@ public class TemplateCell extends TemplateBase //reorder inputs (ensure matrices/vectors come first) and prune literals //note: we order by number of cells and subsequently sparsity to ensure //that sparse inputs are used as the main input w/o unnecessary conversion - List<Hop> sinHops = inHops.stream() + Hop[] sinHops = inHops.stream() .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) - .sorted(new HopInputComparator()).collect(Collectors.toList()); + .sorted(new HopInputComparator()).toArray(Hop[]::new); //construct template node ArrayList<CNode> inputs = new ArrayList<CNode>(); @@ -134,13 +132,13 @@ public class TemplateCell extends TemplateBase CNodeCell tpl = new CNodeCell(inputs, output); tpl.setCellType(TemplateUtils.getCellType(hop)); tpl.setAggOp(TemplateUtils.getAggOp(hop)); - tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && hop.getInput().contains(sinHops.get(0))) - || (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && hop.getInput().get(0) == sinHops.get(0))); + tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && hop.getInput().contains(sinHops[0])) + || (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && hop.getInput().get(0) == sinHops[0])); tpl.setRequiresCastDtm(hop instanceof AggBinaryOp); tpl.setBeginLine(hop.getBeginLine()); // return cplan instance - return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); + return new Pair<Hop[],CNodeTpl>(sinHops, tpl); } protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java index a75e07f..2d53e7c 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java @@ -22,8 +22,6 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; -import java.util.List; -import java.util.stream.Collectors; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.AggOp; @@ -88,9 +86,9 @@ public class TemplateMultiAgg extends TemplateCell //reorder inputs (ensure matrices/vectors come first) and prune literals //note: we order by number of cells and subsequently sparsity to ensure //that sparse inputs are used as the main input w/o unnecessary conversion - List<Hop> sinHops = inHops.stream() + Hop[] sinHops = inHops.stream() .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) - .sorted(new HopInputComparator()).collect(Collectors.toList()); + .sorted(new HopInputComparator()).toArray(Hop[]::new); //construct template node ArrayList<CNode> inputs = new ArrayList<CNode>(); @@ -113,6 +111,6 @@ public class TemplateMultiAgg extends TemplateCell tpl.setBeginLine(hop.getBeginLine()); // return cplan instance - return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); + return new Pair<Hop[],CNodeTpl>(sinHops, tpl); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 71091cf..0a1a651 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -23,8 +23,6 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; -import java.util.List; -import java.util.stream.Collectors; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; @@ -135,9 +133,9 @@ public class TemplateRow extends TemplateBase hop.resetVisitStatus(); //reorder inputs (ensure matrix is first input, and other inputs ordered by size) - List<Hop> sinHops = inHops.stream() + Hop[] sinHops = inHops.stream() .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) - .sorted(new HopInputComparator(inHops2.get("X"))).collect(Collectors.toList()); + .sorted(new HopInputComparator(inHops2.get("X"))).toArray(Hop[]::new); //construct template node ArrayList<CNode> inputs = new ArrayList<CNode>(); @@ -145,15 +143,15 @@ public class TemplateRow extends TemplateBase inputs.add(tmp.get(in.getHopID())); CNode output = tmp.get(hop.getHopID()); CNodeRow tpl = new CNodeRow(inputs, output); - tpl.setRowType(TemplateUtils.getRowType(hop, sinHops.get(0))); + tpl.setRowType(TemplateUtils.getRowType(hop, sinHops[0])); tpl.setNumVectorIntermediates(TemplateUtils .countVectorIntermediates(output, new HashSet<Long>())); tpl.getOutput().resetVisitStatus(); - tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops.get(0).getHopID()); + tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops[0].getHopID()); tpl.setBeginLine(hop.getBeginLine()); // return cplan instance - return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); + return new Pair<Hop[],CNodeTpl>(sinHops, tpl); } private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) http://git-wip-us.apache.org/repos/asf/systemml/blob/4de8d684/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java index 8a62476..42863a8 100644 --- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java @@ -43,6 +43,10 @@ public class UtilFunctions public static final long ADD_PRIME1 = 99991; public static final int DIVIDE_PRIME = 1405695061; + public static int intHashCode(int key1, int key2) { + return 31 * (31 + key1) + key2; + } + public static int longHashCode(long key1) { return (int)(key1^(key1>>>32)); }
