This is an automated email from the ASF dual-hosted git repository.
aglinxinyuan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/texera.git
The following commit(s) were added to refs/heads/main by this push:
new d3d68efaf7 chore: simplify State as a JSON dictionary (#4488)
d3d68efaf7 is described below
commit d3d68efaf76dcfcc5c7faa293918f137b3755e09
Author: Xinyuan Lin <[email protected]>
AuthorDate: Thu Apr 30 23:10:24 2026 -0700
chore: simplify State as a JSON dictionary (#4488)
### What changes were proposed in this PR?
This PR simplifies State into a plain dictionary/map with a shared
JSON-based serialization format for Python and Scala/Java.
- replace the old custom State object semantics with a
dictionary/map-based design
- use JSON serialization/deserialization for cross-language state
transport
- update Python and Scala/Java state access/transport code to use the
new format
- support handling multiple state messages in the same stream
### Any related issues, documentation, discussions?
Closes #4487
### How was this PR tested?
- added a Python test for processing multiple states in
test_main_loop.py
### Was this PR authored or co-authored using generative AI tooling?
Generated by Codex.
---------
Signed-off-by: Xinyuan Lin <[email protected]>
Co-authored-by: Chen Li <[email protected]>
Co-authored-by: Xiaozhen Liu <[email protected]>
---
amber/src/main/python/core/models/state.py | 90 +++++++-------
amber/src/main/python/core/models/test_state.py | 101 ++++++++++++++++
.../main/python/core/runnables/data_processor.py | 2 +
.../main/python/core/runnables/network_receiver.py | 4 +-
.../main/python/core/runnables/network_sender.py | 13 +-
.../main/python/core/runnables/test_main_loop.py | 82 +++++++++----
.../python/core/runnables/test_network_receiver.py | 39 +++++-
.../pythonworker/PythonProxyServer.scala | 2 +-
.../org/apache/texera/amber/core/state/State.scala | 91 ++++++++++----
.../apache/texera/amber/core/state/StateSpec.scala | 131 +++++++++++++++++++++
.../amber/operator/ifStatement/IfOpExec.scala | 3 +-
11 files changed, 456 insertions(+), 102 deletions(-)
diff --git a/amber/src/main/python/core/models/state.py
b/amber/src/main/python/core/models/state.py
index feb35f2e27..003aaa212a 100644
--- a/amber/src/main/python/core/models/state.py
+++ b/amber/src/main/python/core/models/state.py
@@ -15,58 +15,60 @@
# specific language governing permissions and limitations
# under the License.
-from dataclasses import dataclass
-from pandas import DataFrame
-from pyarrow import Table
-from typing import Optional
+import base64
+import json
+from typing import Any
-from .schema import Schema, AttributeType
-from .schema.attribute_type import FROM_PYOBJECT_MAPPING
+from .schema import Schema
+from .tuple import Tuple
-@dataclass
-class State:
- def __init__(self, table: Optional[Table] = None):
- self.schema = Schema()
- if table is not None:
- self.__dict__.update(table.to_pandas().iloc[0].to_dict())
- self.schema = Schema(table.schema)
+class State(dict):
+ CONTENT = "content"
+ SCHEMA = Schema(raw_schema={CONTENT: "STRING"})
- def add(
- self, key: str, value: any, value_type: Optional[AttributeType] = None
- ) -> None:
- self.__dict__[key] = value
- if value_type is not None:
- self.schema.add(key, value_type)
- elif key != "schema":
- self.schema.add(key, FROM_PYOBJECT_MAPPING[type(value)])
+ def to_json(self) -> str:
+ return json.dumps(_to_json_value(self), separators=(",", ":"))
- def get(self, key: str) -> any:
- return self.__dict__[key]
+ def to_tuple(self) -> Tuple:
+ return Tuple({State.CONTENT: self.to_json()}, schema=State.SCHEMA)
- def to_table(self) -> Table:
- return Table.from_pandas(
- df=DataFrame([self.__dict__]),
- schema=self.schema.as_arrow_schema(),
- )
+ @classmethod
+ def from_json(cls, payload: str) -> "State":
+ return cls(_from_json_value(json.loads(payload)))
- def __setattr__(self, key: str, value: any) -> None:
- self.add(key, value)
+ @classmethod
+ def from_tuple(cls, row: Tuple) -> "State":
+ return cls.from_json(row[cls.CONTENT])
- def __setitem__(self, key: str, value: any) -> None:
- self.add(key, value)
- def __getitem__(self, key: str) -> any:
- return self.get(key)
+_TYPE_MARKER = "__texera_type__"
+_PAYLOAD_MARKER = "payload"
+_BYTES_TYPE = "bytes"
- def __str__(self) -> str:
- content = ", ".join(
- [
- repr(key) + ": " + repr(value)
- for key, value in self.__dict__.items()
- if key != "schema"
- ]
- )
- return f"State[{content}]"
- __repr__ = __str__
+def _to_json_value(value: Any) -> Any:
+ if value is None or isinstance(value, (bool, int, float, str)):
+ return value
+ if isinstance(value, bytes):
+ return {
+ _TYPE_MARKER: _BYTES_TYPE,
+ _PAYLOAD_MARKER: base64.b64encode(value).decode("ascii"),
+ }
+ if isinstance(value, dict):
+ return {str(key): _to_json_value(inner) for key, inner in
value.items()}
+ if isinstance(value, (list, tuple)):
+ return [_to_json_value(inner) for inner in value]
+ raise TypeError(
+ f"State value of type {type(value).__name__} is not JSON serializable"
+ )
+
+
+def _from_json_value(value: Any) -> Any:
+ if isinstance(value, list):
+ return [_from_json_value(inner) for inner in value]
+ if isinstance(value, dict):
+ if value.get(_TYPE_MARKER) == _BYTES_TYPE:
+ return base64.b64decode(value[_PAYLOAD_MARKER])
+ return {key: _from_json_value(inner) for key, inner in value.items()}
+ return value
diff --git a/amber/src/main/python/core/models/test_state.py
b/amber/src/main/python/core/models/test_state.py
new file mode 100644
index 0000000000..aef2297130
--- /dev/null
+++ b/amber/src/main/python/core/models/test_state.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+from core.models.state import State
+
+
+class TestState:
+ def test_state_subclasses_dict(self):
+ state = State({"a": 1})
+ assert isinstance(state, dict)
+ assert state["a"] == 1
+ assert State() == {}
+
+ def test_class_attributes(self):
+ assert State.CONTENT == "content"
+ assert State.SCHEMA.get_attr_names() == ["content"]
+
+ def test_json_round_trip_primitives(self):
+ original = State(
+ {
+ "string": "hello",
+ "int": 42,
+ "float": 3.14,
+ "bool_true": True,
+ "bool_false": False,
+ "none_value": None,
+ }
+ )
+ decoded = State.from_json(original.to_json())
+ assert decoded == original
+
+ def test_json_round_trip_empty(self):
+ assert State.from_json(State().to_json()) == State()
+
+ def test_json_round_trip_bytes(self):
+ original = State({"payload": b"\x00\x01\x02\xff"})
+ decoded = State.from_json(original.to_json())
+ assert decoded["payload"] == b"\x00\x01\x02\xff"
+ assert isinstance(decoded["payload"], bytes)
+
+ def test_json_round_trip_nested_dict(self):
+ original = State({"outer": {"inner": {"value": 1}}})
+ decoded = State.from_json(original.to_json())
+ assert decoded == original
+
+ def test_json_round_trip_list_of_mixed_values(self):
+ original = State({"items": [1, "two", 3.0, True, None]})
+ decoded = State.from_json(original.to_json())
+ assert decoded == original
+
+ def test_json_round_trip_bytes_inside_list_and_nested_dict(self):
+ original = State(
+ {
+ "blobs": [b"first", b"second"],
+ "nested": {"sub_blob": b"inside"},
+ }
+ )
+ decoded = State.from_json(original.to_json())
+ assert decoded["blobs"] == [b"first", b"second"]
+ assert decoded["nested"]["sub_blob"] == b"inside"
+
+ def test_to_json_rejects_non_serializable_value(self):
+ class Custom:
+ pass
+
+ with pytest.raises(TypeError):
+ State({"bad": Custom()}).to_json()
+
+ def test_tuple_round_trip(self):
+ original = State({"loop_counter": 3, "label": "outer", "blob":
b"\x01\x02"})
+ decoded = State.from_tuple(original.to_tuple())
+ assert decoded == original
+
+ def test_to_tuple_uses_state_schema(self):
+ tuple_ = State({"x": 1}).to_tuple()
+ # Single STRING column whose value is the JSON serialization.
+ assert tuple_[State.CONTENT] == '{"x":1}'
+
+ def test_nested_dict_decodes_to_plain_dict(self):
+ # Top-level returns a State; nested dicts come back as plain dict.
+ # This is intentional -- only the outermost mapping is wrapped.
+ decoded = State.from_json('{"outer":{"inner":1}}')
+ assert isinstance(decoded, State)
+ assert isinstance(decoded["outer"], dict)
+ assert not isinstance(decoded["outer"], State)
diff --git a/amber/src/main/python/core/runnables/data_processor.py
b/amber/src/main/python/core/runnables/data_processor.py
index 35d2a75d1d..276a1669f5 100644
--- a/amber/src/main/python/core/runnables/data_processor.py
+++ b/amber/src/main/python/core/runnables/data_processor.py
@@ -168,6 +168,8 @@ class DataProcessor(Runnable, Stoppable):
"""
Set the output state after processing by the executor.
"""
+ if output_state is not None and not isinstance(output_state, State):
+ output_state = State(output_state)
self._context.state_processing_manager.current_output_state =
output_state
def _switch_context(self) -> None:
diff --git a/amber/src/main/python/core/runnables/network_receiver.py
b/amber/src/main/python/core/runnables/network_receiver.py
index 5ab857c1c3..659cd65c78 100644
--- a/amber/src/main/python/core/runnables/network_receiver.py
+++ b/amber/src/main/python/core/runnables/network_receiver.py
@@ -32,6 +32,7 @@ from
core.architecture.handlers.actorcommand.credit_update_handler import (
)
from core.models import (
DataFrame,
+ State,
StateFrame,
)
from core.models.internal_queue import (
@@ -40,7 +41,6 @@ from core.models.internal_queue import (
InternalQueue,
ECMElement,
)
-from core.models.state import State
from core.proxy import ProxyServer
from core.util import Stoppable, get_one_of
from core.util.runnable.runnable import Runnable
@@ -96,7 +96,7 @@ class NetworkReceiver(Runnable, Stoppable):
"Data",
lambda _: DataFrame(table),
"State",
- lambda _: StateFrame(State(table)),
+ lambda _:
StateFrame(State.from_json(table[State.CONTENT][0].as_py())),
"ECM",
lambda _:
EmbeddedControlMessage().parse(table["payload"][0].as_py()),
)
diff --git a/amber/src/main/python/core/runnables/network_sender.py
b/amber/src/main/python/core/runnables/network_sender.py
index 9595433fb7..d8e3889ac1 100644
--- a/amber/src/main/python/core/runnables/network_sender.py
+++ b/amber/src/main/python/core/runnables/network_sender.py
@@ -20,7 +20,7 @@ from loguru import logger
from overrides import overrides
from typing import Optional
-from core.models import DataPayload, InternalQueue, DataFrame, StateFrame,
State
+from core.models import DataPayload, InternalQueue, DataFrame, State,
StateFrame
from core.models.internal_queue import (
InternalQueueElement,
DataElement,
@@ -98,13 +98,10 @@ class NetworkSender(StoppableQueueBlockingRunnable):
data_header = PythonDataHeader(tag=to, payload_type="Data")
self._proxy_client.send_data(bytes(data_header),
data_payload.frame)
elif isinstance(data_payload, StateFrame):
- data_header = PythonDataHeader(
- tag=to, payload_type=data_payload.frame.__class__.__name__
- )
- table = (
- data_payload.frame.to_table()
- if isinstance(data_payload.frame, State)
- else None
+ data_header = PythonDataHeader(tag=to, payload_type="State")
+ table = pa.Table.from_pydict(
+ {State.CONTENT: [data_payload.frame.to_json()]},
+ schema=State.SCHEMA.as_arrow_schema(),
)
self._proxy_client.send_data(bytes(data_header), table)
else:
diff --git a/amber/src/main/python/core/runnables/test_main_loop.py
b/amber/src/main/python/core/runnables/test_main_loop.py
index cc6969d964..c9daa633f5 100644
--- a/amber/src/main/python/core/runnables/test_main_loop.py
+++ b/amber/src/main/python/core/runnables/test_main_loop.py
@@ -166,8 +166,7 @@ class TestMainLoop:
def mock_state_data_elements(self, mock_data_input_channel):
elements = []
for value in (1, 2, 3, 4):
- state = State()
- state.add("value", value)
+ state = State({"value": value})
elements.append(
DataElement(
tag=mock_data_input_channel,
@@ -189,19 +188,16 @@ class TestMainLoop:
@staticmethod
def process_state(state: State, port: int) -> State:
- new_state = State()
- for key, value in state.__dict__.items():
- if key != "schema":
- new_state.add(key, value)
- new_state.add("processed_marker", "executed")
- new_state.add("port", port)
+ new_state = State(
+ {key: value for key, value in state.items() if key !=
"schema"}
+ )
+ new_state["processed_marker"] = "executed"
+ new_state["port"] = port
return new_state
@staticmethod
def produce_state_on_finish(port: int) -> State:
- finish_state = State()
- finish_state.add("finish_marker",
"produce_state_on_finish_ran")
- return finish_state
+ return State({"finish_marker": "produce_state_on_finish_ran"})
@staticmethod
def on_finish(port):
@@ -1131,6 +1127,57 @@ class TestMainLoop:
),
)
+ @pytest.mark.timeout(2)
+ def test_process_state_can_emit_consecutive_states(
+ self,
+ main_loop,
+ output_queue,
+ mock_data_output_channel,
+ monkeypatch,
+ ):
+ class DummyExecutor:
+ @staticmethod
+ def process_state(state, port: int):
+ return State({"value": state["value"] + 1, "port": port})
+
+ main_loop.context.executor_manager.executor = DummyExecutor()
+ monkeypatch.setattr(main_loop, "_check_and_process_control", lambda:
None)
+ monkeypatch.setattr(
+ main_loop.context.output_manager,
+ "emit_state",
+ lambda state: [(mock_data_output_channel.to_worker_id,
StateFrame(state))],
+ )
+
+ def fake_switch_context():
+ current_input_state = (
+ main_loop.context.state_processing_manager.current_input_state
+ )
+ if current_input_state is not None:
+
main_loop.context.state_processing_manager.current_output_state = (
+ DummyExecutor.process_state(current_input_state, 0)
+ )
+
+ monkeypatch.setattr(main_loop, "_switch_context", fake_switch_context)
+
+ first_state = State({"value": 1})
+ second_state = State({"value": 41})
+
+ main_loop._process_state(first_state)
+ main_loop._process_state(second_state)
+
+ first_output: DataElement = output_queue.get()
+ second_output: DataElement = output_queue.get()
+
+ assert first_output.tag == mock_data_output_channel
+ assert isinstance(first_output.payload, StateFrame)
+ assert first_output.payload.frame["value"] == 2
+ assert first_output.payload.frame["port"] == 0
+
+ assert second_output.tag == mock_data_output_channel
+ assert isinstance(second_output.payload, StateFrame)
+ assert second_output.payload.frame["value"] == 42
+ assert second_output.payload.frame["port"] == 0
+
@pytest.mark.timeout(5)
def test_main_loop_thread_can_align_ecm(
self,
@@ -1301,10 +1348,7 @@ class TestMainLoop:
class DummyExecutor:
@staticmethod
def process_state(state: State, port: int) -> State:
- output_state = State()
- output_state.add("value", state["value"] + 1)
- output_state.add("port", port)
- return output_state
+ return State({"value": state["value"] + 1, "port": port})
main_loop.context.executor_manager.executor = DummyExecutor()
monkeypatch.setattr(main_loop, "_check_and_process_control", lambda:
None)
@@ -1325,10 +1369,8 @@ class TestMainLoop:
monkeypatch.setattr(main_loop, "_switch_context", fake_switch_context)
- first_state = State()
- first_state.add("value", 1)
- second_state = State()
- second_state.add("value", 41)
+ first_state = State({"value": 1})
+ second_state = State({"value": 41})
main_loop._process_state(first_state)
main_loop._process_state(second_state)
@@ -1443,7 +1485,7 @@ class TestMainLoop:
f"{type(end_channel_state_output.payload).__name__}"
)
end_channel_state = end_channel_state_output.payload.frame
- assert "finish_marker" in end_channel_state.__dict__, (
+ assert "finish_marker" in end_channel_state, (
f"EndChannel emission should be the finish-marker state from "
f"produce_state_on_finish, got {end_channel_state!r}"
)
diff --git a/amber/src/main/python/core/runnables/test_network_receiver.py
b/amber/src/main/python/core/runnables/test_network_receiver.py
index 2cc2541f2d..bf890e4a2f 100644
--- a/amber/src/main/python/core/runnables/test_network_receiver.py
+++ b/amber/src/main/python/core/runnables/test_network_receiver.py
@@ -25,7 +25,8 @@ from core.models.internal_queue import (
DataElement,
ECMElement,
)
-from core.models.payload import DataFrame
+from core.models.payload import DataFrame, StateFrame
+from core.models.state import State
from core.proxy import ProxyClient
from core.runnables.network_receiver import NetworkReceiver
from core.runnables.network_sender import NetworkSender
@@ -139,6 +140,42 @@ class TestNetworkReceiver:
assert len(element.payload.frame) == len(data_payload.frame)
assert element.tag == channel_id
+ @pytest.mark.timeout(10)
+ def test_network_receiver_can_receive_consecutive_state_messages(
+ self,
+ output_queue,
+ input_queue,
+ network_receiver,
+ network_sender_thread,
+ ):
+ network_sender_thread.start()
+ worker_id = ActorVirtualIdentity(name="test")
+ channel_id = ChannelIdentity(worker_id, worker_id, False)
+
+ input_queue.put(
+ DataElement(
+ tag=channel_id,
+ payload=StateFrame(State({"loop_counter": 0, "i": 1})),
+ )
+ )
+ input_queue.put(
+ DataElement(
+ tag=channel_id,
+ payload=StateFrame(State({"loop_counter": 1, "i": 2})),
+ )
+ )
+
+ first_element: DataElement = output_queue.get()
+ second_element: DataElement = output_queue.get()
+
+ assert isinstance(first_element.payload, StateFrame)
+ assert first_element.payload.frame == {"loop_counter": 0, "i": 1}
+ assert first_element.tag == channel_id
+
+ assert isinstance(second_element.payload, StateFrame)
+ assert second_element.payload.frame == {"loop_counter": 1, "i": 2}
+ assert second_element.tag == channel_id
+
@pytest.mark.timeout(10)
def test_network_receiver_can_receive_control_messages(
self,
diff --git
a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala
b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala
index c904e436bc..2ff866365b 100644
---
a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala
+++
b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala
@@ -128,7 +128,7 @@ private class AmberProducer(
dataHeader.payloadType match {
case "State" =>
assert(root.getRowCount == 1)
- outputPort.sendTo(to,
StateFrame(State(Some(ArrowUtils.getTexeraTuple(0, root)))))
+ outputPort.sendTo(to,
StateFrame(State.fromTuple(ArrowUtils.getTexeraTuple(0, root))))
case "ECM" =>
assert(root.getRowCount == 1)
outputPort.sendTo(
diff --git
a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
index 2b3465473b..ba146f1d57 100644
---
a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
+++
b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
@@ -19,36 +19,77 @@
package org.apache.texera.amber.core.state
+import com.fasterxml.jackson.databind.JsonNode
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema,
Tuple}
+import org.apache.texera.amber.util.JSONUtils.objectMapper
-import scala.collection.mutable
+import java.util.Base64
+import scala.jdk.CollectionConverters.IteratorHasAsScala
-final case class State(tuple: Option[Tuple] = None) {
- val data: mutable.Map[String, (AttributeType, Any)] = mutable.LinkedHashMap()
- if (tuple.isDefined) {
- tuple.get.getSchema.getAttributes.foreach { attribute =>
- add(attribute.getName, tuple.get.getField(attribute.getName),
attribute.getType)
- }
- }
+final case class State(values: Map[String, Any]) {
+
+ def toJson: String =
+ objectMapper.writeValueAsString(State.toJsonValue(values))
- def add(key: String, value: Any, valueType: AttributeType): Unit =
- data.put(key, (valueType, value))
+ def toTuple: Tuple =
+ Tuple.builder(State.schema).addSequentially(Array(toJson)).build()
+}
- def get(key: String): Any = data(key)._2
+object State {
+ private val Content = "content"
+ private val BytesTypeMarker = "__texera_type__"
+ private val BytesValue = "bytes"
+ private val PayloadMarker = "payload"
- def apply(key: String): Any = get(key)
+ val schema: Schema = new Schema(
+ new Attribute(Content, AttributeType.STRING)
+ )
- def toTuple: Tuple =
- Tuple
- .builder(
- Schema(data.map {
- case (name, (attrType, _)) =>
- new Attribute(name, attrType)
- }.toList)
- )
- .addSequentially(data.values.map(_._2).toArray)
- .build()
-
- override def toString: String =
- data.map { case (key, (_, value)) => s"$key: $value" }.mkString(", ")
+ def fromJson(payload: String): State =
+ State(
+ objectMapper
+ .readTree(payload)
+ .fields()
+ .asScala
+ .map(entry => entry.getKey -> fromJsonValue(entry.getValue))
+ .toMap
+ )
+
+ def fromTuple(row: Tuple): State = fromJson(row.getField[String](Content))
+
+ private def toJsonValue(value: Any): Any =
+ value match {
+ case null => null
+ case bytes: Array[Byte] =>
+ Map(BytesTypeMarker -> BytesValue, PayloadMarker ->
Base64.getEncoder.encodeToString(bytes))
+ case map: Map[?, ?] =>
+ map.iterator.map { case (k, v) => k -> toJsonValue(v) }.toMap
+ case iterable: Iterable[_] =>
+ iterable.map(toJsonValue).toList
+ case other => other
+ }
+
+ private def fromJsonValue(node: JsonNode): Any = {
+ if (node == null || node.isNull) {
+ null
+ } else if (node.isObject) {
+ val fields = node.fields().asScala.map(entry => entry.getKey ->
entry.getValue).toMap
+ fields.get(BytesTypeMarker) match {
+ case Some(typeNode) if typeNode.isTextual && typeNode.asText() ==
BytesValue =>
+ Base64.getDecoder.decode(fields(PayloadMarker).asText())
+ case _ =>
+ fields.view.mapValues(fromJsonValue).toMap
+ }
+ } else if (node.isArray) {
+ node.elements().asScala.map(fromJsonValue).toList
+ } else if (node.isBoolean) {
+ node.asBoolean()
+ } else if (node.isIntegralNumber) {
+ node.longValue()
+ } else if (node.isFloatingPointNumber) {
+ node.doubleValue()
+ } else {
+ node.asText()
+ }
+ }
}
diff --git
a/common/workflow-core/src/test/scala/org/apache/texera/amber/core/state/StateSpec.scala
b/common/workflow-core/src/test/scala/org/apache/texera/amber/core/state/StateSpec.scala
new file mode 100644
index 0000000000..976a585e31
--- /dev/null
+++
b/common/workflow-core/src/test/scala/org/apache/texera/amber/core/state/StateSpec.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.texera.amber.core.state
+
+import org.scalatest.flatspec.AnyFlatSpec
+
+class StateSpec extends AnyFlatSpec {
+
+ "State" should "json-round-trip an empty state" in {
+ val original = State(Map.empty)
+ assert(State.fromJson(original.toJson) == original)
+ }
+
+ it should "json-round-trip primitive values" in {
+ val original = State(
+ Map(
+ "string" -> "hello",
+ "long" -> 42L,
+ "double" -> 3.14,
+ "bool_true" -> true,
+ "bool_false" -> false
+ )
+ )
+ val decoded = State.fromJson(original.toJson)
+ assert(decoded.values("string") == "hello")
+ assert(decoded.values("long") == 42L)
+ assert(decoded.values("double") == 3.14)
+ assert(decoded.values("bool_true") == true)
+ assert(decoded.values("bool_false") == false)
+ }
+
+ it should "drop null entries during JSON serialization" in {
+ // The shared `objectMapper` is configured with `Include.NON_NULL`, so
+ // null values are stripped before they hit the wire. Document the
+ // behavior here so callers know they cannot transport an explicit null
+ // through a State -- Python's serializer keeps nulls but Scala does not.
+ val original = State(Map("present" -> "value", "absent" -> null))
+ val decoded = State.fromJson(original.toJson)
+ assert(decoded.values.keySet == Set("present"))
+ assert(decoded.values("present") == "value")
+ }
+
+ it should "json-round-trip byte arrays via the bytes type marker" in {
+ val payload = Array[Byte](0, 1, 2, -1)
+ val original = State(Map("payload" -> payload))
+ val decoded = State.fromJson(original.toJson)
+ val decodedBytes = decoded.values("payload").asInstanceOf[Array[Byte]]
+ assert(decodedBytes.sameElements(payload))
+ }
+
+ it should "json-round-trip nested maps" in {
+ val original = State(Map("outer" -> Map("inner" -> Map("value" -> 1L))))
+ val decoded = State.fromJson(original.toJson)
+ assert(decoded == original)
+ }
+
+ it should "json-round-trip lists of mixed values" in {
+ val original = State(Map("items" -> List(1L, "two", 3.0, true, null)))
+ val decoded = State.fromJson(original.toJson)
+ assert(decoded == original)
+ }
+
+ it should "json-round-trip byte arrays nested inside lists and maps" in {
+ val original = State(
+ Map(
+ "blobs" -> List(Array[Byte](1, 2), Array[Byte](3, 4)),
+ "nested" -> Map("sub_blob" -> Array[Byte](5, 6))
+ )
+ )
+ val decoded = State.fromJson(original.toJson)
+ val blobs = decoded.values("blobs").asInstanceOf[List[Array[Byte]]]
+ assert(blobs.head.sameElements(Array[Byte](1, 2)))
+ assert(blobs(1).sameElements(Array[Byte](3, 4)))
+ val subBlob = decoded.values
+ .apply("nested")
+ .asInstanceOf[Map[String, Any]]("sub_blob")
+ .asInstanceOf[Array[Byte]]
+ assert(subBlob.sameElements(Array[Byte](5, 6)))
+ }
+
+ it should "tuple-round-trip" in {
+ val original = State(
+ Map(
+ "loop_counter" -> 3L,
+ "label" -> "outer",
+ "blob" -> Array[Byte](1, 2)
+ )
+ )
+ val decoded = State.fromTuple(original.toTuple)
+ assert(decoded.values("loop_counter") == 3L)
+ assert(decoded.values("label") == "outer")
+ assert(
+
decoded.values("blob").asInstanceOf[Array[Byte]].sameElements(Array[Byte](1, 2))
+ )
+ }
+
+ it should "produce a tuple whose payload is the JSON serialization" in {
+ val tuple = State(Map("x" -> 1L)).toTuple
+ assert(tuple.getSchema == State.schema)
+ assert(tuple.getField[String]("content") == """{"x":1}""")
+ }
+
+ it should "decode a payload encoded by the Python serializer" in {
+ // Wire-format compatibility check: the bytes-marker keys and the
+ // single-row "content" column must match what core/models/state.py
+ // emits, otherwise cross-language transport breaks.
+ val pythonEmitted =
"""{"i":2,"blob":{"__texera_type__":"bytes","payload":"AQID"}}"""
+ val decoded = State.fromJson(pythonEmitted)
+ assert(decoded.values("i") == 2L)
+ assert(
+
decoded.values("blob").asInstanceOf[Array[Byte]].sameElements(Array[Byte](1, 2,
3))
+ )
+ }
+}
diff --git
a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/ifStatement/IfOpExec.scala
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/ifStatement/IfOpExec.scala
index 462bdd0969..4634ad1c18 100644
---
a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/ifStatement/IfOpExec.scala
+++
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/ifStatement/IfOpExec.scala
@@ -34,7 +34,8 @@ class IfOpExec(descString: String) extends OperatorExecutor {
//It can accept any value that can be converted to a boolean. For example,
Int 1 will be converted to true.
override def processState(state: State, port: Int): Option[State] = {
outputPort =
- if (state.get(desc.conditionName).asInstanceOf[Boolean]) PortIdentity(1)
else PortIdentity()
+ if (state.values(desc.conditionName).asInstanceOf[Boolean])
PortIdentity(1)
+ else PortIdentity()
Some(state)
}