This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 3588d195335 Add spark mapstate (#31669)
3588d195335 is described below

commit 3588d195335fa3dc06b002e5e468baa27e79f8fa
Author: twosom <[email protected]>
AuthorDate: Fri Jun 28 23:08:52 2024 +0900

    Add spark mapstate (#31669)
    
    * add isEmpty test in testMap
    
    * add map state in spark runner
    
    * update comment on SparkStateInternalsTest
    
    * modify isEmpty test to assertFalse / assertTrue
---
 .../beam/runners/core/StateInternalsTest.java      |   6 +
 .../spark/stateful/SparkStateInternals.java        | 148 ++++++++++++++++++++-
 .../spark/stateful/SparkStateInternalsTest.java    |  10 +-
 3 files changed, 152 insertions(+), 12 deletions(-)

diff --git 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
index a4cd504eee7..e15249969f2 100644
--- 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
+++ 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
@@ -386,10 +386,16 @@ public abstract class StateInternalsTest {
         value.entries().readLater().read(),
         containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4), 
MapEntry.of("E", 5)));
 
+    // isEmpty
+    assertFalse(value.isEmpty().read());
+
     // clear
     value.clear();
     assertThat(value.entries().read(), Matchers.emptyIterable());
     assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), equalTo(value));
+
+    // isEmpty
+    assertTrue(value.isEmpty().read());
   }
 
   @Test
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
index 3ad955e78aa..731cadb89f0 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
@@ -19,7 +19,11 @@ package org.apache.beam.runners.spark.stateful;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
 import org.apache.beam.runners.core.StateInternals;
 import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateTag;
@@ -28,12 +32,14 @@ import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.MapCoder;
 import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.CombiningState;
 import org.apache.beam.sdk.state.MapState;
 import org.apache.beam.sdk.state.MultimapState;
 import org.apache.beam.sdk.state.OrderedListState;
 import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
 import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.State;
 import org.apache.beam.sdk.state.StateContext;
@@ -44,6 +50,7 @@ import 
org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
 import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Instant;
@@ -119,11 +126,10 @@ class SparkStateInternals<K> implements StateInternals {
 
     @Override
     public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
-        StateTag<MapState<KeyT, ValueT>> spec,
+        StateTag<MapState<KeyT, ValueT>> address,
         Coder<KeyT> mapKeyCoder,
         Coder<ValueT> mapValueCoder) {
-      throw new UnsupportedOperationException(
-          String.format("%s is not supported", 
MapState.class.getSimpleName()));
+      return new SparkMapState<>(namespace, address, MapCoder.of(mapKeyCoder, 
mapValueCoder));
     }
 
     @Override
@@ -359,6 +365,142 @@ class SparkStateInternals<K> implements StateInternals {
     }
   }
 
+  private final class SparkMapState<MapKeyT, MapValueT>
+      extends AbstractState<Map<MapKeyT, MapValueT>> implements 
MapState<MapKeyT, MapValueT> {
+
+    private SparkMapState(
+        StateNamespace namespace,
+        StateTag<? extends State> address,
+        Coder<Map<MapKeyT, MapValueT>> coder) {
+      super(namespace, address, coder);
+    }
+
+    @Override
+    public ReadableState<MapValueT> get(MapKeyT key) {
+      return getOrDefault(key, null);
+    }
+
+    @Override
+    public ReadableState<MapValueT> getOrDefault(MapKeyT key, @Nullable 
MapValueT defaultValue) {
+      return new ReadableState<MapValueT>() {
+        @Override
+        public MapValueT read() {
+          Map<MapKeyT, MapValueT> sparkMapState = readValue();
+          if (sparkMapState == null) {
+            return defaultValue;
+          }
+          return sparkMapState.getOrDefault(key, defaultValue);
+        }
+
+        @Override
+        public ReadableState<MapValueT> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public void put(MapKeyT key, MapValueT value) {
+      Map<MapKeyT, MapValueT> sparkMapState = readValue();
+      if (sparkMapState == null) {
+        sparkMapState = new HashMap<>();
+      }
+      sparkMapState.put(key, value);
+      writeValue(sparkMapState);
+    }
+
+    @Override
+    public ReadableState<MapValueT> computeIfAbsent(
+        MapKeyT key, Function<? super MapKeyT, ? extends MapValueT> 
mappingFunction) {
+      Map<MapKeyT, MapValueT> sparkMapState = readValue();
+      MapValueT current = sparkMapState.get(key);
+      if (current == null) {
+        put(key, mappingFunction.apply(key));
+      }
+      return ReadableStates.immediate(current);
+    }
+
+    @Override
+    public void remove(MapKeyT key) {
+      Map<MapKeyT, MapValueT> sparkMapState = readValue();
+      sparkMapState.remove(key);
+      writeValue(sparkMapState);
+    }
+
+    @Override
+    public ReadableState<Iterable<MapKeyT>> keys() {
+      return new ReadableState<Iterable<MapKeyT>>() {
+        @Override
+        public Iterable<MapKeyT> read() {
+          Map<MapKeyT, MapValueT> sparkMapState = readValue();
+          if (sparkMapState == null) {
+            return Collections.emptyList();
+          }
+          return sparkMapState.keySet();
+        }
+
+        @Override
+        public ReadableState<Iterable<MapKeyT>> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<MapValueT>> values() {
+      return new ReadableState<Iterable<MapValueT>>() {
+        @Override
+        public Iterable<MapValueT> read() {
+          Map<MapKeyT, MapValueT> sparkMapState = readValue();
+          if (sparkMapState == null) {
+            return Collections.emptyList();
+          }
+          Iterable<MapValueT> result = readValue().values();
+          return result != null ? ImmutableList.copyOf(result) : 
Collections.emptyList();
+        }
+
+        @Override
+        public ReadableState<Iterable<MapValueT>> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>> entries() {
+      return new ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>>() {
+        @Override
+        public Iterable<Map.Entry<MapKeyT, MapValueT>> read() {
+          Map<MapKeyT, MapValueT> sparkMapState = readValue();
+          if (sparkMapState == null) {
+            return Collections.emptyList();
+          }
+          return sparkMapState.entrySet();
+        }
+
+        @Override
+        public ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>> 
readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public Boolean read() {
+          return stateTable.get(namespace.stringKey(), address.getId()) == 
null;
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+      };
+    }
+  }
+
   private final class SparkBagState<T> extends AbstractState<List<T>> 
implements BagState<T> {
     private SparkBagState(StateNamespace namespace, StateTag<BagState<T>> 
address, Coder<T> coder) {
       super(namespace, address, ListCoder.of(coder));
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
index 6118f96fcc3..f6f2b8d6df6 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
@@ -25,7 +25,7 @@ import org.junit.runners.JUnit4;
 
 /**
  * Tests for {@link SparkStateInternals}. This is based on {@link 
StateInternalsTest}. Ignore set
- * and map tests.
+ * tests.
  */
 @RunWith(JUnit4.class)
 public class SparkStateInternalsTest extends StateInternalsTest {
@@ -51,15 +51,7 @@ public class SparkStateInternalsTest extends 
StateInternalsTest {
   @Ignore
   public void testMergeSetIntoNewNamespace() {}
 
-  @Override
-  @Ignore
-  public void testMap() {}
-
   @Override
   @Ignore
   public void testSetReadable() {}
-
-  @Override
-  @Ignore
-  public void testMapReadable() {}
 }

Reply via email to