This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 61bb660145 [SYSTEMDS-3938] Fix einsum codestyle and code quality
61bb660145 is described below

commit 61bb660145167f391c6525604a5103e6b0ddd011
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Dec 31 07:20:34 2025 +0100

    [SYSTEMDS-3938] Fix einsum codestyle and code quality
---
 dev/checkstyle/suppressions.xml                    |  13 --
 .../org/apache/sysds/runtime/einsum/EOpNode.java   |  28 +--
 .../apache/sysds/runtime/einsum/EOpNodeBinary.java | 238 ++++++++++-----------
 .../apache/sysds/runtime/einsum/EOpNodeData.java   |  23 +-
 .../apache/sysds/runtime/einsum/EOpNodeFuse.java   |  96 +++++----
 .../apache/sysds/runtime/einsum/EOpNodeUnary.java  |   9 +-
 .../apache/sysds/runtime/einsum/EinsumContext.java |  52 ++---
 .../runtime/einsum/EinsumEquationValidator.java    | 208 +++++++++---------
 .../sysds/runtime/einsum/EinsumSpoofRowwise.java   |  31 ++-
 .../instructions/cp/EinsumCPInstruction.java       | 151 ++++++-------
 .../sysds/test/functions/einsum/EinsumTest.java    |  82 +++----
 11 files changed, 450 insertions(+), 481 deletions(-)

diff --git a/dev/checkstyle/suppressions.xml b/dev/checkstyle/suppressions.xml
index 642bf3beba..9f17ccdd8b 100644
--- a/dev/checkstyle/suppressions.xml
+++ b/dev/checkstyle/suppressions.xml
@@ -40,7 +40,6 @@
     <suppress checks="AvoidStarImportCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]resource[\\/]enumeration[\\/]GridBasedEnumerator\.java$"/>
     <suppress checks="AvoidStarImportCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]resource[\\/]enumeration[\\/]InterestBasedEnumerator\.java$"/>
     <suppress checks="AvoidStarImportCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]resource[\\/]enumeration[\\/]PruneBasedEnumerator\.java$"/>
-    <suppress checks="AvoidStarImportCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]cp[\\/]EinsumCPInstruction\.java$"/>
     <suppress checks="AvoidStarImportCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]context[\\/]CSRPointer\.java$"/>
     <suppress checks="AvoidStarImportCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]io[\\/]ReaderCOGParallel\.java$"/>
     <suppress checks="AvoidStarImportCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]io[\\/]ReaderCOG\.java$"/>
@@ -103,8 +102,6 @@
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]hops[\\/]fedplanner[\\/]FederatedPlannerLogger\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]protobuf[\\/]SysdsProtos\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]compress[\\/]lib[\\/]CLALibBinaryCellOp\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EinsumContext\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EinsumEquationValidator\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]DnnGPUInstruction\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]context[\\/]CudaMemoryAllocator\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]gpu[\\/]context[\\/]GPUContextPool\.java$"/>
@@ -130,14 +127,7 @@
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]functions[\\/]transform[\\/]TransformFrameEncodeBagOfWords\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]gpu[\\/]cudaSupportFunctions[\\/]CudaCublasGeamTest\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]gpu[\\/]cudaSupportFunctions[\\/]CudaCusparseCsrGemmTest\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeBinary\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeData\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeFuse\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeUnary\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNode\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]cp[\\/]EinsumCPInstruction\.java$"/>
     <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]matrix[\\/]data[\\/]LibMatrixMult\.java$"/>
-    <suppress checks="RegexpMultilineCheck" 
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]functions[\\/]einsum[\\/]EinsumTest\.java$"/>
 
     <!-- LocalVariableNameCheck -->
     <suppress checks="LocalVariableNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]api[\\/]PythonDMLScript\.java$"/>
@@ -193,7 +183,6 @@
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]compress[\\/]lib[\\/]CLALibRightMultBy\.java$"/>
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]controlprogram[\\/]federated[\\/]FederatedSSLUtil\.java$"/>
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]controlprogram[\\/]paramserv[\\/]ParamServer\.java$"/>
-    <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]cp[\\/]EinsumCPInstruction\.java$"/>
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]instructions[\\/]spark[\\/]UnaryMatrixSPInstruction\.java$"/>
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]io[\\/]hdf5[\\/]H5\.java$"/>
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]iogen[\\/]ReaderMappingIndex\.java$"/>
@@ -358,6 +347,4 @@
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]gpu[\\/]nn[\\/]DNNOperationsGPUTest\.java$"/>
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]usertest[\\/]UserInterfaceTest\.java$"/>
     <suppress checks="MethodNameCheck" 
files=".*src[\\/]test[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]test[\\/]usertest[\\/]pythonapi[\\/]StartupTest\.java$"/>
-    <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EOpNodeBinary\.java$"/>
-    <suppress checks="MethodNameCheck" 
files=".*src[\\/]main[\\/]java[\\/]org[\\/]apache[\\/]sysds[\\/]runtime[\\/]einsum[\\/]EinsumSpoofRowwise\.java$"/>
 </suppressions>
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
index e205c2721a..665e5ae558 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java
@@ -27,26 +27,26 @@ import java.util.Arrays;
 import java.util.List;
 
 public abstract class EOpNode {
-    public Character c1;
-    public Character c2;
+       public Character c1;
+       public Character c2;
        public Integer dim1;
        public Integer dim2;
-    public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
-        this.c1 = c1;
-        this.c2 = c2;
+       public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
+               this.c1 = c1;
+               this.c2 = c2;
                this.dim1 = dim1;
                this.dim2 = dim2;
-    }
+       }
 
-    public String getOutputString() {
-        if(c1 == null) return "''";
-        if(c2 == null) return c1.toString();
-        return c1.toString() + c2.toString();
-    }
+       public String getOutputString() {
+               if(c1 == null) return "''";
+               if(c2 == null) return c1.toString();
+               return c1.toString() + c2.toString();
+       }
        public abstract List<EOpNode> getChildren();
 
        public String[] recursivePrintString(){
-               ArrayList<String[]> inpStrings = new ArrayList<>();
+               List<String[]> inpStrings = new ArrayList<>();
                for (EOpNode node : getChildren()) {
                        inpStrings.add(node.recursivePrintString());
                }
@@ -61,8 +61,8 @@ public abstract class EOpNode {
                return res;
        };
 
-    public abstract MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, 
int numOfThreads, Log LOG);
+       public abstract MatrixBlock computeEOpNode(List<MatrixBlock> inputs, 
int numOfThreads, Log LOG);
 
-    public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent, 
Character outChar1, Character outChar2);
+       public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent, 
Character outChar1, Character outChar2);
 }
 
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
index d291712169..071fb6706a 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java
@@ -22,16 +22,13 @@ package org.apache.sysds.runtime.einsum;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.lang3.tuple.Triple;
 import org.apache.commons.logging.Log;
-import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
 import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
@@ -47,14 +44,14 @@ import static 
org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensur
 public class EOpNodeBinary extends EOpNode {
 
        public enum EBinaryOperand { // upper case: char remains, lower case: 
summed (reduced) dimension
-        ////// mm:   //////
-        Ba_aC, // -> BC
-        aB_Ca, // -> CB
-        Ba_Ca, // -> BC
-        aB_aC, // -> BC
+               ////// mm:   //////
+               Ba_aC, // -> BC
+               aB_Ca, // -> CB
+               Ba_Ca, // -> BC
+               aB_aC, // -> BC
 
-        ////// element-wise multiplications and sums //////
-        aB_aB,// elemwise and colsum -> B
+               ////// element-wise multiplications and sums //////
+               aB_aB,// elemwise and colsum -> B
                Ab_Ab, // elemwise and rowsum ->A
                Ab_bA, // elemwise, either colsum or rowsum -> A
                aB_Ba,
@@ -63,23 +60,23 @@ public class EOpNodeBinary extends EOpNode {
                aB_a,// -> B
                Ab_b, // -> A
 
-        ////// elementwise, no summations:   //////
-        A_A,// v-elemwise -> A
-        AB_AB,// M-M elemwise -> AB
-        AB_BA, // M-M.T elemwise -> AB
-        AB_A, // M-v colwise -> BA!?
+               ////// elementwise, no summations:   //////
+               A_A,// v-elemwise -> A
+               AB_AB,// M-M elemwise -> AB
+               AB_BA, // M-M.T elemwise -> AB
+               AB_A, // M-v colwise -> BA!?
                AB_B, // M-v rowwise -> AB
 
-        ////// other   //////
+               ////// other   //////
                a_a,// dot ->
-        A_B, // outer mult -> AB
-        A_scalar, // v-scalar
-        AB_scalar, // m-scalar
-        scalar_scalar
-    }
-    public EOpNode left;
-    public EOpNode right;
-    public EBinaryOperand operand;
+               A_B, // outer mult -> AB
+               A_scalar, // v-scalar
+               AB_scalar, // m-scalar
+               scalar_scalar
+       }
+       public EOpNode left;
+       public EOpNode right;
+       public EBinaryOperand operand;
        private boolean transposeResult;
        public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand 
operand){
                super(null,null,null, null);
@@ -178,137 +175,124 @@ public class EOpNodeBinary extends EOpNode {
        }
 
        @Override
-    public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int 
numThreads, Log LOG) {
-        EOpNodeBinary bin = this;
-        MatrixBlock left = this.left.computeEOpNode(inputs, numThreads, LOG);
-        MatrixBlock right = this.right.computeEOpNode(inputs, numThreads, LOG);
+       public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int 
numThreads, Log LOG) {
+               EOpNodeBinary bin = this;
+               MatrixBlock left = this.left.computeEOpNode(inputs, numThreads, 
LOG);
+               MatrixBlock right = this.right.computeEOpNode(inputs, 
numThreads, LOG);
 
-        AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
+               //AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
 
-        MatrixBlock res;
+               MatrixBlock res;
 
-        switch (bin.operand){
-            case AB_AB -> {
-                res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
-            }
-            case A_A -> {
-                ensureMatrixBlockColumnVector(left);
-                ensureMatrixBlockColumnVector(right);
-                res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
-            }
-            case a_a -> {
-                ensureMatrixBlockColumnVector(left);
-                ensureMatrixBlockColumnVector(right);
+               switch (bin.operand){
+                       case AB_AB -> {
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                       }
+                       case A_A -> {
+                               ensureMatrixBlockColumnVector(left);
+                               ensureMatrixBlockColumnVector(right);
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                       }
+                       case a_a -> {
+                               ensureMatrixBlockColumnVector(left);
+                               ensureMatrixBlockColumnVector(right);
                                res = new MatrixBlock(0.0);
                                res.allocateDenseBlock();
                                res.getDenseBlockValues()[0] = 
LibMatrixMult.dotProduct(left.getDenseBlockValues(), 
right.getDenseBlockValues(), 0,0 , left.getNumRows());
-            }
-            case Ab_Ab -> {
+                       }
+                       case Ab_Ab -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, 
right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new 
ArrayList<>(),
                                        null, numThreads);
                        }
-            case aB_aB -> {
+                       case aB_aB -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, 
right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new 
ArrayList<>(),
                                        null, numThreads);
-            }
-            case ab_ab -> {
+                       }
+                       case ab_ab -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, 
right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new 
ArrayList<>(),
                                        null, numThreads);
                        }
