This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 3588d195335 Add spark mapstate (#31669)
3588d195335 is described below
commit 3588d195335fa3dc06b002e5e468baa27e79f8fa
Author: twosom <[email protected]>
AuthorDate: Fri Jun 28 23:08:52 2024 +0900
Add spark mapstate (#31669)
* add isEmpty test in testMap
* add map state in spark runner
* update comment on SparkStateInternalsTest
* modify isEmpty test to assertFalse / assertTrue
---
.../beam/runners/core/StateInternalsTest.java | 6 +
.../spark/stateful/SparkStateInternals.java | 148 ++++++++++++++++++++-
.../spark/stateful/SparkStateInternalsTest.java | 10 +-
3 files changed, 152 insertions(+), 12 deletions(-)
diff --git
a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
index a4cd504eee7..e15249969f2 100644
---
a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
+++
b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
@@ -386,10 +386,16 @@ public abstract class StateInternalsTest {
value.entries().readLater().read(),
containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4),
MapEntry.of("E", 5)));
+ // isEmpty
+ assertFalse(value.isEmpty().read());
+
// clear
value.clear();
assertThat(value.entries().read(), Matchers.emptyIterable());
assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), equalTo(value));
+
+ // isEmpty
+ assertTrue(value.isEmpty().read());
}
@Test
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
index 3ad955e78aa..731cadb89f0 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
@@ -19,7 +19,11 @@ package org.apache.beam.runners.spark.stateful;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
@@ -28,12 +32,14 @@ import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.MapCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateContext;
@@ -44,6 +50,7 @@ import
org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CombineFnUtil;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;
@@ -119,11 +126,10 @@ class SparkStateInternals<K> implements StateInternals {
@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
- StateTag<MapState<KeyT, ValueT>> spec,
+ StateTag<MapState<KeyT, ValueT>> address,
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder) {
- throw new UnsupportedOperationException(
- String.format("%s is not supported",
MapState.class.getSimpleName()));
+ return new SparkMapState<>(namespace, address, MapCoder.of(mapKeyCoder,
mapValueCoder));
}
@Override
@@ -359,6 +365,142 @@ class SparkStateInternals<K> implements StateInternals {
}
}
+ private final class SparkMapState<MapKeyT, MapValueT>
+ extends AbstractState<Map<MapKeyT, MapValueT>> implements
MapState<MapKeyT, MapValueT> {
+
+ private SparkMapState(
+ StateNamespace namespace,
+ StateTag<? extends State> address,
+ Coder<Map<MapKeyT, MapValueT>> coder) {
+ super(namespace, address, coder);
+ }
+
+ @Override
+ public ReadableState<MapValueT> get(MapKeyT key) {
+ return getOrDefault(key, null);
+ }
+
+ @Override
+ public ReadableState<MapValueT> getOrDefault(MapKeyT key, @Nullable
MapValueT defaultValue) {
+ return new ReadableState<MapValueT>() {
+ @Override
+ public MapValueT read() {
+ Map<MapKeyT, MapValueT> sparkMapState = readValue();
+ if (sparkMapState == null) {
+ return defaultValue;
+ }
+ return sparkMapState.getOrDefault(key, defaultValue);
+ }
+
+ @Override
+ public ReadableState<MapValueT> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void put(MapKeyT key, MapValueT value) {
+ Map<MapKeyT, MapValueT> sparkMapState = readValue();
+ if (sparkMapState == null) {
+ sparkMapState = new HashMap<>();
+ }
+ sparkMapState.put(key, value);
+ writeValue(sparkMapState);
+ }
+
+ @Override
+ public ReadableState<MapValueT> computeIfAbsent(
+ MapKeyT key, Function<? super MapKeyT, ? extends MapValueT>
mappingFunction) {
+ Map<MapKeyT, MapValueT> sparkMapState = readValue();
+ MapValueT current = sparkMapState.get(key);
+ if (current == null) {
+ put(key, mappingFunction.apply(key));
+ }
+ return ReadableStates.immediate(current);
+ }
+
+ @Override
+ public void remove(MapKeyT key) {
+ Map<MapKeyT, MapValueT> sparkMapState = readValue();
+ sparkMapState.remove(key);
+ writeValue(sparkMapState);
+ }
+
+ @Override
+ public ReadableState<Iterable<MapKeyT>> keys() {
+ return new ReadableState<Iterable<MapKeyT>>() {
+ @Override
+ public Iterable<MapKeyT> read() {
+ Map<MapKeyT, MapValueT> sparkMapState = readValue();
+ if (sparkMapState == null) {
+ return Collections.emptyList();
+ }
+ return sparkMapState.keySet();
+ }
+
+ @Override
+ public ReadableState<Iterable<MapKeyT>> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public ReadableState<Iterable<MapValueT>> values() {
+ return new ReadableState<Iterable<MapValueT>>() {
+ @Override
+ public Iterable<MapValueT> read() {
+ Map<MapKeyT, MapValueT> sparkMapState = readValue();
+ if (sparkMapState == null) {
+ return Collections.emptyList();
+ }
+ Iterable<MapValueT> result = readValue().values();
+ return result != null ? ImmutableList.copyOf(result) :
Collections.emptyList();
+ }
+
+ @Override
+ public ReadableState<Iterable<MapValueT>> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>> entries() {
+ return new ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>>() {
+ @Override
+ public Iterable<Map.Entry<MapKeyT, MapValueT>> read() {
+ Map<MapKeyT, MapValueT> sparkMapState = readValue();
+ if (sparkMapState == null) {
+ return Collections.emptyList();
+ }
+ return sparkMapState.entrySet();
+ }
+
+ @Override
+ public ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>>
readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ return stateTable.get(namespace.stringKey(), address.getId()) ==
null;
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ };
+ }
+ }
+
private final class SparkBagState<T> extends AbstractState<List<T>>
implements BagState<T> {
private SparkBagState(StateNamespace namespace, StateTag<BagState<T>>
address, Coder<T> coder) {
super(namespace, address, ListCoder.of(coder));
diff --git
a/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
index 6118f96fcc3..f6f2b8d6df6 100644
---
a/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
+++
b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java
@@ -25,7 +25,7 @@ import org.junit.runners.JUnit4;
/**
* Tests for {@link SparkStateInternals}. This is based on {@link
StateInternalsTest}. Ignore set
- * and map tests.
+ * tests.
*/
@RunWith(JUnit4.class)
public class SparkStateInternalsTest extends StateInternalsTest {
@@ -51,15 +51,7 @@ public class SparkStateInternalsTest extends
StateInternalsTest {
@Ignore
public void testMergeSetIntoNewNamespace() {}
- @Override
- @Ignore
- public void testMap() {}
-
@Override
@Ignore
public void testSetReadable() {}
-
- @Override
- @Ignore
- public void testMapReadable() {}
}