This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 61bb660145 [SYSTEMDS-3938] Fix einsum codestyle and code quality
61bb660145 is described below
commit 61bb660145167f391c6525604a5103e6b0ddd011
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Dec 31 07:20:34 2025 +0100
[SYSTEMDS-3938] Fix einsum codestyle and code quality
---
dev/checkstyle/suppressions.xml | 13 --
.../org/apache/sysds/runtime/einsum/EOpNode.java | 28 +--
.../apache/sysds/runtime/einsum/EOpNodeBinary.java | 238 ++++++++++-----------
.../apache/sysds/runtime/einsum/EOpNodeData.java | 23 +-
.../apache/sysds/runtime/einsum/EOpNodeFuse.java | 96 +++++----
.../apache/sysds/runtime/einsum/EOpNodeUnary.java | 9 +-
.../apache/sysds/runtime/einsum/EinsumContext.java | 52 ++---
.../runtime/einsum/EinsumEquationValidator.java | 208 +++++++++---------
.../sysds/runtime/einsum/EinsumSpoofRowwise.java | 31 ++-
.../instructions/cp/EinsumCPInstruction.java | 151 ++++++-------
.../sysds/test/functions/einsum/EinsumTest.java | 82 +++----
11 files changed, 450 insertions(+), 481 deletions(-)
diff --git a/dev/checkstyle/suppressions.xml b/dev/checkstyle/suppressions.xml
index 642bf3beba..9f17ccdd8b 100644
--- a/dev/checkstyle/suppressions.xml
+++ b/dev/checkstyle/suppressions.xml
@@ -40,7 +40,6 @@
<suppress checks="AvoidStarImportCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]resource[\\/]enumeration[\\/]GridBasedEnumerator\.java$"/>
<suppress checks="AvoidStarImportCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]resource[\\/]enumeration[\\/]InterestBasedEnumerator\.java$"/>
<suppress checks="AvoidStarImportCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]resource[\\/]enumeration[\\/]PruneBasedEnumerator\.java$"/>
- <suppress checks="AvoidStarImportCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]cp[\\/]EinsumCPInstruction\.java$"/>
<suppress checks="AvoidStarImportCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]context[\\/]CSRPointer\.java$"/>
<suppress checks="AvoidStarImportCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]io[\\/]ReaderCOGParallel\.java$"/>
<suppress checks="AvoidStarImportCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]io[\\/]ReaderCOG\.java$"/>
@@ -103,8 +102,6 @@
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]hops[\\/]fedplanner[\\/]FederatedPlannerLogger\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]protobuf[\\/]SysdsProtos\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]compress[\\/]lib[\\/]CLALibBinaryCellOp\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EinsumContext\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EinsumEquationValidator\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]DnnGPUInstruction\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]context[\\/]CudaMemoryAllocator\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]context[\\/]GPUContextPool\.java$"/>
@@ -130,14 +127,7 @@
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]functions[\\/]transform[\\/]TransformFrameEncodeBagOfWords\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]gpu[\\/]cudaSupportFunctions[\\/]CudaCublasGeamTest\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]gpu[\\/]cudaSupportFunctions[\\/]CudaCusparseCsrGemmTest\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeBinary\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeData\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeFuse\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeUnary\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNode\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]cp[\\/]EinsumCPInstruction\.java$"/>
<suppress checks="RegexpMultilineCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]matrix[\\/]data[\\/]LibMatrixMult\.java$"/>
- <suppress checks="RegexpMultilineCheck"
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]functions[\\/]einsum[\\/]EinsumTest\.java$"/>
<!-- LocalVariableNameCheck -->
<suppress checks="LocalVariableNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]api[\\/]PythonDMLScript\.java$"/>
@@ -193,7 +183,6 @@
<suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]compress[\\/]lib[\\/]CLALibRightMultBy\.java$"/>
<suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]controlprogram[\\/]federated[\\/]FederatedSSLUtil\.java$"/>
<suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]controlprogram[\\/]paramserv[\\/]ParamServer\.java$"/>
- <suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]cp[\\/]EinsumCPInstruction\.java$"/>
<suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]spark[\\/]UnaryMatrixSPInstruction\.java$"/>
<suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]io[\\/]hdf5[\\/]H5\.java$"/>
<suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]iogen[\\/]ReaderMappingIndex\.java$"/>
@@ -358,6 +347,4 @@
<suppress checks="MethodNameCheck"
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]gpu[\\/]nn[\\/]DNNOperationsGPUTest\.java$"/>
<suppress checks="MethodNameCheck"
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]usertest[\\/]UserInterfaceTest\.java$"/>
<suppress checks="MethodNameCheck"
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]usertest[\\/]pythonapi[\\/]StartupTest\.java$"/>
- <suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeBinary\.java$"/>
- <suppress checks="MethodNameCheck"
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EinsumSpoofRowwise\.java$"/>
</suppressions>
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
index e205c2721a..665e5ae558 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
@@ -27,26 +27,26 @@ import java.util.Arrays;
import java.util.List;
public abstract class EOpNode {
- public Character c1;
- public Character c2;
+ public Character c1;
+ public Character c2;
public Integer dim1;
public Integer dim2;
- public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
- this.c1 = c1;
- this.c2 = c2;
+ public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
+ this.c1 = c1;
+ this.c2 = c2;
this.dim1 = dim1;
this.dim2 = dim2;
- }
+ }
- public String getOutputString() {
- if(c1 == null) return "''";
- if(c2 == null) return c1.toString();
- return c1.toString() + c2.toString();
- }
+ public String getOutputString() {
+ if(c1 == null) return "''";
+ if(c2 == null) return c1.toString();
+ return c1.toString() + c2.toString();
+ }
public abstract List<EOpNode> getChildren();
public String[] recursivePrintString(){
- ArrayList<String[]> inpStrings = new ArrayList<>();
+ List<String[]> inpStrings = new ArrayList<>();
for (EOpNode node : getChildren()) {
inpStrings.add(node.recursivePrintString());
}
@@ -61,8 +61,8 @@ public abstract class EOpNode {
return res;
};
- public abstract MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs,
int numOfThreads, Log LOG);
+ public abstract MatrixBlock computeEOpNode(List<MatrixBlock> inputs,
int numOfThreads, Log LOG);
- public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent,
Character outChar1, Character outChar2);
+ public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent,
Character outChar1, Character outChar2);
}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
index d291712169..071fb6706a 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
@@ -22,16 +22,13 @@ package org.apache.sysds.runtime.einsum;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.logging.Log;
-import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
@@ -47,14 +44,14 @@ import static
org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensur
public class EOpNodeBinary extends EOpNode {
public enum EBinaryOperand { // upper case: char remains, lower case:
summed (reduced) dimension
- ////// mm: //////
- Ba_aC, // -> BC
- aB_Ca, // -> CB
- Ba_Ca, // -> BC
- aB_aC, // -> BC
+ ////// mm: //////
+ Ba_aC, // -> BC
+ aB_Ca, // -> CB
+ Ba_Ca, // -> BC
+ aB_aC, // -> BC
- ////// element-wise multiplications and sums //////
- aB_aB,// elemwise and colsum -> B
+ ////// element-wise multiplications and sums //////
+ aB_aB,// elemwise and colsum -> B
Ab_Ab, // elemwise and rowsum ->A
Ab_bA, // elemwise, either colsum or rowsum -> A
aB_Ba,
@@ -63,23 +60,23 @@ public class EOpNodeBinary extends EOpNode {
aB_a,// -> B
Ab_b, // -> A
- ////// elementwise, no summations: //////
- A_A,// v-elemwise -> A
- AB_AB,// M-M elemwise -> AB
- AB_BA, // M-M.T elemwise -> AB
- AB_A, // M-v colwise -> BA!?
+ ////// elementwise, no summations: //////
+ A_A,// v-elemwise -> A
+ AB_AB,// M-M elemwise -> AB
+ AB_BA, // M-M.T elemwise -> AB
+ AB_A, // M-v colwise -> BA!?
AB_B, // M-v rowwise -> AB
- ////// other //////
+ ////// other //////
a_a,// dot ->
- A_B, // outer mult -> AB
- A_scalar, // v-scalar
- AB_scalar, // m-scalar
- scalar_scalar
- }
- public EOpNode left;
- public EOpNode right;
- public EBinaryOperand operand;
+ A_B, // outer mult -> AB
+ A_scalar, // v-scalar
+ AB_scalar, // m-scalar
+ scalar_scalar
+ }
+ public EOpNode left;
+ public EOpNode right;
+ public EBinaryOperand operand;
private boolean transposeResult;
public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand
operand){
super(null,null,null, null);
@@ -178,137 +175,124 @@ public class EOpNodeBinary extends EOpNode {
}
@Override
- public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int
numThreads, Log LOG) {
- EOpNodeBinary bin = this;
- MatrixBlock left = this.left.computeEOpNode(inputs, numThreads, LOG);
- MatrixBlock right = this.right.computeEOpNode(inputs, numThreads, LOG);
+ public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int
numThreads, Log LOG) {
+ EOpNodeBinary bin = this;
+ MatrixBlock left = this.left.computeEOpNode(inputs, numThreads,
LOG);
+ MatrixBlock right = this.right.computeEOpNode(inputs,
numThreads, LOG);
- AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
+ //AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
- MatrixBlock res;
+ MatrixBlock res;
- switch (bin.operand){
- case AB_AB -> {
- res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
- }
- case A_A -> {
- ensureMatrixBlockColumnVector(left);
- ensureMatrixBlockColumnVector(right);
- res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
- }
- case a_a -> {
- ensureMatrixBlockColumnVector(left);
- ensureMatrixBlockColumnVector(right);
+ switch (bin.operand){
+ case AB_AB -> {
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case A_A -> {
+ ensureMatrixBlockColumnVector(left);
+ ensureMatrixBlockColumnVector(right);
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case a_a -> {
+ ensureMatrixBlockColumnVector(left);
+ ensureMatrixBlockColumnVector(right);
res = new MatrixBlock(0.0);
res.allocateDenseBlock();
res.getDenseBlockValues()[0] =
LibMatrixMult.dotProduct(left.getDenseBlockValues(),
right.getDenseBlockValues(), 0,0 , left.getNumRows());
- }
- case Ab_Ab -> {
+ }
+ case Ab_Ab -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left,
right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new
ArrayList<>(),
null, numThreads);
}
- case aB_aB -> {
+ case aB_aB -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left,
right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new
ArrayList<>(),
null, numThreads);
- }
- case ab_ab -> {
+ }
+ case ab_ab -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left,
right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new
ArrayList<>(),
null, numThreads);
}
- case ab_ba -> {
+ case ab_ba -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left),
List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
null, numThreads);
}
- case Ab_bA -> {
+ case Ab_bA -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left),
List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
null, numThreads);
}
- case aB_Ba -> {
+ case aB_Ba -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left),
List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
null, numThreads);
}
- case AB_BA -> {
- ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
- right = right.reorgOperations(transpose, new MatrixBlock(), 0,
0, 0);
- res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
- }
- case Ba_aC -> {
- res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(),
numThreads);
- }
- case aB_Ca -> {
- res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(),
numThreads);
- }
- case Ba_Ca -> {
- ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
- right = right.reorgOperations(transpose, new MatrixBlock(), 0,
0, 0);
- res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(),
numThreads);
- }
- case aB_aC -> {
- if(false &&
LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(),
right.getNumRows(), right.getNumColumns(), false)){
- res = new MatrixBlock(left.getNumColumns(),
right.getNumColumns(),false);
- res.allocateDenseBlock();
- double[] m1 = left.getDenseBlock().values(0);
- double[] m2 = right.getDenseBlock().values(0);
- double[] c = res.getDenseBlock().values(0);
- int alen = left.getNumColumns();
- int blen = right.getNumColumns();
- for(int i =0;i<left.getNumRows();i++){
-
LibSpoofPrimitives.vectOuterMultAdd(m1,m2,c,i*alen,i*blen, 0,alen,blen);
- }
- }else {
- ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
- left = left.reorgOperations(transpose, new MatrixBlock(),
0, 0, 0);
- res = LibMatrixMult.matrixMult(left, right, new
MatrixBlock(), numThreads);
- }
- }
- case A_scalar, AB_scalar -> {
- res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new
ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
- }
- case AB_B -> {
- ensureMatrixBlockRowVector(right);
- res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
- }
- case Ab_b -> {
+ case AB_BA -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ right = right.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case Ba_aC -> {
+ res = LibMatrixMult.matrixMult(left,right, new
MatrixBlock(), numThreads);
+ }
+ case aB_Ca -> {
+ res = LibMatrixMult.matrixMult(right,left, new
MatrixBlock(), numThreads);
+ }
+ case Ba_Ca -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ right = right.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = LibMatrixMult.matrixMult(left,right, new
MatrixBlock(), numThreads);
+ }
+ case aB_aC -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ left = left.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = LibMatrixMult.matrixMult(left, right, new
MatrixBlock(), numThreads);
+ }
+ case A_scalar, AB_scalar -> {
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new
ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
+ }
+ case AB_B -> {
+ ensureMatrixBlockRowVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case Ab_b -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left),
new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(),
null, numThreads);
}
- case AB_A -> {
- ensureMatrixBlockColumnVector(right);
- res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
- }
- case aB_a -> {
+ case AB_A -> {
+ ensureMatrixBlockColumnVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case aB_a -> {
res =
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left),
new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(),
null, numThreads);
}
- case A_B -> {
- ensureMatrixBlockColumnVector(left);
- ensureMatrixBlockRowVector(right);
- res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
- }
- case scalar_scalar -> {
- return new MatrixBlock(left.get(0,0)*right.get(0,0));
- }
- default -> {
- throw new IllegalArgumentException("Unexpected value: " +
bin.operand.toString());
- }
+ case A_B -> {
+ ensureMatrixBlockColumnVector(left);
+ ensureMatrixBlockRowVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case scalar_scalar -> {
+ return new
MatrixBlock(left.get(0,0)*right.get(0,0));
+ }
+ default -> {
+ throw new IllegalArgumentException("Unexpected
value: " + bin.operand.toString());
+ }
- }
+ }
if(transposeResult){
ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
res = res.reorgOperations(transpose, new MatrixBlock(),
0, 0, 0);
}
if(c2 == null) ensureMatrixBlockColumnVector(res);
- return res;
- }
+ return res;
+ }
- @Override
- public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character
outChar1, Character outChar2) {
- if (this.operand ==EBinaryOperand.aB_aC){
- if(this.right.c2 == outChar1) { // result is CB so Swap aB and aC
+ @Override
+ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character
outChar1, Character outChar2) {
+ if (this.operand ==EBinaryOperand.aB_aC){
+ if(this.right.c2 == outChar1) { // result is CB so Swap
aB and aC
var tmpLeft = left; left = right; right =
tmpLeft;
- var tmpC1 = c1; c1 = c2; c2 =
tmpC1;
- var tmpDim1 = dim1; dim1 = dim2; dim2 =
tmpDim1;
- }
+ var tmpC1 = c1; c1 = c2; c2 =
tmpC1;
+ var tmpDim1 = dim1; dim1 = dim2; dim2 =
tmpDim1;
+ }
if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left
instanceof EOpNodeFuse fuse && fuse.einsumRewriteType ==
EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB
&&
(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK || ((fuse.dim1
* fuse.dim2 *(fuse.ABs.size()+fuse.BAs.size())) + (right.dim1*right.dim2)) * 8
> 6 * 1024 * 1024)
&&
LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1,
right.dim2, false)) {
@@ -319,16 +303,16 @@ public class EOpNodeBinary extends EOpNode {
return fuse;
}
- left = left.reorderChildrenAndOptimize(this, left.c2, left.c1); //
maybe can be reordered
- if(left.c2 == right.c1) { // check if change happened:
- this.operand = EBinaryOperand.Ba_aC;
- }
+ left = left.reorderChildrenAndOptimize(this, left.c2,
left.c1); // maybe can be reordered
+ if(left.c2 == right.c1) { // check if change happened:
+ this.operand = EBinaryOperand.Ba_aC;
+ }
right = right.reorderChildrenAndOptimize(this,
right.c1, right.c2);
- }else if (this.operand ==EBinaryOperand.Ba_Ca){
+ }else if (this.operand ==EBinaryOperand.Ba_Ca){
if(this.right.c1 == outChar1) { // result is CB so Swap
Ba and Ca
var tmpLeft = left; left = right; right =
tmpLeft;
- var tmpC1 = c1; c1 = c2; c2 =
tmpC1;
- var tmpDim1 = dim1; dim1 = dim2; dim2 =
tmpDim1;
+ var tmpC1 = c1; c1 = c2; c2 =
tmpC1;
+ var tmpDim1 = dim1; dim1 = dim2; dim2 =
tmpDim1;
}
right = right.reorderChildrenAndOptimize(this,
right.c2, right.c1); // maybe can be reordered
@@ -341,10 +325,10 @@ public class EOpNodeBinary extends EOpNode {
right = right.reorderChildrenAndOptimize(this,
right.c1, right.c2);
}
return this;
- }
+ }
// used in the old approach
- public static Triple<Integer, EBinaryOperand, Pair<Character,
Character>> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character,
Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character
outChar1, Character outChar2){
+ public static Triple<Integer, EBinaryOperand, Pair<Character,
Character>> tryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character,
Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character
outChar1, Character outChar2){
Predicate<Character> cannotBeSummed = (c) ->
c == outChar1 || c == outChar2 ||
charToOccurences.get(c) > 2;
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
index fd710d19d1..4906323f8a 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
@@ -22,15 +22,14 @@ package org.apache.sysds.runtime.einsum;
import org.apache.commons.logging.Log;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import java.util.ArrayList;
import java.util.List;
public class EOpNodeData extends EOpNode {
- public int matrixIdx;
- public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2,
int matrixIdx){
- super(c1,c2,dim1,dim2);
- this.matrixIdx = matrixIdx;
- }
+ public int matrixIdx;
+ public EOpNodeData(Character c1, Character c2, Integer dim1, Integer
dim2, int matrixIdx){
+ super(c1,c2,dim1,dim2);
+ this.matrixIdx = matrixIdx;
+ }
@Override
public List<EOpNode> getChildren() {
@@ -40,13 +39,13 @@ public class EOpNodeData extends EOpNode {
public String toString() {
return this.getClass().getSimpleName()+" ("+matrixIdx+")
"+getOutputString();
}
- @Override
- public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int
numOfThreads, Log LOG) {
- return inputs.get(matrixIdx);
- }
+ @Override
+ public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int
numOfThreads, Log LOG) {
+ return inputs.get(matrixIdx);
+ }
- @Override
+ @Override
public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character
outChar1, Character outChar2) {
return this;
- }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
index a999ce68f9..5accf93bcc 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
@@ -38,6 +38,8 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
+import java.util.Map;
+import java.util.Set;
import java.util.function.Function;
import static
org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector;
@@ -48,20 +50,20 @@ public class EOpNodeFuse extends EOpNode {
private EOpNode scalar = null;
public enum EinsumRewriteType{
- // B -> row*vec, A -> row*scalar
- AB_BA_B_A__AB,
+ // B -> row*vec, A -> row*scalar
+ AB_BA_B_A__AB,
AB_BA_A__B,
- AB_BA_B_A__A,
- AB_BA_B_A__,
+ AB_BA_B_A__A,
+ AB_BA_B_A__,
- // scalar from row(AB).dot(B) multiplied by row(AZ)
- AB_BA_B_A_AZ__Z,
+ // scalar from row(AB).dot(B) multiplied by row(AZ)
+ AB_BA_B_A_AZ__Z,
- // AZ: last step is outer matrix multiplication using vector Z
- AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB,
- }
+ // AZ: last step is outer matrix multiplication using vector Z
+ AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB,
+ }
- public EinsumRewriteType einsumRewriteType;
+ public EinsumRewriteType einsumRewriteType;
public List<EOpNode> ABs;
public List<EOpNode> BAs;
public List<EOpNode> Bs;
@@ -78,15 +80,15 @@ public class EOpNodeFuse extends EOpNode {
if (scalar != null) all.add(scalar);
return all;
};
- private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer
dim2, EinsumRewriteType einsumRewriteType, List<EOpNode> ABs, List<EOpNode>
BAs, List<EOpNode> Bs, List<EOpNode> As, List<EOpNode> AZs) {
- super(c1,c2, dim1, dim2);
- this.einsumRewriteType = einsumRewriteType;
+ private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer
dim2, EinsumRewriteType einsumRewriteType, List<EOpNode> ABs, List<EOpNode>
BAs, List<EOpNode> Bs, List<EOpNode> As, List<EOpNode> AZs) {
+ super(c1,c2, dim1, dim2);
+ this.einsumRewriteType = einsumRewriteType;
this.ABs = ABs;
this.BAs = BAs;
this.Bs = Bs;
this.As = As;
this.AZs = AZs;
- }
+ }
public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List<EOpNode>
ABs, List<EOpNode> BAs, List<EOpNode> Bs, List<EOpNode> As, List<EOpNode> AZs,
List<Pair<List<EOpNode>, List<EOpNode>>> AXsAndXs) {
super(null,null,null, null);
switch(einsumRewriteType) {
@@ -137,12 +139,13 @@ public class EOpNodeFuse extends EOpNode {
throw new
RuntimeException("EOpNodeFuse.addScalarAsIntermediate: scalar is undefined for
type "+einsumRewriteType.toString());
}
- public static List<EOpNodeFuse> findFuseOps(ArrayList<EOpNode> operands,
Character outChar1, Character outChar2,
- HashMap<Character, Integer> charToSize, HashMap<Character,
Integer> charToOccurences, ArrayList<EOpNode> ret) {
- ArrayList<EOpNodeFuse> result = new ArrayList<>();
- HashSet<String> matricesChars = new HashSet<>();
- HashMap<Character, HashSet<String>>
matricesCharsStartingWithChar = new HashMap<>();
- HashMap<String, ArrayList<EOpNode>> charsToMatrices = new
HashMap<>();
+ public static List<EOpNodeFuse> findFuseOps(List<EOpNode> operands,
Character outChar1, Character outChar2,
+ Map<Character, Integer> charToSize, Map<Character, Integer>
charToOccurences, List<EOpNode> ret)
+ {
+ List<EOpNodeFuse> result = new ArrayList<>();
+ Set<String> matricesChars = new HashSet<>();
+ Map<Character, HashSet<String>> matricesCharsStartingWithChar =
new HashMap<>();
+ Map<String, ArrayList<EOpNode>> charsToMatrices = new
HashMap<>();
for(EOpNode operand1 : operands) {
String k;
@@ -336,18 +339,19 @@ public class EOpNodeFuse extends EOpNode {
result.add(e);
}
- for(EOpNode n : operands) {
- if(!usedOperands.contains(n)){
- ret.add(n);
- } else {
+ for(EOpNode n : operands) {
+ if(!usedOperands.contains(n)){
+ ret.add(n);
+ } else {
charToOccurences.put(n.c1,
charToOccurences.get(n.c1) - 1);
if(charToOccurences.get(n.c2)!= null)
charToOccurences.put(n.c2,
charToOccurences.get(n.c2)-1);
}
- }
+ }
- return result;
- }
+ return result;
+ }
+ @SuppressWarnings("unused")
public static MatrixBlock compute(EinsumRewriteType rewriteType,
List<MatrixBlock> ABsInput, List<MatrixBlock> mbBAs, List<MatrixBlock> mbBs,
List<MatrixBlock> mbAs, List<MatrixBlock> mbAZs,
Double scalar, int numThreads){
boolean isResultAB =rewriteType ==
EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB;
@@ -358,7 +362,7 @@ public class EOpNodeFuse extends EOpNode {
boolean isResultBZ =rewriteType ==
EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ;
boolean isResultZB =rewriteType ==
EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__ZB;
- ArrayList<MatrixBlock> mbABs = new ArrayList<>(ABsInput);
+ List<MatrixBlock> mbABs = new ArrayList<>(ABsInput);
int bSize = mbABs.get(0).getNumColumns();
int aSize = mbABs.get(0).getNumRows();
if (!mbBAs.isEmpty()) {
@@ -424,19 +428,19 @@ public class EOpNodeFuse extends EOpNode {
return out;
}
- @Override
- public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int
numThreads, Log LOG) {
+ @Override
+ public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int
numThreads, Log LOG) {
final Function<EOpNode, MatrixBlock> eOpNodeToMatrixBlock = n
-> n.computeEOpNode(inputs, numThreads, LOG);
- ArrayList<MatrixBlock> mbABs = new
ArrayList<>(ABs.stream().map(eOpNodeToMatrixBlock).toList());
+ List<MatrixBlock> mbABs = new
ArrayList<>(ABs.stream().map(eOpNodeToMatrixBlock).toList());
List<MatrixBlock> mbBAs =
BAs.stream().map(eOpNodeToMatrixBlock).toList();
List<MatrixBlock> mbBs =
Bs.stream().map(eOpNodeToMatrixBlock).toList();
List<MatrixBlock> mbAs =
As.stream().map(eOpNodeToMatrixBlock).toList();
List<MatrixBlock> mbAZs =
AZs.stream().map(eOpNodeToMatrixBlock).toList();
Double scalar = this.scalar == null ? null :
this.scalar.computeEOpNode(inputs, numThreads, LOG).get(0,0);
return EOpNodeFuse.compute(this.einsumRewriteType, mbABs,
mbBAs, mbBs, mbAs, mbAZs , scalar, numThreads);
- }
+ }
- @Override
+ @Override
public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character
outChar1, Character outChar2) {
ABs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1,
n.c2));
BAs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1,
n.c2));
@@ -444,18 +448,18 @@ public class EOpNodeFuse extends EOpNode {
Bs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1,
n.c2));
AZs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1,
n.c2));
return this;
- }
-
- private static @NotNull List<MatrixBlock>
multiplyVectorsIntoOne(List<MatrixBlock> mbs, int size) {
- MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(),
mbs.get(0).getNumColumns(), false);
- mb.allocateDenseBlock();
- for(int i = 1; i< mbs.size(); i++) { // multiply Bs
- if(i==1)
-
LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0),
mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size);
- else
-
LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0,
size);
- }
- return List.of(mb);
- }
+ }
+
+ private static @NotNull List<MatrixBlock>
multiplyVectorsIntoOne(List<MatrixBlock> mbs, int size) {
+ MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(),
mbs.get(0).getNumColumns(), false);
+ mb.allocateDenseBlock();
+ for(int i = 1; i< mbs.size(); i++) { // multiply Bs
+ if(i==1)
+
LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0),
mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size);
+ else
+
LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0,
size);
+ }
+ return List.of(mb);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
index e46e7ec104..918f1dd3b2 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
@@ -33,7 +33,6 @@ import
org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
-import java.util.ArrayList;
import java.util.List;
public class EOpNodeUnary extends EOpNode {
@@ -47,7 +46,7 @@ public class EOpNodeUnary extends EOpNode {
super(c1, c2, dim1, dim2);
this.child = child;
this.eUnaryOperand = eUnaryOperand;
- }
+ }
@Override
public List<EOpNode> getChildren() {
@@ -58,8 +57,8 @@ public class EOpNodeUnary extends EOpNode {
return this.getClass().getSimpleName()+"
("+eUnaryOperand.toString()+") "+this.getOutputString();
}
- @Override
- public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int
numOfThreads, Log LOG) {
+ @Override
+ public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int
numOfThreads, Log LOG) {
MatrixBlock mb = child.computeEOpNode(inputs, numOfThreads,
LOG);
return switch(eUnaryOperand) {
case DIAG->{
@@ -97,7 +96,7 @@ public class EOpNodeUnary extends EOpNode {
};
}
- @Override
+ @Override
public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character
outChar1, Character outChar2) {
return this;
}
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
index 55692d0109..16c67e5399 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
@@ -24,31 +24,33 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
public class EinsumContext {
- public Integer outRows;
- public Integer outCols;
- public Character outChar1;
- public Character outChar2;
- public HashMap<Character, Integer> charToDimensionSize;
- public String equationString;
- public ArrayList<String> newEquationStringInputsSplit;
- public HashMap<Character, Integer> characterAppearanceCount;
-
- private EinsumContext(){};
- public static EinsumContext getEinsumContext(String eqStr,
ArrayList<MatrixBlock> inputs){
- EinsumContext res = new EinsumContext();
-
- res.equationString = eqStr;
- HashMap<Character, Integer> charToDimensionSize = new
HashMap<>();
- HashMap<Character, Integer> characterAppearanceCount = new HashMap<>();
- ArrayList<String> newEquationStringSplit = new ArrayList<>();
+ public Integer outRows;
+ public Integer outCols;
+ public Character outChar1;
+ public Character outChar2;
+ public Map<Character, Integer> charToDimensionSize;
+ public String equationString;
+ public List<String> newEquationStringInputsSplit;
+ public Map<Character, Integer> characterAppearanceCount;
+
+ private EinsumContext(){};
+ public static EinsumContext getEinsumContext(String eqStr,
List<MatrixBlock> inputs){
+ EinsumContext res = new EinsumContext();
+
+ res.equationString = eqStr;
+ Map<Character, Integer> charToDimensionSize = new HashMap<>();
+ Map<Character, Integer> characterAppearanceCount = new
HashMap<>();
+ List<String> newEquationStringSplit = new ArrayList<>();
Character outChar1 = null;
Character outChar2 = null;
- Iterator<MatrixBlock> it = inputs.iterator();
- MatrixBlock curArr = it.next();
- int i = 0;
+ Iterator<MatrixBlock> it = inputs.iterator();
+ MatrixBlock curArr = it.next();
+ int i = 0;
char c = eqStr.charAt(i);
for(i = 0; i < eqStr.length(); i++) {
@@ -110,11 +112,11 @@ public class EinsumContext {
res.outRows=(outChar1 == null ? 1 :
charToDimensionSize.get(outChar1));
res.outCols=(outChar2 == null ? 1 :
charToDimensionSize.get(outChar2));
- res.outChar1 = outChar1;
- res.outChar2 = outChar2;
- res.newEquationStringInputsSplit = newEquationStringSplit;
+ res.outChar1 = outChar1;
+ res.outChar2 = outChar2;
+ res.newEquationStringInputsSplit = newEquationStringSplit;
res.characterAppearanceCount = characterAppearanceCount;
res.charToDimensionSize = charToDimensionSize;
- return res;
- }
+ return res;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
index 7fdce50d3b..417a1b760b 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
@@ -32,117 +32,117 @@ import java.util.List;
public class EinsumEquationValidator {
- public static <HopOrIdentifier extends ParseInfo> Triple<Long, Long,
Types.DataType> validateEinsumEquationAndReturnDimensions(String
equationString, List<HopOrIdentifier> expressionsOrIdentifiers) throws
LanguageException {
- String[] eqStringParts = equationString.split("->"); // length 2 if
"...->..." , length 1 if "...->"
- boolean isResultScalar = eqStringParts.length == 1;
+ public static <HopOrIdentifier extends ParseInfo> Triple<Long, Long,
Types.DataType> validateEinsumEquationAndReturnDimensions(String
equationString, List<HopOrIdentifier> expressionsOrIdentifiers) throws
LanguageException {
+ String[] eqStringParts = equationString.split("->"); // length
2 if "...->..." , length 1 if "...->"
+ boolean isResultScalar = eqStringParts.length == 1;
- if(expressionsOrIdentifiers == null)
- throw new RuntimeException("Einsum: called
validateEinsumAndReturnDimensions with null list");
+ if(expressionsOrIdentifiers == null)
+ throw new RuntimeException("Einsum: called
validateEinsumAndReturnDimensions with null list");
- HashMap<Character, Long> charToDimensionSize = new HashMap<>();
- Iterator<HopOrIdentifier> it = expressionsOrIdentifiers.iterator();
- HopOrIdentifier currArr = it.next();
- int arrSizeIterator = 0;
- int numberOfMatrices = 1;
- for (int i = 0; i < eqStringParts[0].length(); i++) {
- char c = equationString.charAt(i);
- if(c==' ') continue;
- if(c==','){
- if(!it.hasNext())
- throw new LanguageException("Einsum: Provided less
operands than specified in equation str");
- currArr = it.next();
- arrSizeIterator = 0;
- numberOfMatrices++;
- } else{
- long thisCharDimension = getThisCharDimension(currArr,
arrSizeIterator);
- if (charToDimensionSize.containsKey(c)){
- if (charToDimensionSize.get(c) != thisCharDimension)
- throw new LanguageException("Einsum: Character '" + c
+ "' expected to be dim " + charToDimensionSize.get(c) + ", but found " +
thisCharDimension);
- }else{
- charToDimensionSize.put(c, thisCharDimension);
- }
- arrSizeIterator++;
- }
- }
- if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
- throw new LanguageException("Einsum: Provided more operands than
specified in equation str");
+ HashMap<Character, Long> charToDimensionSize = new HashMap<>();
+ Iterator<HopOrIdentifier> it =
expressionsOrIdentifiers.iterator();
+ HopOrIdentifier currArr = it.next();
+ int arrSizeIterator = 0;
+ int numberOfMatrices = 1;
+ for (int i = 0; i < eqStringParts[0].length(); i++) {
+ char c = equationString.charAt(i);
+ if(c==' ') continue;
+ if(c==','){
+ if(!it.hasNext())
+ throw new LanguageException("Einsum:
Provided less operands than specified in equation str");
+ currArr = it.next();
+ arrSizeIterator = 0;
+ numberOfMatrices++;
+ } else{
+ long thisCharDimension =
getThisCharDimension(currArr, arrSizeIterator);
+ if (charToDimensionSize.containsKey(c)){
+ if (charToDimensionSize.get(c) !=
thisCharDimension)
+ throw new
LanguageException("Einsum: Character '" + c + "' expected to be dim " +
charToDimensionSize.get(c) + ", but found " + thisCharDimension);
+ }else{
+ charToDimensionSize.put(c,
thisCharDimension);
+ }
+ arrSizeIterator++;
+ }
+ }
+ if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
+ throw new LanguageException("Einsum: Provided more
operands than specified in equation str");
- if (isResultScalar)
- return Triple.of(-1l,-1l, Types.DataType.SCALAR);
+ if (isResultScalar)
+ return Triple.of(-1l,-1l, Types.DataType.SCALAR);
- int numberOfOutDimensions = 0;
- Character dim1Char = null;
- long dim1 = 1;
- long dim2 = 1;
- for (int i = 0; i < eqStringParts[1].length(); i++) {
- char c = eqStringParts[1].charAt(i);
- if (c == ' ') continue;
- if (numberOfOutDimensions == 0) {
- dim1Char = c;
- if(!charToDimensionSize.containsKey(c))
- throw new LanguageException("Einsum: Output dimension
'"+c+"' not present in input operands");
- dim1 = charToDimensionSize.get(c);
- } else {
- if(c==dim1Char) throw new LanguageException("Einsum: output
character "+c+" provided multiple times");
- if(!charToDimensionSize.containsKey(c))
- throw new LanguageException("Einsum: Output dimension
'"+c+"' not present in input operands");
- dim2 = charToDimensionSize.get(c);
- }
- numberOfOutDimensions++;
- }
- if (numberOfOutDimensions > 2) {
- throw new LanguageException("Einsum: output matrices with with no.
dims > 2 not supported");
- } else {
- return Triple.of(dim1, dim2, Types.DataType.MATRIX);
- }
- }
+ int numberOfOutDimensions = 0;
+ Character dim1Char = null;
+ long dim1 = 1;
+ long dim2 = 1;
+ for (int i = 0; i < eqStringParts[1].length(); i++) {
+ char c = eqStringParts[1].charAt(i);
+ if (c == ' ') continue;
+ if (numberOfOutDimensions == 0) {
+ dim1Char = c;
+ if(!charToDimensionSize.containsKey(c))
+ throw new LanguageException("Einsum:
Output dimension '"+c+"' not present in input operands");
+ dim1 = charToDimensionSize.get(c);
+ } else {
+ if(c==dim1Char) throw new
LanguageException("Einsum: output character "+c+" provided multiple times");
+ if(!charToDimensionSize.containsKey(c))
+ throw new LanguageException("Einsum:
Output dimension '"+c+"' not present in input operands");
+ dim2 = charToDimensionSize.get(c);
+ }
+ numberOfOutDimensions++;
+ }
+ if (numberOfOutDimensions > 2) {
+ throw new LanguageException("Einsum: output matrices
with with no. dims > 2 not supported");
+ } else {
+ return Triple.of(dim1, dim2, Types.DataType.MATRIX);
+ }
+ }
- public static Types.DataType validateEinsumEquationNoDimensions(String
equationString, int numberOfMatrixInputs) throws LanguageException {
- String[] eqStringParts = equationString.split("->"); // length 2 if
"...->..." , length 1 if "...->"
- boolean isResultScalar = eqStringParts.length == 1;
+ public static Types.DataType validateEinsumEquationNoDimensions(String
equationString, int numberOfMatrixInputs) throws LanguageException {
+ String[] eqStringParts = equationString.split("->"); // length
2 if "...->..." , length 1 if "...->"
+ boolean isResultScalar = eqStringParts.length == 1;
- int numberOfMatrices = 1;
- for (int i = 0; i < eqStringParts[0].length(); i++) {
- char c = eqStringParts[0].charAt(i);
- if(c == ' ') continue;
- if(c == ',')
- numberOfMatrices++;
- }
- if(numberOfMatrixInputs != numberOfMatrices){
- throw new LanguageException("Einsum: Invalid number of
parameters, given: " + numberOfMatrixInputs + ", expected: " +
numberOfMatrices);
- }
+ int numberOfMatrices = 1;
+ for (int i = 0; i < eqStringParts[0].length(); i++) {
+ char c = eqStringParts[0].charAt(i);
+ if(c == ' ') continue;
+ if(c == ',')
+ numberOfMatrices++;
+ }
+ if(numberOfMatrixInputs != numberOfMatrices){
+ throw new LanguageException("Einsum: Invalid number of
parameters, given: " + numberOfMatrixInputs + ", expected: " +
numberOfMatrices);
+ }
- if(isResultScalar){
- return Types.DataType.SCALAR;
- }else {
- int numberOfDimensions = 0;
- Character dim1Char = null;
- for (int i = 0; i < eqStringParts[1].length(); i++) {
- char c = eqStringParts[i].charAt(i);
- if(c == ' ') continue;
- numberOfDimensions++;
- if (numberOfDimensions == 1 && c == dim1Char)
- throw new LanguageException("Einsum: output character
"+c+" provided multiple times");
- dim1Char = c;
- }
+ if(isResultScalar){
+ return Types.DataType.SCALAR;
+ }else {
+ int numberOfDimensions = 0;
+ Character dim1Char = null;
+ for (int i = 0; i < eqStringParts[1].length(); i++) {
+ char c = eqStringParts[i].charAt(i);
+ if(c == ' ') continue;
+ numberOfDimensions++;
+ if (numberOfDimensions == 1 && c == dim1Char)
+ throw new LanguageException("Einsum:
output character "+c+" provided multiple times");
+ dim1Char = c;
+ }
- if (numberOfDimensions > 2) {
- throw new LanguageException("Einsum: output matrices with with
no. dims > 2 not supported");
- } else {
- return Types.DataType.MATRIX;
- }
- }
- }
+ if (numberOfDimensions > 2) {
+ throw new LanguageException("Einsum: output
matrices with with no. dims > 2 not supported");
+ } else {
+ return Types.DataType.MATRIX;
+ }
+ }
+ }
- private static <HopOrIdentifier extends ParseInfo> long
getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
- long thisCharDimension;
- if(currArr instanceof Hop){
- thisCharDimension = arrSizeIterator == 0 ? ((Hop)
currArr).getDim1() : ((Hop) currArr).getDim2();
- } else if(currArr instanceof Identifier){
- thisCharDimension = arrSizeIterator == 0 ? ((Identifier)
currArr).getDim1() : ((Identifier) currArr).getDim2();
- } else {
- throw new RuntimeException("validateEinsumAndReturnDimensions
called with expressions that are not Hop or Identifier: "+ currArr == null ?
"null" : currArr.getClass().toString());
- }
- return thisCharDimension;
- }
+ private static <HopOrIdentifier extends ParseInfo> long
getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
+ long thisCharDimension;
+ if(currArr instanceof Hop){
+ thisCharDimension = arrSizeIterator == 0 ? ((Hop)
currArr).getDim1() : ((Hop) currArr).getDim2();
+ } else if(currArr instanceof Identifier){
+ thisCharDimension = arrSizeIterator == 0 ?
((Identifier) currArr).getDim1() : ((Identifier) currArr).getDim2();
+ } else {
+ throw new
RuntimeException("validateEinsumAndReturnDimensions called with expressions
that are not Hop or Identifier: "+ currArr == null ? "null" :
currArr.getClass().toString());
+ }
+ return thisCharDimension;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
index 8b3c5544e6..40b73ea399 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
@@ -27,6 +27,8 @@ import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
public final class EinsumSpoofRowwise extends SpoofRowwise {
+ private static final long serialVersionUID = -5957679254041639561L;
+
private final int _ABCount;
private final boolean _Bsupplied;
private final int _ACount;
@@ -51,24 +53,23 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
int rix) {
switch(_EinsumRewriteType) {
case AB_BA_B_A__AB -> {
- genexec_AB(a, ai, b, null, c, ci, len, grix,
rix);
+ genexecAB(a, ai, b, null, c, ci, len, grix,
rix);
if(scalars.length != 0) {
LibMatrixMult.vectMultiplyWrite(scalars[0], c, c, ci, ci, len); }
}
case AB_BA_A__B -> {
- genexec_B(a, ai, b, null, c, ci, len, grix,
rix);
+ genexecB(a, ai, b, null, c, ci, len, grix, rix);
}
case AB_BA_B_A__A -> {
- //
HARDCODEDgenexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix);
- genexec_A_or_(a, ai, b, null, c, ci, len, grix,
rix);
+ genexecAor(a, ai, b, null, c, ci, len, grix,
rix);
if(scalars.length != 0) { c[rix] *= scalars[0];
}
}
case AB_BA_B_A__ -> {
- genexec_A_or_(a, ai, b, null, c, ci, len, grix,
rix);
+ genexecAor(a, ai, b, null, c, ci, len, grix,
rix);
if(scalars.length != 0) { c[0] *= scalars[0]; }
}
case AB_BA_B_A_AZ__Z -> {
double[] temp = {0};
- genexec_A_or_(a, ai, b, null, temp, 0, len,
grix, rix);
+ genexecAor(a, ai, b, null, temp, 0, len, grix,
rix);
if(scalars.length != 0) { temp[0] *=
scalars[0]; }
if(_AZCount > 1) {
double[] temp2 = new double[_ZSize];
@@ -85,7 +86,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
}
case AB_BA_A_AZ__BZ -> {
double[] temp = new double[len];
- genexec_B(a, ai, b, null, temp, 0, len, grix,
rix);
+ genexecB(a, ai, b, null, temp, 0, len, grix,
rix);
if(scalars.length != 0) {
LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len);
}
@@ -104,7 +105,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
}
case AB_BA_A_AZ__ZB -> {
double[] temp = new double[len];
- genexec_B(a, ai, b, null, temp, 0, len, grix,
rix);
+ genexecB(a, ai, b, null, temp, 0, len, grix,
rix);
if(scalars.length != 0) {
LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len);
}
@@ -125,7 +126,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
}
}
- private void genexec_AB(double[] a, int ai, SideInput[] b, double[]
scalars, double[] c, int ci, int len, long grix,
+ private void genexecAB(double[] a, int ai, SideInput[] b, double[]
scalars, double[] c, int ci, int len, long grix,
int rix) {
int bi = 0;
double[] TMP1 = null;
@@ -164,7 +165,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
}
}
- private void genexec_B(double[] a, int ai, SideInput[] b, double[]
scalars, double[] c, int ci, int len, long grix,
+ private void genexecB(double[] a, int ai, SideInput[] b, double[]
scalars, double[] c, int ci, int len, long grix,
int rix) {
int bi = 0;
double[] TMP1 = null;
@@ -190,7 +191,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
}
}
- private void genexec_A_or_(double[] a, int ai, SideInput[] b, double[]
scalars, double[] c, int ci, int len, long grix, int rix) {
+ private void genexecAor(double[] a, int ai, SideInput[] b, double[]
scalars, double[] c, int ci, int len, long grix, int rix) {
int bi = 0;
double[] TMP1 = null;
double TMP2 = 0;
@@ -217,14 +218,6 @@ public final class EinsumSpoofRowwise extends SpoofRowwise
{
else c[0] += TMP2;
}
- private void HARDCODEDgenexec_A_or_(double[] a, int ai, SideInput[] b,
double[] scalars, double[] c, int ci,
- int len, long grix, int rix) {
- double[] TMP1 = LibSpoofPrimitives.vectMultWrite(a,
b[0].values(rix), ai, ai, len);
- double TMP2 = LibSpoofPrimitives.dotProduct(TMP1,
b[1].values(0), 0, ai, len);
- TMP2 *= b[2].values(0)[rix];
- c[rix] = TMP2;
- }
-
protected void genexec(double[] avals, int[] aix, int ai, SideInput[]
b, double[] scalars, double[] c, int ci,
int alen, int len, long grix, int rix) {
throw new RuntimeException("Sparse fused einsum not
implemented");
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
index b8af61d35b..2b1074c80c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
@@ -31,24 +31,40 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
-import org.apache.sysds.runtime.codegen.*;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
+import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.einsum.*;
+import org.apache.sysds.runtime.einsum.EOpNode;
+import org.apache.sysds.runtime.einsum.EOpNodeBinary;
import org.apache.sysds.runtime.einsum.EOpNodeBinary.EBinaryOperand;
-import org.apache.sysds.runtime.functionobjects.*;
+import org.apache.sysds.runtime.einsum.EOpNodeData;
+import org.apache.sysds.runtime.einsum.EOpNodeFuse;
+import org.apache.sysds.runtime.einsum.EOpNodeUnary;
+import org.apache.sysds.runtime.einsum.EinsumContext;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.utils.Explain;
-import java.util.*;
import static org.apache.sysds.api.DMLScript.EXPLAIN;
import static
org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
- public static final boolean FORCE_CELL_TPL = false; // naive looped
solution
+ public static final boolean FORCE_CELL_TPL = false; // naive looped
solution
public static final boolean FUSE_OUTER_MULTIPLY = true;
public static final boolean FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK
= true;
@@ -63,7 +79,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
public EinsumCPInstruction(Operator op, String opcode, String istr,
CPOperand out, CPOperand... inputs)
{
super(op, opcode, istr, out, inputs);
- _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2;
+ _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2;
_in = inputs;
this.eqStr = inputs[0].getName();
if (PRINT_TRACE)
@@ -80,9 +96,9 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
if(mb instanceof CompressedMatrixBlock){
mb = ((CompressedMatrixBlock)
mb).getUncompressed("Spoof instruction");
}
- if(mb.getNumRows() == 1){
- ensureMatrixBlockColumnVector(mb);
- }
+ if(mb.getNumRows() == 1){
+ ensureMatrixBlockColumnVector(mb);
+ }
inputs.add(mb);
}
}
@@ -93,11 +109,11 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
if( LOG.isTraceEnabled() ) LOG.trace("output: "+resultString +"
"+einc.outRows+"x"+einc.outCols);
- ArrayList<String> inputsChars =
einc.newEquationStringInputsSplit;
+ List<String> inputsChars = einc.newEquationStringInputsSplit;
if(LOG.isTraceEnabled())
LOG.trace(String.join(",",einc.newEquationStringInputsSplit));
- ArrayList<EOpNode> eOpNodes = new
ArrayList<>(inputsChars.size());
- ArrayList<EOpNode> eOpNodesScalars = new
ArrayList<>(inputsChars.size()); // computed separately and not included into
plan until it is already created
+ List<EOpNode> eOpNodes = new ArrayList<>(inputsChars.size());
+ List<EOpNode> eOpNodesScalars = new
ArrayList<>(inputsChars.size()); // computed separately and not included into
plan until it is already created
//make all vetors col vectors
for(int i = 0; i < inputs.size(); i++){
@@ -106,7 +122,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
addSumDimensionsDiagonalsAndScalars(einc, inputsChars,
eOpNodes, eOpNodesScalars, einc.charToDimensionSize);
- HashMap<Character, Integer> characterToOccurences =
einc.characterAppearanceCount;
+ Map<Character, Integer> characterToOccurences =
einc.characterAppearanceCount;
for (int i = 0; i < inputsChars.size(); i++) {
if (inputsChars.get(i) == null) continue;
@@ -118,37 +134,16 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
eOpNodes.add(n);
}
- ArrayList<EOpNode> ret = new ArrayList<>();
+ List<EOpNode> ret = new ArrayList<>();
addVectorMultiplies(eOpNodes,
eOpNodesScalars,characterToOccurences, einc.outChar1, einc.outChar2, ret);
eOpNodes = ret;
List<EOpNode> plan;
ArrayList<MatrixBlock> remainingMatrices;
- if(!FORCE_CELL_TPL) {
- if(true){
- plan = generateGreedyPlan(eOpNodes,
eOpNodesScalars, einc.charToDimensionSize, characterToOccurences,
einc.outChar1, einc.outChar2);
- }else { // old way: try to do fusion first and then
rest in binary fashion cost based
- List<EOpNodeFuse> fuseOps;
- do {
- ret = new ArrayList<>();
- fuseOps =
EOpNodeFuse.findFuseOps(eOpNodes, einc.outChar1, einc.outChar2,
einc.charToDimensionSize, characterToOccurences, ret);
-
- if(!fuseOps.isEmpty()) {
- for (EOpNodeFuse fuseOp :
fuseOps) {
- if (fuseOp.c1 == null) {
-
eOpNodesScalars.add(fuseOp);
- continue;
- }
- ret.add(fuseOp);
- }
- eOpNodes = ret;
- }
- } while(eOpNodes.size() > 1 &&
!fuseOps.isEmpty());
- Pair<Integer, List<EOpNode>> costAndPlan =
generateBinaryPlanCostBased(0, eOpNodes, einc.charToDimensionSize,
characterToOccurences,
- einc.outChar1, einc.outChar2);
- plan = costAndPlan.getRight();
- }
+ if(!FORCE_CELL_TPL) {
+ plan = generateGreedyPlan(eOpNodes, eOpNodesScalars,
+ einc.charToDimensionSize,
characterToOccurences, einc.outChar1, einc.outChar2);
if(!eOpNodesScalars.isEmpty()){
EOpNode l = eOpNodesScalars.get(0);
for(int i = 1; i < eOpNodesScalars.size(); i++){
@@ -194,7 +189,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
remainingMatrices = executePlan(plan, inputs);
- }else{
+ }else{
plan = eOpNodes;
remainingMatrices = inputs;
}
@@ -208,7 +203,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
ec.setMatrixOutput(output.getName(),
remainingMatrices.get(0));
}
else if(resNode.c1 == einc.outChar2 &&
resNode.c2 == einc.outChar1){
- if( LOG.isTraceEnabled()) LOG.trace("Transposing the final
result");
+ if( LOG.isTraceEnabled())
LOG.trace("Transposing the final result");
ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
MatrixBlock resM =
remainingMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0);
@@ -233,7 +228,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}else{
// use cell template with loops for remaining
ArrayList<MatrixBlock> mbs = remainingMatrices;
- ArrayList<String> chars = new ArrayList<>();
+ List<String> chars = new ArrayList<>();
for (int i = 0; i < plan.size(); i++) {
String s;
@@ -243,7 +238,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
chars.add(s);
}
- ArrayList<Character> summingChars = new ArrayList<>();
+ List<Character> summingChars = new ArrayList<>();
for (Character c : characterToOccurences.keySet()) {
if (c != einc.outChar1 && c != einc.outChar2)
summingChars.add(c);
}
@@ -281,7 +276,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
return new
EOpNodeBinary(addToNode,scalar,EBinaryOperand.AB_scalar);
}
- private static Pair<Integer, EOpNode>
addScalarToPlanFindMinCost(EOpNode plan, HashMap<Character, Integer>
charToSizeMap) {
+ private static Pair<Integer, EOpNode>
addScalarToPlanFindMinCost(EOpNode plan, Map<Character, Integer> charToSizeMap)
{
int thisSize = 0;
if(plan.c1 != null) thisSize += charToSizeMap.get(plan.c1);
if(plan.c2 != null) thisSize += charToSizeMap.get(plan.c2);
@@ -315,8 +310,10 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
return Pair.of(cost, plan);
}
- private static void addVectorMultiplies(ArrayList<EOpNode> eOpNodes,
ArrayList<EOpNode> eOpNodesScalars,HashMap<Character, Integer>
charToOccurences, Character outChar1, Character outChar2,ArrayList<EOpNode>
ret) {
- HashMap<Character, ArrayList<EOpNode>> vectorCharacterToIndices
= new HashMap<>();
+ private static void addVectorMultiplies(List<EOpNode> eOpNodes,
List<EOpNode> eOpNodesScalars,
+ Map<Character, Integer> charToOccurences, Character outChar1,
Character outChar2, List<EOpNode> ret)
+ {
+ Map<Character, List<EOpNode>> vectorCharacterToIndices = new
HashMap<>();
for (int i = 0; i < eOpNodes.size(); i++) {
if (eOpNodes.get(i).c2 == null) {
if
(vectorCharacterToIndices.containsKey(eOpNodes.get(i).c1))
@@ -325,9 +322,9 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
vectorCharacterToIndices.put(eOpNodes.get(i).c1, new
ArrayList<>(Collections.singletonList(eOpNodes.get(i))));
}
}
- HashSet<EOpNode> usedNodes = new HashSet<>();
+ Set<EOpNode> usedNodes = new HashSet<>();
for(Character c : vectorCharacterToIndices.keySet()){
- ArrayList<EOpNode> nodes =
vectorCharacterToIndices.get(c);
+ List<EOpNode> nodes = vectorCharacterToIndices.get(c);
if(nodes.size()==1) continue;
EOpNode left = nodes.get(0);
@@ -358,9 +355,9 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
}
- private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc,
ArrayList<String> inputStrings,
- ArrayList<EOpNode> eOpNodes, ArrayList<EOpNode> eOpNodesScalars,
- HashMap<Character, Integer> charToDimensionSize) {
+ private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc,
List<String> inputStrings,
+ List<EOpNode> eOpNodes, List<EOpNode> eOpNodesScalars,
Map<Character, Integer> charToDimensionSize)
+ {
for(int i = 0; i< inputStrings.size(); i++){
String s = inputStrings.get(i);
if (s.isEmpty()){
@@ -414,9 +411,10 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
}
- private static List<EOpNode> generateGreedyPlan(ArrayList<EOpNode>
eOpNodes,
- ArrayList<EOpNode> eOpNodesScalars, HashMap<Character, Integer>
charToSizeMap, HashMap<Character, Integer> charToOccurences, Character
outChar1, Character outChar2) {
- ArrayList<EOpNode> ret;
+ private static List<EOpNode> generateGreedyPlan(List<EOpNode> eOpNodes,
List<EOpNode> eOpNodesScalars,
+ Map<Character, Integer> charToSizeMap, Map<Character, Integer>
charToOccurences, Character outChar1, Character outChar2)
+ {
+ List<EOpNode> ret;
int lastNumOfOperands = -1;
while(lastNumOfOperands != eOpNodes.size() && eOpNodes.size() >
1){
lastNumOfOperands = eOpNodes.size();
@@ -443,7 +441,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
eOpNodes = ret;
ret = new ArrayList<>();
- ArrayList<List<EOpNode>> matrixMultiplies =
findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences,
+ List<List<EOpNode>> matrixMultiplies =
findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences,
ret);
for(List<EOpNode> list : matrixMultiplies) {
@@ -456,7 +454,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
return eOpNodes;
}
- private static void reverseMMChainIfBeneficial(ArrayList<EOpNode>
mmChain){ // possibly check the cost instead of number of transposes
+ private static void reverseMMChainIfBeneficial(List<EOpNode> mmChain){
// possibly check the cost instead of number of transposes
char c1 = mmChain.get(0).c1;
char c2 = mmChain.get(0).c2;
int noTransposes = 0;
@@ -466,7 +464,6 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
noTransposes++;
if(c2 == mmChain.get(i).c2){
- c1 = c1;
c2 = mmChain.get(i).c1;
}
else if(c1 == mmChain.get(i).c1) {
@@ -481,10 +478,10 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
Collections.reverse(mmChain);
}
}
- private static EOpNodeBinary optimizeMMChain(List<EOpNode> mmChainL,
HashMap<Character, Integer> charToSizeMap) {
- ArrayList<EOpNode> mmChain = new ArrayList<>(mmChainL);
+ private static EOpNodeBinary optimizeMMChain(List<EOpNode> mmChainL,
Map<Character, Integer> charToSizeMap) {
+ List<EOpNode> mmChain = new ArrayList<>(mmChainL);
reverseMMChainIfBeneficial(mmChain);
- ArrayList<Pair<Integer, Integer>> dimensions = new
ArrayList<>();
+ List<Pair<Integer, Integer>> dimensions = new ArrayList<>();
for(int i = 0; i < mmChain.size()-1; i++){
EOpNode n1 = mmChain.get(i);
@@ -516,7 +513,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
return EOpNodeBinary.combineMatrixMultiply(left, right);
}
- private static void getDimsArray( ArrayList<Pair<Integer, Integer>>
chain, double[] dimsArray )
+ private static void getDimsArray( List<Pair<Integer, Integer>> chain,
double[] dimsArray )
{
for( int i = 0; i < chain.size(); i++ ) {
if (i == 0) {
@@ -539,11 +536,12 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
}
}
- private static ArrayList<List<EOpNode>>
findMatrixMultiplicationChains(ArrayList<EOpNode> inpOperands, Character
outChar1, Character outChar2, HashMap<Character, Integer> charToOccurences,
- ArrayList<EOpNode> ret) {
- HashSet<Character> charactersThatCanBeContracted = new
HashSet<>();
- HashMap<Character, ArrayList<EOpNode>> characterToNodes = new
HashMap<>();
- ArrayList<EOpNode> operandsTodo = new ArrayList<>();
+ private static List<List<EOpNode>>
findMatrixMultiplicationChains(List<EOpNode> inpOperands,
+ Character outChar1, Character outChar2, Map<Character, Integer>
charToOccurences, List<EOpNode> ret)
+ {
+ Set<Character> charactersThatCanBeContracted = new HashSet<>();
+ Map<Character, ArrayList<EOpNode>> characterToNodes = new
HashMap<>();
+ List<EOpNode> operandsTodo = new ArrayList<>();
for(EOpNode op : inpOperands) {
if(op.c2 == null || op.c1 == null) continue;
@@ -563,9 +561,9 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
if (todo) operandsTodo.add(op);
}
- ArrayList<List<EOpNode>> res = new ArrayList<>();
+ List<List<EOpNode>> res = new ArrayList<>();
- HashSet<EOpNode> doneNodes = new HashSet<>();
+ Set<EOpNode> doneNodes = new HashSet<>();
for(int i = 0; i < operandsTodo.size(); i++){
EOpNode iterateNode = operandsTodo.get(i);
@@ -624,6 +622,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
// old way, DFS finds all paths
+ @SuppressWarnings("unused")
private Pair<Integer, List<EOpNode>> generateBinaryPlanCostBased(int
cost, ArrayList<EOpNode> operands, HashMap<Character, Integer> charToSizeMap,
HashMap<Character, Integer> charToOccurences, Character outChar1, Character
outChar2) {
Integer minCost = cost;
List<EOpNode> minNodes = operands;
@@ -632,7 +631,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
boolean swap = (operands.get(0).c2 == null &&
operands.get(1).c2 != null) || operands.get(0).c1 == null;
EOpNode n1 = operands.get(!swap ? 0 : 1);
EOpNode n2 = operands.get(!swap ? 1 : 0);
- Triple<Integer, EBinaryOperand, Pair<Character,
Character>> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap,
charToOccurences, outChar1, outChar2);
+ Triple<Integer, EBinaryOperand, Pair<Character,
Character>> t = EOpNodeBinary.tryCombineAndCost(n1, n2, charToSizeMap,
charToOccurences, outChar1, outChar2);
if (t != null) {
EOpNodeBinary newNode = new EOpNodeBinary(n1,
n2, t.getMiddle());
int thisCost = cost + t.getLeft();
@@ -650,7 +649,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
EOpNode n1 = operands.get(!swap ? i : j);
EOpNode n2 = operands.get(!swap ? j : i);
- Triple<Integer, EBinaryOperand, Pair<Character,
Character>> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap,
charToOccurences, outChar1, outChar2);
+ Triple<Integer, EBinaryOperand, Pair<Character,
Character>> t = EOpNodeBinary.tryCombineAndCost(n1, n2, charToSizeMap,
charToOccurences, outChar1, outChar2);
if (t != null){
EOpNodeBinary newNode = new
EOpNodeBinary(n1, n2, t.getMiddle());
int thisCost = cost + t.getLeft();
@@ -687,10 +686,10 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
return Pair.of(minCost, minNodes);
}
- private ArrayList<MatrixBlock> executePlan(List<EOpNode> plan,
ArrayList<MatrixBlock> inputs) {
+ private ArrayList<MatrixBlock> executePlan(List<EOpNode> plan,
List<MatrixBlock> inputs) {
ArrayList<MatrixBlock> res = new ArrayList<>(plan.size());
for(EOpNode p : plan){
- res.add(p.computeEOpNode(inputs, _numThreads, LOG));
+ res.add(p.computeEOpNode(inputs, _numThreads, LOG));
}
return res;
}
@@ -708,7 +707,7 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
mb.getDenseBlock().resetNoFill(mb.getNumRows(),1);
}
}
- public static void ensureMatrixBlockRowVector(MatrixBlock mb){
+ public static void ensureMatrixBlockRowVector(MatrixBlock mb){
if(mb.getNumRows() > 1){
mb.setNumColumns(mb.getNumRows());
mb.setNumRows(1);
@@ -722,10 +721,12 @@ public class EinsumCPInstruction extends
BuiltinNaryCPInstruction {
}
}
- private MatrixBlock computeCellSummation(ArrayList<MatrixBlock> inputs,
List<String> inputsChars, String resultString,
-
HashMap<Character, Integer>
charToDimensionSizeInt, List<Character> summingChars, int outRows, int outCols){
+ private MatrixBlock computeCellSummation(ArrayList<MatrixBlock> inputs,
List<String> inputsChars,
+ String resultString, Map<Character, Integer>
charToDimensionSizeInt,
+ List<Character> summingChars, int outRows, int outCols)
+ {
ArrayList<CNode> dummyIn = new ArrayList<>();
- dummyIn.add(new CNodeData(new LiteralOp(0), 0, 0, DataType.SCALAR));
+ dummyIn.add(new CNodeData(new LiteralOp(0), 0, 0,
DataType.SCALAR));
CNodeCell cnode = new CNodeCell(dummyIn, null);
StringBuilder sb = new StringBuilder();
diff --git
a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
index 97e1eb83bf..f86da127d6 100644
--- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
@@ -46,7 +46,7 @@ import java.util.Map;
@RunWith(Parameterized.class)
public class EinsumTest extends AutomatedTestBase
{
- final private static List<Config> TEST_CONFIGS = List.of(
+ final private static List<Config> TEST_CONFIGS = List.of(
new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), //
mm
new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))),
new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))),
@@ -78,23 +78,23 @@ public class EinsumTest extends AutomatedTestBase
new Config("ij,i->i", List.of(shape(10, 5), shape(10))),
new Config("ij,i->j", List.of(shape(10, 5), shape(10))),
- new Config("i,i->", List.of(shape(5), shape(5))), // dot
- new Config("i,j->", List.of(shape(5), shape(80))), // sum
- new Config("i,j->ij", List.of(shape(5), shape(80))), //
outer vect mult
- new Config("i,j->ji", List.of(shape(5), shape(80))), //
outer vect mult
+ new Config("i,i->", List.of(shape(5), shape(5))), // dot
+ new Config("i,j->", List.of(shape(5), shape(80))), // sum
+ new Config("i,j->ij", List.of(shape(5), shape(80))), //
outer vect mult
+ new Config("i,j->ji", List.of(shape(5), shape(80))), //
outer vect mult
- new Config("ij->", List.of(shape(10, 5))), // sum
- new Config("i->", List.of(shape(10))), // sum
- new Config("ij->i", List.of(shape(10, 5))), // sum(1)
- new Config("ij->j", List.of(shape(10, 5))), // sum(0)
- new Config("ij->ji", List.of(shape(10, 5))), // T
- new Config("ij->ij", List.of(shape(10, 5))),
- new Config("i->i", List.of(shape(10))),
- new Config("ii->i", List.of(shape(10, 10))), // Diag
- new Config("ii->", List.of(shape(10, 10))), // Trace
- new Config("ii,i->i", List.of(shape(10, 10),shape(10))), //
Diag*vec
+ new Config("ij->", List.of(shape(10, 5))), // sum
+ new Config("i->", List.of(shape(10))), // sum
+ new Config("ij->i", List.of(shape(10, 5))), // sum(1)
+ new Config("ij->j", List.of(shape(10, 5))), // sum(0)
+ new Config("ij->ji", List.of(shape(10, 5))), // T
+ new Config("ij->ij", List.of(shape(10, 5))),
+ new Config("i->i", List.of(shape(10))),
+ new Config("ii->i", List.of(shape(10, 10))), // Diag
+ new Config("ii->", List.of(shape(10, 10))), // Trace
+ new Config("ii,i->i", List.of(shape(10, 10),shape(10))), //
Diag*vec
- new Config("ab,cd->ba", List.of(shape( 6, 10), shape(6,
5))), // sum cd to scalar and multiply ab
+ new Config("ab,cd->ba", List.of(shape( 6, 10), shape(6, 5))),
// sum cd to scalar and multiply ab
new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (the
cell tpl fallback)
List.of(shape(5, 6), shape(5, 3), shape(5, 10),
shape(6, 3), shape(10, 6), shape(10, 3))),
@@ -137,7 +137,7 @@ public class EinsumTest extends AutomatedTestBase
new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5),
shape(10, 5),shape(10),shape(10,200),shape(200,7)))
,new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5),
shape(10, 5),shape(10),shape(10,20),shape(20,7)))
,new Config("ab,ab,bc,bc->bc", List.of(shape(10, 5),
shape(10, 5),shape(5,20),shape(5,20)))
- );
+ );
private final int id;
private final String einsumStr;
private final File dmlFile;
@@ -312,36 +312,36 @@ public class EinsumTest extends AutomatedTestBase
}
private static class Config {
- public List<Double> factors;
+ public List<Double> factors;
String einsumStr;
List<int[]> shapes;
- Config(String einsum, List<int[]> shapes) {
- this(einsum,shapes,null);
- }
- Config(String einsum, Map<Character, Integer> charToSize){
- this(einsum, charToSize, null);
- }
-
- Config(String einsum, Map<Character, Integer> charToSize, List<Double>
factors) {
- this.einsumStr = einsum;
- String leftPart = einsum.split("->")[0];
- List<int[]> shapes = new ArrayList<>();
- for(String op :
Arrays.stream(leftPart.split(",")).map(x->x.trim()).toList()){
- if (op.length() == 1) {
- shapes.add(new int[]{charToSize.get(op.charAt(0))});
- }else{
- shapes.add(new
int[]{charToSize.get(op.charAt(0)),charToSize.get(op.charAt(1))});
- }
-
- }
- this.shapes = shapes;
- this.factors = factors;
- }
+ Config(String einsum, List<int[]> shapes) {
+ this(einsum,shapes,null);
+ }
+ Config(String einsum, Map<Character, Integer> charToSize){
+ this(einsum, charToSize, null);
+ }
+
+ Config(String einsum, Map<Character, Integer> charToSize,
List<Double> factors) {
+ this.einsumStr = einsum;
+ String leftPart = einsum.split("->")[0];
+ List<int[]> shapes = new ArrayList<>();
+ for(String op :
Arrays.stream(leftPart.split(",")).map(x->x.trim()).toList()){
+ if (op.length() == 1) {
+ shapes.add(new
int[]{charToSize.get(op.charAt(0))});
+ }else{
+ shapes.add(new
int[]{charToSize.get(op.charAt(0)),charToSize.get(op.charAt(1))});
+ }
+
+ }
+ this.shapes = shapes;
+ this.factors = factors;
+ }
Config(String einsum, List<int[]> shapes, List<Double> factors)
{
this.einsumStr = einsum;
this.shapes = shapes;
- this.factors = factors;
+ this.factors = factors;
}
}