-            case ab_ba -> {
+                       case ab_ba -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), 
List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
                                        null, numThreads);
                        }
-            case Ab_bA -> {
+                       case Ab_bA -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), 
List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
                                        null, numThreads);
                        }
-            case aB_Ba -> {
+                       case aB_Ba -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), 
List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
                                        null, numThreads);
                        }
-            case AB_BA -> {
-                ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
-                right = right.reorgOperations(transpose, new MatrixBlock(), 0, 
0, 0);
-                res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
-            }
-            case Ba_aC -> {
-                res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), 
numThreads);
-            }
-            case aB_Ca -> {
-                res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), 
numThreads);
-            }
-            case Ba_Ca -> {
-                ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
-                right = right.reorgOperations(transpose, new MatrixBlock(), 0, 
0, 0);
-                res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), 
numThreads);
-            }
-            case aB_aC -> {
-                if(false && 
LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(), 
right.getNumRows(), right.getNumColumns(), false)){
-                    res = new MatrixBlock(left.getNumColumns(), 
right.getNumColumns(),false);
-                    res.allocateDenseBlock();
-                    double[] m1 = left.getDenseBlock().values(0);
-                    double[] m2 = right.getDenseBlock().values(0);
-                    double[] c = res.getDenseBlock().values(0);
-                    int alen = left.getNumColumns();
-                    int blen = right.getNumColumns();
-                    for(int i =0;i<left.getNumRows();i++){
-                        
LibSpoofPrimitives.vectOuterMultAdd(m1,m2,c,i*alen,i*blen, 0,alen,blen);
-                    }
-                }else {
-                    ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
-                    left = left.reorgOperations(transpose, new MatrixBlock(), 
0, 0, 0);
-                    res = LibMatrixMult.matrixMult(left, right, new 
MatrixBlock(), numThreads);
-                }
-            }
-            case A_scalar, AB_scalar -> {
-                res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new 
ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
-            }
-            case AB_B -> {
-                ensureMatrixBlockRowVector(right);
-                res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
-            }
-            case Ab_b -> {
+                       case AB_BA -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+                               right = right.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                       }
+                       case Ba_aC -> {
+                               res = LibMatrixMult.matrixMult(left,right, new 
MatrixBlock(), numThreads);
+                       }
+                       case aB_Ca -> {
+                               res = LibMatrixMult.matrixMult(right,left, new 
MatrixBlock(), numThreads);
+                       }
+                       case Ba_Ca -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+                               right = right.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = LibMatrixMult.matrixMult(left,right, new 
MatrixBlock(), numThreads);
+                       }
+                       case aB_aC -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+                               left = left.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = LibMatrixMult.matrixMult(left, right, new 
MatrixBlock(), numThreads);
+                       }
+                       case A_scalar, AB_scalar -> {
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new 
ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
+                       }
+                       case AB_B -> {
+                               ensureMatrixBlockRowVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                       }
+                       case Ab_b -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), 
new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(),
                                        null, numThreads);
                        }
-            case AB_A -> {
-                ensureMatrixBlockColumnVector(right);
-                res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
-            }
-            case aB_a -> {
+                       case AB_A -> {
+                               ensureMatrixBlockColumnVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                       }
+                       case aB_a -> {
                                res = 
EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), 
new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(),
                                        null, numThreads);
                        }
-            case A_B -> {
-                ensureMatrixBlockColumnVector(left);
-                ensureMatrixBlockRowVector(right);
-                res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
-            }
-            case scalar_scalar -> {
-                return new MatrixBlock(left.get(0,0)*right.get(0,0));
-            }
-            default -> {
-                throw new IllegalArgumentException("Unexpected value: " + 
bin.operand.toString());
-            }
+                       case A_B -> {
+                               ensureMatrixBlockColumnVector(left);
+                               ensureMatrixBlockRowVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                       }
+                       case scalar_scalar -> {
+                               return new 
MatrixBlock(left.get(0,0)*right.get(0,0));
+                       }
+                       default -> {
+                               throw new IllegalArgumentException("Unexpected 
value: " + bin.operand.toString());
+                       }
 
-        }
+               }
                if(transposeResult){
                        ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
                        res = res.reorgOperations(transpose, new MatrixBlock(), 
0, 0, 0);
                }
                if(c2 == null) ensureMatrixBlockColumnVector(res);
-        return res;
-    }
+               return res;
+       }
 
-    @Override
-    public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character 
outChar1, Character outChar2) {
-        if (this.operand ==EBinaryOperand.aB_aC){
-            if(this.right.c2 == outChar1) { // result is CB so Swap aB and aC
+       @Override
+       public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character 
outChar1, Character outChar2) {
+               if (this.operand ==EBinaryOperand.aB_aC){
+                       if(this.right.c2 == outChar1) { // result is CB so Swap 
aB and aC
                                var tmpLeft = left;  left = right;  right = 
tmpLeft;
-                               var tmpC1 = c1;       c1 = c2;         c2 = 
tmpC1;
-                               var tmpDim1 = dim1;   dim1 = dim2;     dim2 = 
tmpDim1;
-            }
+                               var tmpC1 = c1;    c1 = c2;              c2 = 
tmpC1;
+                               var tmpDim1 = dim1;   dim1 = dim2;       dim2 = 
tmpDim1;
+                       }
                        if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left 
instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == 
EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB
                                && 
(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK || ((fuse.dim1 
* fuse.dim2 *(fuse.ABs.size()+fuse.BAs.size())) + (right.dim1*right.dim2)) * 8 
> 6 * 1024 * 1024)
                                && 
LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2,  right.dim1, 
right.dim2, false)) {
@@ -319,16 +303,16 @@ public class EOpNodeBinary extends EOpNode {
                                return fuse;
                        }
 
-            left = left.reorderChildrenAndOptimize(this, left.c2, left.c1); // 
maybe can be reordered
-            if(left.c2 == right.c1) { // check if change happened:
-                this.operand = EBinaryOperand.Ba_aC;
-            }
+                       left = left.reorderChildrenAndOptimize(this, left.c2, 
left.c1); // maybe can be reordered
+                       if(left.c2 == right.c1) { // check if change happened:
+                               this.operand = EBinaryOperand.Ba_aC;
+                       }
                        right =  right.reorderChildrenAndOptimize(this, 
right.c1, right.c2);
-        }else if (this.operand ==EBinaryOperand.Ba_Ca){
+               }else if (this.operand ==EBinaryOperand.Ba_Ca){
                        if(this.right.c1 == outChar1) { // result is CB so Swap 
Ba and Ca
                                var tmpLeft = left;  left = right;  right = 
tmpLeft;
-                               var tmpC1 = c1;       c1 = c2;         c2 = 
tmpC1;
-                               var tmpDim1 = dim1;   dim1 = dim2;     dim2 = 
tmpDim1;
+                               var tmpC1 = c1;    c1 = c2;              c2 = 
tmpC1;
+                               var tmpDim1 = dim1;   dim1 = dim2;       dim2 = 
tmpDim1;
                        }
 
                        right = right.reorderChildrenAndOptimize(this, 
right.c2, right.c1); // maybe can be reordered
@@ -341,10 +325,10 @@ public class EOpNodeBinary extends EOpNode {
                        right = right.reorderChildrenAndOptimize(this, 
right.c1, right.c2);
                }
                return this;
-    }
+       }
 
        // used in the old approach
