This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0bca04715c [SYSTEMDS-3790] New Federated Planner MemoTable
0bca04715c is described below

commit 0bca04715c3280d5ab748976494b39ebba46889d
Author: min-guk <[email protected]>
AuthorDate: Thu Nov 21 09:24:36 2024 +0100

    [SYSTEMDS-3790] New Federated Planner MemoTable
    
    Closes #2141.
---
 .../apache/sysds/hops/fedplanner/MemoTable.java    | 160 ++++++++++++++++++
 .../test/component/federated/MemoTableTest.java    | 186 +++++++++++++++++++++
 2 files changed, 346 insertions(+)

diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
new file mode 100644
index 0000000000..8fce06b33e
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+
+/**
+ * A Memoization Table for managing federated plans (`FedPlan`) based on
+ * combinations of Hops and FTypes. Each combination is mapped to a list
+ * of possible execution plans, allowing for pruning and optimization.
+ */
+public class MemoTable {
+
+       // Maps combinations of Hop ID and FType to lists of FedPlans
+       private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> hopMemoTable 
= new HashMap<>();
+       
+       /**
+        * Represents a federated execution plan with its cost and associated 
references.
+        */
+       public static class FedPlan {
+               @SuppressWarnings("unused")
+               private final Hop hopRef;                       // The 
associated Hop object
+               private final double cost;                      // Cost of this 
federated plan
+               @SuppressWarnings("unused")
+               private final List<Pair<Long, FType>> planRefs; // References 
to dependent plans
+
+               public FedPlan(Hop hopRef, double cost, List<Pair<Long, FType>> 
planRefs) {
+                       this.hopRef = hopRef;
+                       this.cost = cost;
+                       this.planRefs = planRefs;
+               }
+
+               public double getCost() {
+                       return cost;
+               }
+       }
+
+       /**
+        * Adds a single FedPlan to the memo table for a given Hop and FType.
+        * If the entry already exists, the new FedPlan is appended to the list.
+        *
+        * @param hop     The Hop object.
+        * @param fType   The associated FType.
+        * @param fedPlan The FedPlan to add.
+        */
+       public void addFedPlan(Hop hop, FType fType, FedPlan fedPlan) {
+               if (contains(hop, fType)) {
+                       List<FedPlan> fedPlanList = get(hop, fType);
+                       fedPlanList.add(fedPlan);
+               } else {
+                       List<FedPlan> fedPlanList = new ArrayList<>();
+                       fedPlanList.add(fedPlan);
+                       hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), 
fType), fedPlanList);
+               }
+       }
+
+       /**
+        * Adds multiple FedPlans to the memo table for a given Hop and FType.
+        * If the entry already exists, the new FedPlans are appended to the 
list.
+        *
+        * @param hop    The Hop object.
+        * @param fType  The associated FType.
+        * @param newFedPlanList The list of FedPlans to add.
+        */
+       public void addFedPlanList(Hop hop, FType fType, List<FedPlan> 
fedPlanList) {
+               if (contains(hop, fType)) {
+                       List<FedPlan> prevFedPlanList = get(hop, fType);
+                       prevFedPlanList.addAll(fedPlanList);
+               } else {
+                       hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), 
fType), fedPlanList);
+               }
+       }
+
+       /**
+        * Retrieves the list of FedPlans associated with a given Hop and FType.
+        *
+        * @param hop   The Hop object.
+        * @param fType The associated FType.
+        * @return The list of FedPlans, or null if no entry exists.
+        */
+       public List<FedPlan> get(Hop hop, FType fType) {
+               return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), 
fType));
+       }
+
+       /**
+        * Checks if the memo table contains an entry for a given Hop and FType.
+        *
+        * @param hop   The Hop object.
+        * @param fType The associated FType.
+        * @return True if the entry exists, false otherwise.
+        */
+       public boolean contains(Hop hop, FType fType) {
+               return hopMemoTable.containsKey(new 
ImmutablePair<>(hop.getHopID(), fType));
+       }
+
+       /**
+        * Prunes the FedPlans associated with a specific Hop and FType,
+        * keeping only the plan with the minimum cost.
+        *
+        * @param hop   The Hop object.
+        * @param fType The associated FType.
+        */
+       public void prunePlan(Hop hop, FType fType) {
+               prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), 
fType)));
+       }
+
+       /**
+        * Prunes all entries in the memo table, retaining only the minimum-cost
+        * FedPlan for each entry.
+        */
+       public void pruneAll() {
+               for (Map.Entry<Pair<Long, FType>, List<FedPlan>> entry : 
hopMemoTable.entrySet()) {
+                       prunePlan(entry.getValue());
+               }
+       }
+
+       /**
+        * Prunes the given list of FedPlans to retain only the plan with the 
minimum cost.
+        *
+        * @param fedPlanList The list of FedPlans to prune.
+        */
+       private void prunePlan(List<FedPlan> fedPlanList) {
+               if (fedPlanList.size() > 1) {
+                       // Find the FedPlan with the minimum cost
+                       FedPlan minCostPlan = fedPlanList.stream()
+                                       .min(Comparator.comparingDouble(plan -> 
plan.cost))
+                                       .orElse(null);
+
+                       // Retain only the minimum cost plan
+                       fedPlanList.clear();
+                       fedPlanList.add(minCostPlan);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java 
b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
new file mode 100644
index 0000000000..e3928c1263
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.federated;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FTypes;
+import org.apache.sysds.hops.fedplanner.MemoTable;
+import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan;
+import org.apache.commons.lang3.tuple.Pair;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.when;
+
+public class MemoTableTest {
+       
+       private MemoTable memoTable;
+       
+       @Mock
+       private Hop mockHop1;
+       
+       @Mock
+       private Hop mockHop2;
+       
+       private java.util.Random rand;
+
+       @Before
+       public void setUp() {
+               MockitoAnnotations.openMocks(this);
+               memoTable = new MemoTable();
+               
+               // Set up unique IDs for mock Hops
+               when(mockHop1.getHopID()).thenReturn(1L);
+               when(mockHop2.getHopID()).thenReturn(2L);
+               
+               // Initialize random generator with fixed seed for reproducible 
tests
+               rand = new java.util.Random(42); 
+       }
+       
+       @Test
+       public void testAddAndGetSingleFedPlan() {
+               // Initialize test data
+               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+               FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
+               
+               // Verify initial state
+               List<FedPlan> result = memoTable.get(mockHop1, 
FTypes.FType.FULL);
+               assertNull("Initial FedPlan list should be null before adding 
any plans", result);
+
+               // Add single FedPlan
+               memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
+               
+               // Verify after addition
+               result = memoTable.get(mockHop1, FTypes.FType.FULL);
+               assertNotNull("FedPlan list should exist after adding a plan", 
result);
+               assertEquals("FedPlan list should contain exactly one plan", 1, 
result.size());
+               assertEquals("FedPlan cost should be exactly 10.0", 10.0, 
result.get(0).getCost(), 0.001);
+       }
+       
+       @Test
+       public void testAddMultipleDuplicatedFedPlans() {
+               // Initialize test data with duplicate costs
+               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+               List<FedPlan> fedPlans = new ArrayList<>();
+               fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs));  // Unique 
cost
+               fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs));  // First 
duplicate
+               fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs));  // Second 
duplicate
+               
+               // Add multiple plans including duplicates
+               memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans);
+               
+               // Verify handling of duplicate plans
+               List<FedPlan> result = memoTable.get(mockHop1, 
FTypes.FType.FULL);
+               assertNotNull("FedPlan list should exist after adding multiple 
plans", result);
+               assertEquals("FedPlan list should maintain all plans including 
duplicates", 3, result.size());
+       }
+       
+       @Test
+       public void testContains() {
+               // Initialize test data
+               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+               FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
+               
+               // Verify initial state
+               assertFalse("MemoTable should not contain any entries 
initially", 
+                       memoTable.contains(mockHop1, FTypes.FType.FULL));
+               
+               // Add plan and verify presence
+               memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
+               
+               assertTrue("MemoTable should contain entry after adding 
FedPlan", 
+                       memoTable.contains(mockHop1, FTypes.FType.FULL));
+               assertFalse("MemoTable should not contain entries for different 
Hop", 
+                       memoTable.contains(mockHop2, FTypes.FType.FULL));
+       }
+       
+       @Test
+       public void testPrunePlanPruneAll() {
+               // Initialize base test data
+               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
+               // Create separate FedPlan lists for independent testing of 
each Hop
+               List<FedPlan> fedPlans1 = new ArrayList<>();  // Plans for 
mockHop1
+               List<FedPlan> fedPlans2 = new ArrayList<>();  // Plans for 
mockHop2
+               
+               // Generate random cost FedPlans for both Hops
+               double minCost = Double.MAX_VALUE;
+               int size = 100;
+               for(int i = 0; i < size; i++) {
+                       double cost = rand.nextDouble() * 1000;  // Random cost 
between 0 and 1000
+                       fedPlans1.add(new FedPlan(mockHop1, cost, planRefs));
+                       fedPlans2.add(new FedPlan(mockHop2, cost, planRefs));
+                       minCost = Math.min(minCost, cost);
+               }
+               
+               // Add FedPlan lists to MemoTable
+               memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, 
fedPlans1);
+               memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL, 
fedPlans2);
+               
+               // Test selective pruning on mockHop1
+               memoTable.prunePlan(mockHop1, FTypes.FType.FULL);
+               
+               // Get results for verification
+               List<FedPlan> result1 = memoTable.get(mockHop1, 
FTypes.FType.FULL);
+               List<FedPlan> result2 = memoTable.get(mockHop2, 
FTypes.FType.FULL);
+
+               // Verify selective pruning results
+               assertNotNull("Pruned mockHop1 should maintain a FedPlan list", 
result1);
+               assertEquals("Pruned mockHop1 should contain exactly one 
minimum cost plan", 1, result1.size());
+               assertEquals("Pruned mockHop1's plan should have the minimum 
cost", minCost, result1.get(0).getCost(), 0.001);
+               
+               // Verify unpruned Hop state
+               assertNotNull("Unpruned mockHop2 should maintain a FedPlan 
list", result2);
+               assertEquals("Unpruned mockHop2 should maintain all original 
plans", size, result2.size());
+
+               // Add additional plans to both Hops
+               for(int i = 0; i < size; i++) {
+                       double cost = rand.nextDouble() * 1000;
+                       memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new 
FedPlan(mockHop1, cost, planRefs));
+                       memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new 
FedPlan(mockHop2, cost, planRefs));
+                       minCost = Math.min(minCost, cost);
+               }
+
+               // Test global pruning
+               memoTable.pruneAll();
+               
+               // Verify global pruning results
+               assertNotNull("mockHop1 should maintain a FedPlan list after 
global pruning", result1);
+               assertEquals("mockHop1 should contain exactly one minimum cost 
plan after global pruning", 
+                       1, result1.size());
+               assertEquals("mockHop1's plan should have the global minimum 
cost", 
+                       minCost, result1.get(0).getCost(), 0.001);
+
+               assertNotNull("mockHop2 should maintain a FedPlan list after 
global pruning", result2);
+               assertEquals("mockHop2 should contain exactly one minimum cost 
plan after global pruning", 
+                       1, result2.size());
+               assertEquals("mockHop2's plan should have the global minimum 
cost", 
+                       minCost, result2.get(0).getCost(), 0.001);
+       }
+}

Reply via email to