-       public static Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character, 
Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character 
outChar1, Character outChar2){
+       public static Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> tryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character, 
Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character 
outChar1, Character outChar2){
                Predicate<Character> cannotBeSummed = (c) ->
                        c == outChar1 || c == outChar2 || 
charToOccurences.get(c) > 2;
 
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
index fd710d19d1..4906323f8a 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java
@@ -22,15 +22,14 @@ package org.apache.sysds.runtime.einsum;
 import org.apache.commons.logging.Log;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
-import java.util.ArrayList;
 import java.util.List;
 
 public class EOpNodeData extends EOpNode {
-    public int matrixIdx;
-    public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2, 
int matrixIdx){
-        super(c1,c2,dim1,dim2);
-        this.matrixIdx = matrixIdx;
-    }
+       public int matrixIdx;
+       public EOpNodeData(Character c1, Character c2, Integer dim1, Integer 
dim2, int matrixIdx){
+               super(c1,c2,dim1,dim2);
+               this.matrixIdx = matrixIdx;
+       }
 
        @Override
        public List<EOpNode> getChildren() {
@@ -40,13 +39,13 @@ public class EOpNodeData extends EOpNode {
        public String toString() {
                return this.getClass().getSimpleName()+" ("+matrixIdx+") 
"+getOutputString();
        }
-    @Override
-    public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int 
numOfThreads, Log LOG) {
-        return inputs.get(matrixIdx);
-    }
+       @Override
+       public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int 
numOfThreads, Log LOG) {
+               return inputs.get(matrixIdx);
+       }
 
-    @Override
+       @Override
        public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character 
outChar1, Character outChar2) {
                return this;
-    }
+       }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
index a999ce68f9..5accf93bcc 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java
@@ -38,6 +38,8 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 import java.util.function.Function;
 
 import static 
org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector;
@@ -48,20 +50,20 @@ public class EOpNodeFuse extends EOpNode {
        private EOpNode scalar = null;
 
        public enum EinsumRewriteType{
-        // B -> row*vec, A -> row*scalar
-        AB_BA_B_A__AB,
+               // B -> row*vec, A -> row*scalar
+               AB_BA_B_A__AB,
                AB_BA_A__B,
-        AB_BA_B_A__A,
-        AB_BA_B_A__,
+               AB_BA_B_A__A,
+               AB_BA_B_A__,
 
-        // scalar from row(AB).dot(B) multiplied by row(AZ)
-        AB_BA_B_A_AZ__Z,
+               // scalar from row(AB).dot(B) multiplied by row(AZ)
+               AB_BA_B_A_AZ__Z,
 
-        // AZ: last step is outer matrix multiplication using vector Z
-        AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB,
-    }
+               // AZ: last step is outer matrix multiplication using vector Z
+               AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB,
+       }
 
-    public EinsumRewriteType einsumRewriteType;
+       public EinsumRewriteType einsumRewriteType;
        public List<EOpNode> ABs;
        public List<EOpNode> BAs;
        public List<EOpNode> Bs;
@@ -78,15 +80,15 @@ public class EOpNodeFuse extends EOpNode {
                if (scalar != null) all.add(scalar);
                return all;
        };
-    private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer 
dim2, EinsumRewriteType einsumRewriteType, List<EOpNode> ABs, List<EOpNode> 
BAs, List<EOpNode> Bs, List<EOpNode> As, List<EOpNode> AZs) {
-        super(c1,c2, dim1, dim2);
-        this.einsumRewriteType = einsumRewriteType;
+       private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer 
dim2, EinsumRewriteType einsumRewriteType, List<EOpNode> ABs, List<EOpNode> 
BAs, List<EOpNode> Bs, List<EOpNode> As, List<EOpNode> AZs) {
+               super(c1,c2, dim1, dim2);
+               this.einsumRewriteType = einsumRewriteType;
                this.ABs = ABs;
                this.BAs = BAs;
                this.Bs = Bs;
                this.As = As;
                this.AZs = AZs;
-    }
+       }
        public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List<EOpNode> 
ABs, List<EOpNode> BAs, List<EOpNode> Bs, List<EOpNode> As, List<EOpNode> AZs, 
List<Pair<List<EOpNode>, List<EOpNode>>> AXsAndXs) {
                super(null,null,null, null);
                switch(einsumRewriteType) {
@@ -137,12 +139,13 @@ public class EOpNodeFuse extends EOpNode {
                        throw new 
RuntimeException("EOpNodeFuse.addScalarAsIntermediate: scalar is undefined for 
type "+einsumRewriteType.toString());
        }
 
-    public static List<EOpNodeFuse> findFuseOps(ArrayList<EOpNode> operands, 
Character outChar1, Character outChar2,
-               HashMap<Character, Integer> charToSize, HashMap<Character, 
Integer> charToOccurences, ArrayList<EOpNode> ret) {
-               ArrayList<EOpNodeFuse> result = new ArrayList<>();
-               HashSet<String> matricesChars = new HashSet<>();
-               HashMap<Character, HashSet<String>> 
matricesCharsStartingWithChar = new HashMap<>();
-               HashMap<String, ArrayList<EOpNode>> charsToMatrices = new 
HashMap<>();
+       public static List<EOpNodeFuse> findFuseOps(List<EOpNode> operands, 
Character outChar1, Character outChar2,
+               Map<Character, Integer> charToSize, Map<Character, Integer> 
charToOccurences, List<EOpNode> ret)
+       {
+               List<EOpNodeFuse> result = new ArrayList<>();
+               Set<String> matricesChars = new HashSet<>();
+               Map<Character, HashSet<String>> matricesCharsStartingWithChar = 
new HashMap<>();
+               Map<String, ArrayList<EOpNode>> charsToMatrices = new 
HashMap<>();
 
                for(EOpNode operand1 : operands) {
                        String k;
@@ -336,18 +339,19 @@ public class EOpNodeFuse extends EOpNode {
                        result.add(e);
                }
 
-        for(EOpNode n : operands) {
-            if(!usedOperands.contains(n)){
-                ret.add(n);
-            } else {
+               for(EOpNode n : operands) {
+                       if(!usedOperands.contains(n)){
+                               ret.add(n);
+                       } else {
                                charToOccurences.put(n.c1, 
charToOccurences.get(n.c1) - 1);
                                if(charToOccurences.get(n.c2)!= null)
                                        charToOccurences.put(n.c2, 
charToOccurences.get(n.c2)-1);
                        }
-        }
+               }
 
-        return result;
-    }
+               return result;
+       }
+       @SuppressWarnings("unused")
        public static MatrixBlock compute(EinsumRewriteType rewriteType, 
List<MatrixBlock> ABsInput, List<MatrixBlock> mbBAs, List<MatrixBlock> mbBs, 
List<MatrixBlock> mbAs, List<MatrixBlock> mbAZs,
                Double scalar, int numThreads){
                boolean isResultAB =rewriteType  == 
EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB;
@@ -358,7 +362,7 @@ public class EOpNodeFuse extends EOpNode {
                boolean isResultBZ =rewriteType  == 
EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ;
                boolean isResultZB =rewriteType  == 
EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__ZB;
 
-               ArrayList<MatrixBlock> mbABs = new ArrayList<>(ABsInput);
+               List<MatrixBlock> mbABs = new ArrayList<>(ABsInput);
                int bSize = mbABs.get(0).getNumColumns();
                int aSize = mbABs.get(0).getNumRows();
                if (!mbBAs.isEmpty()) {
@@ -424,19 +428,19 @@ public class EOpNodeFuse extends EOpNode {
 
                return out;
        }
-    @Override
-    public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int 
numThreads, Log LOG) {
+       @Override
+       public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int 
numThreads, Log LOG) {
                final Function<EOpNode, MatrixBlock> eOpNodeToMatrixBlock =  n 
-> n.computeEOpNode(inputs, numThreads, LOG);
-        ArrayList<MatrixBlock> mbABs = new 
ArrayList<>(ABs.stream().map(eOpNodeToMatrixBlock).toList());
+               List<MatrixBlock> mbABs = new 
ArrayList<>(ABs.stream().map(eOpNodeToMatrixBlock).toList());
                List<MatrixBlock> mbBAs = 
BAs.stream().map(eOpNodeToMatrixBlock).toList();
                List<MatrixBlock> mbBs =  
Bs.stream().map(eOpNodeToMatrixBlock).toList();
                List<MatrixBlock> mbAs = 
As.stream().map(eOpNodeToMatrixBlock).toList();
                List<MatrixBlock> mbAZs = 
AZs.stream().map(eOpNodeToMatrixBlock).toList();
                Double scalar = this.scalar == null ? null : 
this.scalar.computeEOpNode(inputs, numThreads, LOG).get(0,0);
                return EOpNodeFuse.compute(this.einsumRewriteType, mbABs, 
mbBAs, mbBs, mbAs, mbAZs , scalar, numThreads);
-    }
+       }
 
-    @Override
+       @Override
        public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character 
outChar1, Character outChar2) {
                ABs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, 
n.c2));
                BAs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, 
n.c2));
@@ -444,18 +448,18 @@ public class EOpNodeFuse extends EOpNode {
                Bs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, 
n.c2));
                AZs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, 
n.c2));
                return this;
-    }
-
-    private static @NotNull List<MatrixBlock> 
multiplyVectorsIntoOne(List<MatrixBlock> mbs, int size) {
-        MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), 
mbs.get(0).getNumColumns(), false);
-        mb.allocateDenseBlock();
-        for(int i = 1; i< mbs.size(); i++) { // multiply Bs
-            if(i==1)
-                
LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), 
mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size);
-            else
-                
LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0,
 size);
-        }
-        return List.of(mb);
-    }
+       }
+
+       private static @NotNull List<MatrixBlock> 
multiplyVectorsIntoOne(List<MatrixBlock> mbs, int size) {
+               MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), 
mbs.get(0).getNumColumns(), false);
+               mb.allocateDenseBlock();
+               for(int i = 1; i< mbs.size(); i++) { // multiply Bs
+                       if(i==1)
+                               
LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), 
mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size);
+                       else
+                               
LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0,
 size);
+               }
+               return List.of(mb);
+       }
 }
 
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
index e46e7ec104..918f1dd3b2 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java
@@ -33,7 +33,6 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 
-import java.util.ArrayList;
 import java.util.List;
 
 public class EOpNodeUnary extends EOpNode {
@@ -47,7 +46,7 @@ public class EOpNodeUnary extends EOpNode {
                super(c1, c2, dim1, dim2);
                this.child = child;
                this.eUnaryOperand = eUnaryOperand;
-    }
+       }
 
        @Override
        public List<EOpNode> getChildren() {
@@ -58,8 +57,8 @@ public class EOpNodeUnary extends EOpNode {
                return this.getClass().getSimpleName()+" 
("+eUnaryOperand.toString()+") "+this.getOutputString();
        }
 
-    @Override
-    public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int 
numOfThreads, Log LOG) {
+       @Override
+       public MatrixBlock computeEOpNode(List<MatrixBlock> inputs, int 
numOfThreads, Log LOG) {
                MatrixBlock mb = child.computeEOpNode(inputs, numOfThreads, 
LOG);
                return switch(eUnaryOperand) {
                        case DIAG->{
@@ -97,7 +96,7 @@ public class EOpNodeUnary extends EOpNode {
                };
        }
 
-    @Override
+       @Override
        public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character 
outChar1, Character outChar2) {
                return this;
        }
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
index 55692d0109..16c67e5399 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
@@ -24,31 +24,33 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
 
 public class EinsumContext {
-    public Integer outRows;
-    public Integer outCols;
-    public Character outChar1;
-    public Character outChar2;
-    public HashMap<Character, Integer> charToDimensionSize;
-    public String equationString;
-    public ArrayList<String> newEquationStringInputsSplit;
-       public HashMap<Character, Integer> characterAppearanceCount;
-
-    private EinsumContext(){};
-    public static EinsumContext getEinsumContext(String eqStr, 
ArrayList<MatrixBlock> inputs){
-        EinsumContext res = new EinsumContext();
-
-        res.equationString = eqStr;
-               HashMap<Character, Integer> charToDimensionSize = new 
HashMap<>();
-        HashMap<Character, Integer> characterAppearanceCount = new HashMap<>();
-               ArrayList<String> newEquationStringSplit = new ArrayList<>();
+       public Integer outRows;
+       public Integer outCols;
+       public Character outChar1;
+       public Character outChar2;
+       public Map<Character, Integer> charToDimensionSize;
+       public String equationString;
+       public List<String> newEquationStringInputsSplit;
+       public Map<Character, Integer> characterAppearanceCount;
+
+       private EinsumContext(){};
+       public static EinsumContext getEinsumContext(String eqStr, 
List<MatrixBlock> inputs){
+               EinsumContext res = new EinsumContext();
+
+               res.equationString = eqStr;
+               Map<Character, Integer> charToDimensionSize = new HashMap<>();
+               Map<Character, Integer> characterAppearanceCount = new 
HashMap<>();
+               List<String> newEquationStringSplit = new ArrayList<>();
                Character outChar1 = null;
                Character outChar2 = null;
 
-        Iterator<MatrixBlock> it = inputs.iterator();
-        MatrixBlock curArr = it.next();
-        int i = 0;
+               Iterator<MatrixBlock> it = inputs.iterator();
+               MatrixBlock curArr = it.next();
+               int i = 0;
 
                char c = eqStr.charAt(i);
                for(i = 0; i < eqStr.length(); i++) {
@@ -110,11 +112,11 @@ public class EinsumContext {
                res.outRows=(outChar1 == null ? 1 : 
charToDimensionSize.get(outChar1));
                res.outCols=(outChar2 == null ? 1 : 
charToDimensionSize.get(outChar2));
 
-        res.outChar1 = outChar1;
-        res.outChar2 = outChar2;
-        res.newEquationStringInputsSplit = newEquationStringSplit;
+               res.outChar1 = outChar1;
+               res.outChar2 = outChar2;
+               res.newEquationStringInputsSplit = newEquationStringSplit;
                res.characterAppearanceCount = characterAppearanceCount;
                res.charToDimensionSize = charToDimensionSize;
-        return res;
-    }
+               return res;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
index 7fdce50d3b..417a1b760b 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
@@ -32,117 +32,117 @@ import java.util.List;
 
 public class EinsumEquationValidator {
 
-    public static <HopOrIdentifier extends ParseInfo> Triple<Long, Long, 
Types.DataType> validateEinsumEquationAndReturnDimensions(String 
equationString, List<HopOrIdentifier> expressionsOrIdentifiers) throws 
LanguageException {
-        String[] eqStringParts = equationString.split("->"); // length 2 if 
"...->..." , length 1 if "...->"
-        boolean isResultScalar = eqStringParts.length == 1;
+       public static <HopOrIdentifier extends ParseInfo> Triple<Long, Long, 
Types.DataType> validateEinsumEquationAndReturnDimensions(String 
equationString, List<HopOrIdentifier> expressionsOrIdentifiers) throws 
LanguageException {
+               String[] eqStringParts = equationString.split("->"); // length 
2 if "...->..." , length 1 if "...->"
+               boolean isResultScalar = eqStringParts.length == 1;
 
-        if(expressionsOrIdentifiers == null)
-            throw new RuntimeException("Einsum: called 
validateEinsumAndReturnDimensions with null list");
+               if(expressionsOrIdentifiers == null)
+                       throw new RuntimeException("Einsum: called 
validateEinsumAndReturnDimensions with null list");
 
-        HashMap<Character, Long> charToDimensionSize = new HashMap<>();
-        Iterator<HopOrIdentifier> it = expressionsOrIdentifiers.iterator();
-        HopOrIdentifier currArr = it.next();
-        int arrSizeIterator = 0;
-        int numberOfMatrices = 1;
-        for (int i = 0; i < eqStringParts[0].length(); i++) {
-            char c = equationString.charAt(i);
-            if(c==' ') continue;
-            if(c==','){
-                if(!it.hasNext())
-                    throw new LanguageException("Einsum: Provided less 
operands than specified in equation str");
-                currArr = it.next();
-                arrSizeIterator = 0;
-                numberOfMatrices++;
-            } else{
-                long thisCharDimension = getThisCharDimension(currArr, 
arrSizeIterator);
-                if (charToDimensionSize.containsKey(c)){
-                    if (charToDimensionSize.get(c) != thisCharDimension)
-                        throw new LanguageException("Einsum: Character '" + c 
+ "' expected to be dim " + charToDimensionSize.get(c) + ", but found " + 
thisCharDimension);
-                }else{
-                    charToDimensionSize.put(c, thisCharDimension);
-                }
-                arrSizeIterator++;
-            }
-        }
-        if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
-            throw new LanguageException("Einsum: Provided more operands than 
specified in equation str");
+               HashMap<Character, Long> charToDimensionSize = new HashMap<>();
+               Iterator<HopOrIdentifier> it = 
expressionsOrIdentifiers.iterator();
+               HopOrIdentifier currArr = it.next();
+               int arrSizeIterator = 0;
+               int numberOfMatrices = 1;
+               for (int i = 0; i < eqStringParts[0].length(); i++) {
+                       char c = equationString.charAt(i);
+                       if(c==' ') continue;
+                       if(c==','){
+                               if(!it.hasNext())
+                                       throw new LanguageException("Einsum: 
Provided less operands than specified in equation str");
+                               currArr = it.next();
+                               arrSizeIterator = 0;
+                               numberOfMatrices++;
+                       } else{
+                               long thisCharDimension = 
getThisCharDimension(currArr, arrSizeIterator);
+                               if (charToDimensionSize.containsKey(c)){
+                                       if (charToDimensionSize.get(c) != 
thisCharDimension)
+                                               throw new 
LanguageException("Einsum: Character '" + c + "' expected to be dim " + 
charToDimensionSize.get(c) + ", but found " + thisCharDimension);
+                               }else{
+                                       charToDimensionSize.put(c, 
thisCharDimension);
+                               }
+                               arrSizeIterator++;
+                       }
+               }
+               if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
+                       throw new LanguageException("Einsum: Provided more 
operands than specified in equation str");
 
-        if (isResultScalar)
-            return Triple.of(-1l,-1l, Types.DataType.SCALAR);
+               if (isResultScalar)
+                       return Triple.of(-1l,-1l, Types.DataType.SCALAR);
 
-        int numberOfOutDimensions = 0;
-        Character dim1Char = null;
-        long dim1 = 1;
-        long dim2 = 1;
-        for (int i = 0; i < eqStringParts[1].length(); i++) {
-            char c = eqStringParts[1].charAt(i);
-            if (c == ' ') continue;
-            if (numberOfOutDimensions == 0) {
-                dim1Char = c;
-                if(!charToDimensionSize.containsKey(c))
-                    throw new LanguageException("Einsum: Output dimension 
'"+c+"' not present in input operands");
-                dim1 = charToDimensionSize.get(c);
-            } else {
-                if(c==dim1Char) throw new LanguageException("Einsum: output 
character "+c+" provided multiple times");
-                if(!charToDimensionSize.containsKey(c))
-                    throw new LanguageException("Einsum: Output dimension 
'"+c+"' not present in input operands");
-                dim2 = charToDimensionSize.get(c);
-            }
-            numberOfOutDimensions++;
-        }
-        if (numberOfOutDimensions > 2) {
-            throw new LanguageException("Einsum: output matrices with with no. 
dims > 2 not supported");
-        } else {
-            return Triple.of(dim1, dim2, Types.DataType.MATRIX);
-        }
-    }
+               int numberOfOutDimensions = 0;
+               Character dim1Char = null;
+               long dim1 = 1;
+               long dim2 = 1;
+               for (int i = 0; i < eqStringParts[1].length(); i++) {
+                       char c = eqStringParts[1].charAt(i);
+                       if (c == ' ') continue;
+                       if (numberOfOutDimensions == 0) {
+                               dim1Char = c;
+                               if(!charToDimensionSize.containsKey(c))
+                                       throw new LanguageException("Einsum: 
Output dimension '"+c+"' not present in input operands");
+                               dim1 = charToDimensionSize.get(c);
+                       } else {
+                               if(c==dim1Char) throw new 
LanguageException("Einsum: output character "+c+" provided multiple times");
+                               if(!charToDimensionSize.containsKey(c))
+                                       throw new LanguageException("Einsum: 
Output dimension '"+c+"' not present in input operands");
+                               dim2 = charToDimensionSize.get(c);
+                       }
+                       numberOfOutDimensions++;
+               }
+               if (numberOfOutDimensions > 2) {
+                       throw new LanguageException("Einsum: output matrices 
with with no. dims > 2 not supported");
+               } else {
+                       return Triple.of(dim1, dim2, Types.DataType.MATRIX);
+               }
+       }
 
-    public static Types.DataType validateEinsumEquationNoDimensions(String 
equationString, int numberOfMatrixInputs) throws LanguageException {
-        String[] eqStringParts = equationString.split("->"); // length 2 if 
"...->..." , length 1 if "...->"
-        boolean isResultScalar = eqStringParts.length == 1;
+       public static Types.DataType validateEinsumEquationNoDimensions(String 
equationString, int numberOfMatrixInputs) throws LanguageException {
+               String[] eqStringParts = equationString.split("->"); // length 
2 if "...->..." , length 1 if "...->"
+               boolean isResultScalar = eqStringParts.length == 1;
 
-        int numberOfMatrices = 1;
-        for (int i = 0; i < eqStringParts[0].length(); i++) {
-            char c = eqStringParts[0].charAt(i);
-            if(c == ' ') continue;
-            if(c == ',')
-                numberOfMatrices++;
-        }
-        if(numberOfMatrixInputs != numberOfMatrices){
-            throw  new LanguageException("Einsum: Invalid number of 
parameters, given: " + numberOfMatrixInputs + ", expected: " + 
numberOfMatrices);
-        }
+               int numberOfMatrices = 1;
+               for (int i = 0; i < eqStringParts[0].length(); i++) {
+                       char c = eqStringParts[0].charAt(i);
+                       if(c == ' ') continue;
+                       if(c == ',')
+                               numberOfMatrices++;
+               }
+               if(numberOfMatrixInputs != numberOfMatrices){
+                       throw  new LanguageException("Einsum: Invalid number of 
parameters, given: " + numberOfMatrixInputs + ", expected: " + 
numberOfMatrices);
+               }
 
-        if(isResultScalar){
-            return Types.DataType.SCALAR;
-        }else {
-            int numberOfDimensions = 0;
-            Character dim1Char = null;
-            for (int i = 0; i < eqStringParts[1].length(); i++) {
-                char c = eqStringParts[i].charAt(i);
-                if(c == ' ') continue;
-                numberOfDimensions++;
-                if (numberOfDimensions == 1 && c == dim1Char)
-                    throw new LanguageException("Einsum: output character 
"+c+" provided multiple times");
-                dim1Char = c;
-            }
+               if(isResultScalar){
+                       return Types.DataType.SCALAR;
+               }else {
+                       int numberOfDimensions = 0;
+                       Character dim1Char = null;
+                       for (int i = 0; i < eqStringParts[1].length(); i++) {
+                               char c = eqStringParts[i].charAt(i);
+                               if(c == ' ') continue;
+                               numberOfDimensions++;
+                               if (numberOfDimensions == 1 && c == dim1Char)
+                                       throw new LanguageException("Einsum: 
output character "+c+" provided multiple times");
+                               dim1Char = c;
+                       }
 
-            if (numberOfDimensions > 2) {
-                throw new LanguageException("Einsum: output matrices with with 
no. dims > 2 not supported");
-            } else {
-                return Types.DataType.MATRIX;
-            }
-        }
-    }
+                       if (numberOfDimensions > 2) {
+                               throw new LanguageException("Einsum: output 
matrices with with no. dims > 2 not supported");
+                       } else {
+                               return Types.DataType.MATRIX;
+                       }
+               }
+       }
 
-    private static <HopOrIdentifier extends ParseInfo> long 
getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
-        long thisCharDimension;
-        if(currArr instanceof Hop){
-            thisCharDimension = arrSizeIterator == 0 ? ((Hop) 
currArr).getDim1()  : ((Hop) currArr).getDim2();
-        } else if(currArr instanceof Identifier){
-            thisCharDimension = arrSizeIterator == 0 ? ((Identifier) 
currArr).getDim1()  : ((Identifier) currArr).getDim2();
-        } else {
-            throw new RuntimeException("validateEinsumAndReturnDimensions 
called with expressions that are not Hop or Identifier: "+ currArr == null ? 
"null" : currArr.getClass().toString());
-        }
-        return thisCharDimension;
-    }
+       private static <HopOrIdentifier extends ParseInfo> long 
getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
+               long thisCharDimension;
+               if(currArr instanceof Hop){
+                       thisCharDimension = arrSizeIterator == 0 ? ((Hop) 
currArr).getDim1()  : ((Hop) currArr).getDim2();
+               } else if(currArr instanceof Identifier){
+                       thisCharDimension = arrSizeIterator == 0 ? 
((Identifier) currArr).getDim1()  : ((Identifier) currArr).getDim2();
+               } else {
+                       throw new 
RuntimeException("validateEinsumAndReturnDimensions called with expressions 
that are not Hop or Identifier: "+ currArr == null ? "null" : 
currArr.getClass().toString());
+               }
+               return thisCharDimension;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
index 8b3c5544e6..40b73ea399 100644
--- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java
@@ -27,6 +27,8 @@ import org.apache.sysds.runtime.codegen.SpoofRowwise;
 import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
 
 public final class EinsumSpoofRowwise extends SpoofRowwise {
+       private static final long serialVersionUID = -5957679254041639561L;
+       
        private final int _ABCount;
        private final boolean _Bsupplied;
        private final int _ACount;
@@ -51,24 +53,23 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
                int rix) {
                switch(_EinsumRewriteType) {
                        case AB_BA_B_A__AB -> {
-                               genexec_AB(a, ai, b, null, c, ci, len, grix, 
rix);
+                               genexecAB(a, ai, b, null, c, ci, len, grix, 
rix);
                                if(scalars.length != 0) { 
LibMatrixMult.vectMultiplyWrite(scalars[0], c, c, ci, ci, len); }
                        }
                        case AB_BA_A__B -> {
-                               genexec_B(a, ai, b, null, c, ci, len, grix, 
rix);
+                               genexecB(a, ai, b, null, c, ci, len, grix, rix);
                        }
                        case AB_BA_B_A__A -> {
-                               //                              
HARDCODEDgenexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix);
-                               genexec_A_or_(a, ai, b, null, c, ci, len, grix, 
rix);
+                               genexecAor(a, ai, b, null, c, ci, len, grix, 
rix);
                                if(scalars.length != 0) { c[rix] *= scalars[0]; 
}
                        }
                        case AB_BA_B_A__ -> {
-                               genexec_A_or_(a, ai, b, null, c, ci, len, grix, 
rix);
+                               genexecAor(a, ai, b, null, c, ci, len, grix, 
rix);
                                if(scalars.length != 0) { c[0] *= scalars[0]; }
                        }
                        case AB_BA_B_A_AZ__Z -> {
                                double[] temp = {0};
-                               genexec_A_or_(a, ai, b, null, temp, 0, len, 
grix, rix);
+                               genexecAor(a, ai, b, null, temp, 0, len, grix, 
rix);
                                if(scalars.length != 0) { temp[0] *= 
scalars[0]; }
                                if(_AZCount > 1) {
                                        double[] temp2 = new double[_ZSize];
@@ -85,7 +86,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
                        }
                        case AB_BA_A_AZ__BZ -> {
                                double[] temp = new double[len];
-                               genexec_B(a, ai, b, null, temp, 0, len, grix, 
rix);
+                               genexecB(a, ai, b, null, temp, 0, len, grix, 
rix);
                                if(scalars.length != 0) {
                                        
LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len);
                                }
@@ -104,7 +105,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
                        }
                        case AB_BA_A_AZ__ZB -> {
                                double[] temp = new double[len];
-                               genexec_B(a, ai, b, null, temp, 0, len, grix, 
rix);
+                               genexecB(a, ai, b, null, temp, 0, len, grix, 
rix);
                                if(scalars.length != 0) {
                                        
LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len);
                                }
@@ -125,7 +126,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
                }
        }
 
-       private void genexec_AB(double[] a, int ai, SideInput[] b, double[] 
scalars, double[] c, int ci, int len, long grix,
+       private void genexecAB(double[] a, int ai, SideInput[] b, double[] 
scalars, double[] c, int ci, int len, long grix,
                int rix) {
                int bi = 0;
                double[] TMP1 = null;
@@ -164,7 +165,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
                }
        }
 
-       private void genexec_B(double[] a, int ai, SideInput[] b, double[] 
scalars, double[] c, int ci, int len, long grix,
+       private void genexecB(double[] a, int ai, SideInput[] b, double[] 
scalars, double[] c, int ci, int len, long grix,
                int rix) {
                int bi = 0;
                double[] TMP1 = null;
@@ -190,7 +191,7 @@ public final class EinsumSpoofRowwise extends SpoofRowwise {
                }
        }
 
-       private void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] 
scalars, double[] c, int ci, int len, long grix, int rix) {
+       private void genexecAor(double[] a, int ai, SideInput[] b, double[] 
scalars, double[] c, int ci, int len, long grix, int rix) {
                int bi = 0;
                double[] TMP1 = null;
                double TMP2 = 0;
@@ -217,14 +218,6 @@ public final class EinsumSpoofRowwise extends SpoofRowwise 
{
                else c[0] += TMP2;
        }
 
-       private void HARDCODEDgenexec_A_or_(double[] a, int ai, SideInput[] b, 
double[] scalars, double[] c, int ci,
-               int len, long grix, int rix) {
-               double[] TMP1 = LibSpoofPrimitives.vectMultWrite(a, 
b[0].values(rix), ai, ai, len);
-               double TMP2 = LibSpoofPrimitives.dotProduct(TMP1, 
b[1].values(0), 0, ai, len);
-               TMP2 *= b[2].values(0)[rix];
-               c[rix] = TMP2;
-       }
-
        protected void genexec(double[] avals, int[] aix, int ai, SideInput[] 
b, double[] scalars, double[] c, int ci,
                int alen, int len, long grix, int rix) {
                throw new RuntimeException("Sparse fused einsum not 
implemented");
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
index b8af61d35b..2b1074c80c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
@@ -31,24 +31,40 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.codegen.cplan.CNode;
 import org.apache.sysds.hops.codegen.cplan.CNodeCell;
 import org.apache.sysds.hops.codegen.cplan.CNodeData;
-import org.apache.sysds.runtime.codegen.*;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
+import org.apache.sysds.runtime.codegen.SpoofOperator;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.einsum.*;
+import org.apache.sysds.runtime.einsum.EOpNode;
+import org.apache.sysds.runtime.einsum.EOpNodeBinary;
 import org.apache.sysds.runtime.einsum.EOpNodeBinary.EBinaryOperand;
-import org.apache.sysds.runtime.functionobjects.*;
+import org.apache.sysds.runtime.einsum.EOpNodeData;
+import org.apache.sysds.runtime.einsum.EOpNodeFuse;
+import org.apache.sysds.runtime.einsum.EOpNodeUnary;
+import org.apache.sysds.runtime.einsum.EinsumContext;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.utils.Explain;
 
-import java.util.*;
 
 import static org.apache.sysds.api.DMLScript.EXPLAIN;
 import static 
org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
 public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
-    public static final boolean FORCE_CELL_TPL = false; // naive looped 
solution
+       public static final boolean FORCE_CELL_TPL = false; // naive looped 
solution
 
        public static final boolean FUSE_OUTER_MULTIPLY = true;
        public static final boolean FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK 
= true;
@@ -63,7 +79,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
        public EinsumCPInstruction(Operator op, String opcode, String istr, 
CPOperand out, CPOperand... inputs)
        {
                super(op, opcode, istr, out, inputs);
-        _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2;
+               _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2;
                _in = inputs;
                this.eqStr = inputs[0].getName();
                if (PRINT_TRACE)
@@ -80,9 +96,9 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                                if(mb instanceof CompressedMatrixBlock){
                                        mb = ((CompressedMatrixBlock) 
mb).getUncompressed("Spoof instruction");
                                }
-                if(mb.getNumRows() == 1){
-                    ensureMatrixBlockColumnVector(mb);
-                }
+                               if(mb.getNumRows() == 1){
+                                       ensureMatrixBlockColumnVector(mb);
+                               }
                                inputs.add(mb);
                        }
                }
@@ -93,11 +109,11 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
 
                if( LOG.isTraceEnabled() ) LOG.trace("output: "+resultString +" 
"+einc.outRows+"x"+einc.outCols);
 
-               ArrayList<String> inputsChars = 
einc.newEquationStringInputsSplit;
+               List<String> inputsChars = einc.newEquationStringInputsSplit;
 
                if(LOG.isTraceEnabled()) 
LOG.trace(String.join(",",einc.newEquationStringInputsSplit));
-               ArrayList<EOpNode> eOpNodes = new 
ArrayList<>(inputsChars.size());
-               ArrayList<EOpNode> eOpNodesScalars = new 
ArrayList<>(inputsChars.size()); // computed separately and not included into 
plan until it is already created
+               List<EOpNode> eOpNodes = new ArrayList<>(inputsChars.size());
+               List<EOpNode> eOpNodesScalars = new 
ArrayList<>(inputsChars.size()); // computed separately and not included into 
plan until it is already created
 
                //make all vetors col vectors
                for(int i = 0; i < inputs.size(); i++){
@@ -106,7 +122,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
 
                addSumDimensionsDiagonalsAndScalars(einc, inputsChars, 
eOpNodes, eOpNodesScalars, einc.charToDimensionSize);
 
-               HashMap<Character, Integer> characterToOccurences = 
einc.characterAppearanceCount;
+               Map<Character, Integer> characterToOccurences = 
einc.characterAppearanceCount;
 
                for (int i = 0; i < inputsChars.size(); i++) {
                        if (inputsChars.get(i) == null) continue;
@@ -118,37 +134,16 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        eOpNodes.add(n);
                }
 
-               ArrayList<EOpNode> ret = new ArrayList<>();
+               List<EOpNode> ret = new ArrayList<>();
                addVectorMultiplies(eOpNodes, 
eOpNodesScalars,characterToOccurences, einc.outChar1, einc.outChar2, ret);
                eOpNodes = ret;
 
                List<EOpNode> plan;
                ArrayList<MatrixBlock> remainingMatrices;
 
-        if(!FORCE_CELL_TPL) {
-                       if(true){
-                               plan = generateGreedyPlan(eOpNodes, 
eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, 
einc.outChar1, einc.outChar2);
-                       }else { // old way: try to do fusion first and then 
rest in binary fashion cost based
-                               List<EOpNodeFuse> fuseOps;
-                               do {
-                                       ret = new ArrayList<>();
-                                       fuseOps = 
EOpNodeFuse.findFuseOps(eOpNodes, einc.outChar1, einc.outChar2, 
einc.charToDimensionSize, characterToOccurences, ret);
-
-                                       if(!fuseOps.isEmpty()) {
-                                               for (EOpNodeFuse fuseOp : 
fuseOps) {
-                                                       if (fuseOp.c1 == null) {
-                                                               
eOpNodesScalars.add(fuseOp);
-                                                               continue;
-                                                       }
-                                                       ret.add(fuseOp);
-                                               }
-                                               eOpNodes = ret;
-                                       }
-                               } while(eOpNodes.size() > 1 && 
!fuseOps.isEmpty());
-                               Pair<Integer, List<EOpNode>> costAndPlan = 
generateBinaryPlanCostBased(0, eOpNodes, einc.charToDimensionSize, 
characterToOccurences,
-                                       einc.outChar1, einc.outChar2);
-                               plan = costAndPlan.getRight();
-                       }
+               if(!FORCE_CELL_TPL) {
+                       plan = generateGreedyPlan(eOpNodes, eOpNodesScalars, 
+                               einc.charToDimensionSize, 
characterToOccurences, einc.outChar1, einc.outChar2);
                        if(!eOpNodesScalars.isEmpty()){
                                EOpNode l = eOpNodesScalars.get(0);
                                for(int i = 1; i < eOpNodesScalars.size(); i++){
@@ -194,7 +189,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        }
 
                        remainingMatrices = executePlan(plan, inputs);
-        }else{
+               }else{
                        plan = eOpNodes;
                        remainingMatrices = inputs;
                }
@@ -208,7 +203,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                                        ec.setMatrixOutput(output.getName(), 
remainingMatrices.get(0));
                                }
                                else if(resNode.c1 == einc.outChar2 && 
resNode.c2 == einc.outChar1){
-                    if( LOG.isTraceEnabled()) LOG.trace("Transposing the final 
result");
+                                       if( LOG.isTraceEnabled()) 
LOG.trace("Transposing the final result");
 
                                        ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
                                        MatrixBlock resM = 
remainingMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0);
@@ -233,7 +228,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                }else{
                        // use cell template with loops for remaining
                        ArrayList<MatrixBlock> mbs = remainingMatrices;
-                       ArrayList<String> chars = new ArrayList<>();
+                       List<String> chars = new ArrayList<>();
 
                        for (int i = 0; i < plan.size(); i++) {
                                String s;
@@ -243,7 +238,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                                chars.add(s);
                        }
 
-                       ArrayList<Character> summingChars = new ArrayList<>();
+                       List<Character> summingChars = new ArrayList<>();
                        for (Character c : characterToOccurences.keySet()) {
                                if (c != einc.outChar1 && c != einc.outChar2) 
summingChars.add(c);
                        }
@@ -281,7 +276,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                return new 
EOpNodeBinary(addToNode,scalar,EBinaryOperand.AB_scalar);
        }
 
-       private static Pair<Integer, EOpNode> 
addScalarToPlanFindMinCost(EOpNode plan, HashMap<Character, Integer> 
charToSizeMap) {
+       private static Pair<Integer, EOpNode> 
addScalarToPlanFindMinCost(EOpNode plan, Map<Character, Integer> charToSizeMap) 
{
                int thisSize = 0;
                if(plan.c1 != null) thisSize += charToSizeMap.get(plan.c1);
                if(plan.c2 != null) thisSize += charToSizeMap.get(plan.c2);
@@ -315,8 +310,10 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                return Pair.of(cost, plan);
        }
 
-       private static void addVectorMultiplies(ArrayList<EOpNode> eOpNodes, 
ArrayList<EOpNode> eOpNodesScalars,HashMap<Character, Integer> 
charToOccurences, Character outChar1, Character outChar2,ArrayList<EOpNode> 
ret) {
-               HashMap<Character, ArrayList<EOpNode>> vectorCharacterToIndices 
= new HashMap<>();
+       private static void addVectorMultiplies(List<EOpNode> eOpNodes, 
List<EOpNode> eOpNodesScalars, 
+               Map<Character, Integer> charToOccurences, Character outChar1, 
Character outChar2, List<EOpNode> ret) 
+       {
+               Map<Character, List<EOpNode>> vectorCharacterToIndices = new 
HashMap<>();
                for (int i = 0; i < eOpNodes.size(); i++) {
                        if (eOpNodes.get(i).c2 == null) {
                                if 
(vectorCharacterToIndices.containsKey(eOpNodes.get(i).c1))
@@ -325,9 +322,9 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                                        
vectorCharacterToIndices.put(eOpNodes.get(i).c1, new 
ArrayList<>(Collections.singletonList(eOpNodes.get(i))));
                        }
                }
-               HashSet<EOpNode> usedNodes = new HashSet<>();
+               Set<EOpNode> usedNodes = new HashSet<>();
                for(Character c : vectorCharacterToIndices.keySet()){
-                       ArrayList<EOpNode> nodes = 
vectorCharacterToIndices.get(c);
+                       List<EOpNode> nodes = vectorCharacterToIndices.get(c);
 
                        if(nodes.size()==1) continue;
                        EOpNode left = nodes.get(0);
@@ -358,9 +355,9 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                }
        }
 
-       private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc, 
ArrayList<String> inputStrings,
-               ArrayList<EOpNode> eOpNodes, ArrayList<EOpNode> eOpNodesScalars,
-               HashMap<Character, Integer> charToDimensionSize) {
+       private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc, 
List<String> inputStrings,
+               List<EOpNode> eOpNodes, List<EOpNode> eOpNodesScalars, 
Map<Character, Integer> charToDimensionSize) 
+       {
                for(int i = 0; i< inputStrings.size(); i++){
                        String s = inputStrings.get(i);
                        if (s.isEmpty()){
@@ -414,9 +411,10 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                }
        }
 
-       private static List<EOpNode> generateGreedyPlan(ArrayList<EOpNode> 
eOpNodes,
-               ArrayList<EOpNode> eOpNodesScalars, HashMap<Character, Integer> 
charToSizeMap, HashMap<Character, Integer> charToOccurences, Character 
outChar1, Character outChar2) {
-               ArrayList<EOpNode> ret;
+       private static List<EOpNode> generateGreedyPlan(List<EOpNode> eOpNodes, 
List<EOpNode> eOpNodesScalars,
+               Map<Character, Integer> charToSizeMap, Map<Character, Integer> 
charToOccurences, Character outChar1, Character outChar2) 
+       {
+               List<EOpNode> ret;
                int lastNumOfOperands = -1;
                while(lastNumOfOperands != eOpNodes.size() && eOpNodes.size() > 
1){
                        lastNumOfOperands = eOpNodes.size();
@@ -443,7 +441,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        eOpNodes = ret;
 
                        ret = new ArrayList<>();
-                       ArrayList<List<EOpNode>> matrixMultiplies = 
findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences,
+                       List<List<EOpNode>> matrixMultiplies = 
findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences,
                                ret);
 
                        for(List<EOpNode> list : matrixMultiplies) {
@@ -456,7 +454,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                return eOpNodes;
        }
 
-       private static void reverseMMChainIfBeneficial(ArrayList<EOpNode> 
mmChain){ // possibly check the cost instead of number of transposes
+       private static void reverseMMChainIfBeneficial(List<EOpNode> mmChain){ 
// possibly check the cost instead of number of transposes
                char c1 = mmChain.get(0).c1;
                char c2 = mmChain.get(0).c2;
                int noTransposes = 0;
@@ -466,7 +464,6 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        }
                        noTransposes++;
                        if(c2 == mmChain.get(i).c2){
-                               c1 = c1;
                                c2 = mmChain.get(i).c1;
                        }
                        else if(c1 == mmChain.get(i).c1) {
@@ -481,10 +478,10 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        Collections.reverse(mmChain);
                }
        }
-       private static EOpNodeBinary optimizeMMChain(List<EOpNode> mmChainL, 
HashMap<Character, Integer> charToSizeMap) {
-               ArrayList<EOpNode> mmChain = new ArrayList<>(mmChainL);
+       private static EOpNodeBinary optimizeMMChain(List<EOpNode> mmChainL, 
Map<Character, Integer> charToSizeMap) {
+               List<EOpNode> mmChain = new ArrayList<>(mmChainL);
                reverseMMChainIfBeneficial(mmChain);
-               ArrayList<Pair<Integer, Integer>> dimensions = new 
ArrayList<>();
+               List<Pair<Integer, Integer>> dimensions = new ArrayList<>();
 
                for(int i = 0; i < mmChain.size()-1; i++){
                        EOpNode n1 = mmChain.get(i);
@@ -516,7 +513,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                return EOpNodeBinary.combineMatrixMultiply(left, right);
        }
 
-       private static void getDimsArray( ArrayList<Pair<Integer, Integer>> 
chain, double[] dimsArray )
+       private static void getDimsArray( List<Pair<Integer, Integer>> chain, 
double[] dimsArray )
        {
                for( int i = 0; i < chain.size(); i++ ) {
                        if (i == 0) {
@@ -539,11 +536,12 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        }
                }
        }
-       private static ArrayList<List<EOpNode>> 
findMatrixMultiplicationChains(ArrayList<EOpNode> inpOperands, Character 
outChar1, Character outChar2, HashMap<Character, Integer> charToOccurences,
-               ArrayList<EOpNode> ret) {
-               HashSet<Character> charactersThatCanBeContracted = new 
HashSet<>();
-               HashMap<Character, ArrayList<EOpNode>> characterToNodes = new 
HashMap<>();
-               ArrayList<EOpNode> operandsTodo =  new ArrayList<>();
+       private static List<List<EOpNode>> 
findMatrixMultiplicationChains(List<EOpNode> inpOperands, 
+               Character outChar1, Character outChar2, Map<Character, Integer> 
charToOccurences, List<EOpNode> ret) 
+       {
+               Set<Character> charactersThatCanBeContracted = new HashSet<>();
+               Map<Character, ArrayList<EOpNode>> characterToNodes = new 
HashMap<>();
+               List<EOpNode> operandsTodo =  new ArrayList<>();
                for(EOpNode op : inpOperands) {
                        if(op.c2 == null || op.c1 == null) continue;
 
@@ -563,9 +561,9 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        }
                        if (todo)  operandsTodo.add(op);
                }
-               ArrayList<List<EOpNode>> res = new ArrayList<>();
+               List<List<EOpNode>> res = new ArrayList<>();
 
-               HashSet<EOpNode> doneNodes = new HashSet<>();
+               Set<EOpNode> doneNodes = new HashSet<>();
 
                for(int i = 0; i < operandsTodo.size(); i++){
                        EOpNode iterateNode = operandsTodo.get(i);
@@ -624,6 +622,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
        }
 
        // old way, DFS finds all paths
+       @SuppressWarnings("unused")
        private Pair<Integer, List<EOpNode>> generateBinaryPlanCostBased(int 
cost, ArrayList<EOpNode> operands, HashMap<Character, Integer> charToSizeMap, 
HashMap<Character, Integer> charToOccurences, Character outChar1, Character 
outChar2) {
                Integer minCost = cost;
                List<EOpNode> minNodes = operands;
@@ -632,7 +631,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        boolean swap = (operands.get(0).c2 == null && 
operands.get(1).c2 != null) || operands.get(0).c1 == null;
                        EOpNode n1 = operands.get(!swap ? 0 : 1);
                        EOpNode n2 = operands.get(!swap ? 1 : 0);
-                       Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, 
charToOccurences, outChar1, outChar2);
+                       Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> t = EOpNodeBinary.tryCombineAndCost(n1, n2, charToSizeMap, 
charToOccurences, outChar1, outChar2);
                        if (t != null) {
                                EOpNodeBinary newNode = new EOpNodeBinary(n1, 
n2, t.getMiddle());
                                int thisCost = cost + t.getLeft();
@@ -650,7 +649,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                                EOpNode n1 = operands.get(!swap ? i : j);
                                EOpNode n2 = operands.get(!swap ? j : i);
 
-                               Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, 
charToOccurences, outChar1, outChar2);
+                               Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> t = EOpNodeBinary.tryCombineAndCost(n1, n2, charToSizeMap, 
charToOccurences, outChar1, outChar2);
                                if (t != null){
                                        EOpNodeBinary newNode = new 
EOpNodeBinary(n1, n2, t.getMiddle());
                                        int thisCost = cost + t.getLeft();
@@ -687,10 +686,10 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                return Pair.of(minCost, minNodes);
        }
 
-       private ArrayList<MatrixBlock> executePlan(List<EOpNode> plan, 
ArrayList<MatrixBlock> inputs) {
+       private ArrayList<MatrixBlock> executePlan(List<EOpNode> plan, 
List<MatrixBlock> inputs) {
                ArrayList<MatrixBlock> res = new ArrayList<>(plan.size());
                for(EOpNode p : plan){
-            res.add(p.computeEOpNode(inputs, _numThreads, LOG));
+                       res.add(p.computeEOpNode(inputs, _numThreads, LOG));
                }
                return res;
        }
@@ -708,7 +707,7 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                        mb.getDenseBlock().resetNoFill(mb.getNumRows(),1);
                }
        }
-    public static void ensureMatrixBlockRowVector(MatrixBlock mb){
+       public static void ensureMatrixBlockRowVector(MatrixBlock mb){
                if(mb.getNumRows() > 1){
                        mb.setNumColumns(mb.getNumRows());
                        mb.setNumRows(1);
@@ -722,10 +721,12 @@ public class EinsumCPInstruction extends 
BuiltinNaryCPInstruction {
                }
        }
 
-       private MatrixBlock computeCellSummation(ArrayList<MatrixBlock> inputs, 
List<String> inputsChars, String resultString,
-                                                                               
                                   HashMap<Character, Integer> 
charToDimensionSizeInt, List<Character> summingChars, int outRows, int outCols){
+       private MatrixBlock computeCellSummation(ArrayList<MatrixBlock> inputs, 
List<String> inputsChars, 
+               String resultString, Map<Character, Integer> 
charToDimensionSizeInt, 
+               List<Character> summingChars, int outRows, int outCols)
+       {
                ArrayList<CNode> dummyIn = new ArrayList<>();
-        dummyIn.add(new CNodeData(new LiteralOp(0), 0, 0, DataType.SCALAR));
+               dummyIn.add(new CNodeData(new LiteralOp(0), 0, 0, 
DataType.SCALAR));
                CNodeCell cnode = new CNodeCell(dummyIn, null);
                StringBuilder sb = new StringBuilder();
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java 
b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
index 97e1eb83bf..f86da127d6 100644
--- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
@@ -46,7 +46,7 @@ import java.util.Map;
 @RunWith(Parameterized.class)
 public class EinsumTest extends AutomatedTestBase
 {
-    final private static List<Config> TEST_CONFIGS = List.of(
+       final private static List<Config> TEST_CONFIGS = List.of(
                new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), // 
mm
                new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))),
                new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))),
@@ -78,23 +78,23 @@ public class EinsumTest extends AutomatedTestBase
                new Config("ij,i->i",   List.of(shape(10, 5), shape(10))),
                new Config("ij,i->j",   List.of(shape(10, 5), shape(10))),
 
-               new Config("i,i->",     List.of(shape(5), shape(5))), // dot
-               new Config("i,j->",     List.of(shape(5), shape(80))), // sum
-               new Config("i,j->ij",     List.of(shape(5), shape(80))), // 
outer vect mult
-               new Config("i,j->ji",     List.of(shape(5), shape(80))), // 
outer vect mult
+               new Config("i,i->",      List.of(shape(5), shape(5))), // dot
+               new Config("i,j->",      List.of(shape(5), shape(80))), // sum
+               new Config("i,j->ij",    List.of(shape(5), shape(80))), // 
outer vect mult
+               new Config("i,j->ji",    List.of(shape(5), shape(80))), // 
outer vect mult
 
-               new Config("ij->",     List.of(shape(10, 5))), // sum
-               new Config("i->",     List.of(shape(10))), // sum
-               new Config("ij->i",     List.of(shape(10, 5))), // sum(1)
-               new Config("ij->j",     List.of(shape(10, 5))), // sum(0)
-               new Config("ij->ji",     List.of(shape(10, 5))), // T
-               new Config("ij->ij",     List.of(shape(10, 5))),
-               new Config("i->i",     List.of(shape(10))),
-               new Config("ii->i",     List.of(shape(10, 10))), // Diag
-               new Config("ii->",     List.of(shape(10, 10))), // Trace
-               new Config("ii,i->i",     List.of(shape(10, 10),shape(10))), // 
Diag*vec
+               new Config("ij->",       List.of(shape(10, 5))), // sum
+               new Config("i->",        List.of(shape(10))), // sum
+               new Config("ij->i",      List.of(shape(10, 5))), // sum(1)
+               new Config("ij->j",      List.of(shape(10, 5))), // sum(0)
+               new Config("ij->ji",     List.of(shape(10, 5))), // T
+               new Config("ij->ij",     List.of(shape(10, 5))),
+               new Config("i->i",       List.of(shape(10))),
+               new Config("ii->i",      List.of(shape(10, 10))), // Diag
+               new Config("ii->",       List.of(shape(10, 10))), // Trace
+               new Config("ii,i->i",    List.of(shape(10, 10),shape(10))), // 
Diag*vec
 
-               new Config("ab,cd->ba",     List.of(shape( 6, 10), shape(6, 
5))), // sum cd to scalar and multiply ab
+               new Config("ab,cd->ba",  List.of(shape( 6, 10), shape(6, 5))), 
// sum cd to scalar and multiply ab
 
                new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (the 
cell tpl fallback)
                                List.of(shape(5, 6), shape(5, 3), shape(5, 10), 
shape(6, 3), shape(10, 6), shape(10, 3))),
@@ -137,7 +137,7 @@ public class EinsumTest extends AutomatedTestBase
                new Config("ab,ab,a,ag,gz->bz",   List.of(shape(10, 5), 
shape(10, 5),shape(10),shape(10,200),shape(200,7)))
                ,new Config("ab,ab,a,ag,gz->bz",   List.of(shape(10, 5), 
shape(10, 5),shape(10),shape(10,20),shape(20,7)))
                ,new Config("ab,ab,bc,bc->bc",   List.of(shape(10, 5), 
shape(10, 5),shape(5,20),shape(5,20)))
-    );
+       );
        private final int id;
        private final String einsumStr;
        private final File dmlFile;
@@ -312,36 +312,36 @@ public class EinsumTest extends AutomatedTestBase
        }
 
        private static class Config {
-        public List<Double> factors;
+               public List<Double> factors;
                String einsumStr;
                List<int[]> shapes;
 
-        Config(String einsum, List<int[]> shapes) {
-            this(einsum,shapes,null);
-        }
-        Config(String einsum, Map<Character, Integer> charToSize){
-            this(einsum, charToSize, null);
-        }
-
-        Config(String einsum, Map<Character, Integer> charToSize, List<Double> 
factors) {
-            this.einsumStr = einsum;
-            String leftPart = einsum.split("->")[0];
-            List<int[]> shapes = new ArrayList<>();
-            for(String op : 
Arrays.stream(leftPart.split(",")).map(x->x.trim()).toList()){
-                if (op.length() == 1) {
-                    shapes.add(new int[]{charToSize.get(op.charAt(0))});
-                }else{
-                    shapes.add(new 
int[]{charToSize.get(op.charAt(0)),charToSize.get(op.charAt(1))});
-                }
-
-            }
-            this.shapes = shapes;
-            this.factors = factors;
-        }
+               Config(String einsum, List<int[]> shapes) {
+                       this(einsum,shapes,null);
+               }
+               Config(String einsum, Map<Character, Integer> charToSize){
+                       this(einsum, charToSize, null);
+               }
+
+               Config(String einsum, Map<Character, Integer> charToSize, 
List<Double> factors) {
+                       this.einsumStr = einsum;
+                       String leftPart = einsum.split("->")[0];
+                       List<int[]> shapes = new ArrayList<>();
+                       for(String op : 
Arrays.stream(leftPart.split(",")).map(x->x.trim()).toList()){
+                               if (op.length() == 1) {
+                                       shapes.add(new 
int[]{charToSize.get(op.charAt(0))});
+                               }else{
+                                       shapes.add(new 
int[]{charToSize.get(op.charAt(0)),charToSize.get(op.charAt(1))});
+                               }
+
+                       }
+                       this.shapes = shapes;
+                       this.factors = factors;
+               }
                Config(String einsum, List<int[]> shapes, List<Double> factors) 
{
                        this.einsumStr = einsum;
                        this.shapes = shapes;
-            this.factors = factors;
+                       this.factors = factors;
                }
        }
 

Reply via email to