This is an automated email from the ASF dual-hosted git repository.

ebenizzy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git


The following commit(s) were added to refs/heads/main by this push:
     new ac782937 Pool based async persister (#535)
ac782937 is described below

commit ac7829370172f681f9a648bbf6fe196b1e52bf8e
Author: gamarin2 <[email protected]>
AuthorDate: Mon Jul 21 06:09:13 2025 +0200

    Pool based async persister (#535)
    
    * pool based asyncpg persister
    
    * docs
    
    * remove ruff formatting to avoid diff
    
    * fix docstring
    
    * hook fixes
    
    ---------
    
    Co-authored-by: Gautier MARIN <[email protected]>
---
 burr/core/parallelism.py                  |   2 +-
 burr/integrations/persisters/b_asyncpg.py | 322 +++++++++++++++++++++---------
 docs/concepts/actions.rst                 |  16 ++
 docs/concepts/parallelism.rst             | 102 +++++++++-
 docs/concepts/sync-vs-async.rst           |  25 +++
 5 files changed, 367 insertions(+), 100 deletions(-)

diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py
index 91c8f30a..25c9e6f2 100644
--- a/burr/core/parallelism.py
+++ b/burr/core/parallelism.py
@@ -875,4 +875,4 @@ def map_reduce_action(
     """Experimental API for creating a map-reduce action easily. We'll be 
improving this."""
     return PassThroughMapActionsAndStates(
         action=action, state=state, reducer=reducer, reads=reads, 
writes=writes, inputs=inputs
-    )
+    )
\ No newline at end of file
diff --git a/burr/integrations/persisters/b_asyncpg.py 
b/burr/integrations/persisters/b_asyncpg.py
index d6f788f4..592d6c29 100644
--- a/burr/integrations/persisters/b_asyncpg.py
+++ b/burr/integrations/persisters/b_asyncpg.py
@@ -1,20 +1,27 @@
+import json
+import logging
+from typing import Literal, Optional, ClassVar
+from typing import Any
+from burr.common.types import BaseCopyable
+from burr.core import persistence, state
 from burr.integrations import base
 
+
 try:
     import asyncpg
 except ImportError as e:
     base.require_plugin(e, "asyncpg")
 
-import json
-import logging
-from typing import Literal, Optional
+try:
+    from typing import Self
+except ImportError:
+    Self = Any
 
-from burr.core import persistence, state
 
 logger = logging.getLogger(__name__)
 
 
-class AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
+class AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister, 
BaseCopyable):
     """Class for async PostgreSQL persistence of state.
 
     .. warning::
@@ -24,6 +31,11 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
         We suggest to use the persister either as a context manager through 
the ``async with`` clause or
         using the method ``.cleanup()``.
 
+    .. warning::
+        If you intend to use parallelism features or need to share this 
persister across multiple tasks,
+        you should initialize it with a connection pool (set ``use_pool=True`` 
in ``from_values``).
+        Direct connections cannot be shared across different tasks and may 
cause errors in concurrent scenarios.
+
     .. note::
         The implementation relies on the popular asyncpg library: 
https://github.com/MagicStack/asyncpg
 
@@ -35,10 +47,10 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
 
     .. code:: bash
 
-        docker run --name local-psql \  # container name
-                   -v local_psql_data:/SOME/FILE_PATH/ \  # mounting a volume 
for data persistence
-                   -p 54320:5432 \  # port mapping
-                   -e POSTGRES_PASSWORD=my_password \  # superuser password
+        docker run --name local-psql \\\\  # container name
+                   -v local_psql_data:/SOME/FILE_PATH/ \\\\  # mounting a 
volume for data persistence
+                   -p 54320:5432 \\\\  # port mapping
+                   -e POSTGRES_PASSWORD=my_password \\\\  # superuser password
                    -d postgres  # database name
 
     Then you should be able to create the class like this:
@@ -48,11 +60,35 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
         p = await AsyncPostgreSQLPersister.from_values("postgres", "postgres", 
"my_password",
                                            "localhost", 54320, 
table_name="burr_state")
 
-
     """
 
     PARTITION_KEY_DEFAULT = ""
 
+    # Class variable to hold the connection pool
+    _pool: ClassVar[Optional[asyncpg.Pool]] = None
+
+    @classmethod
+    async def create_pool(
+        cls,
+        user: str,
+        password: str,
+        database: str,
+        host: str,
+        port: int,
+        **pool_kwargs,
+    ) -> asyncpg.Pool:
+        """Creates a connection pool that can be shared across persisters."""
+        if cls._pool is None:
+            cls._pool = await asyncpg.create_pool(
+                user=user,
+                password=password,
+                database=database,
+                host=host,
+                port=port,
+                **pool_kwargs,
+            )
+        return cls._pool
+
     @classmethod
     async def from_config(cls, config: dict) -> "AsyncPostgreSQLPersister":
         """Creates a new instance of the PostgreSQLPersister from a 
configuration dictionary."""
@@ -67,6 +103,8 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
         host: str,
         port: int,
         table_name: str = "burr_state",
+        use_pool: bool = False,
+        **pool_kwargs,
     ) -> "AsyncPostgreSQLPersister":
         """Builds a new instance of the PostgreSQLPersister from the provided 
values.
 
@@ -76,55 +114,119 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
         :param host: the host of the PostgreSQL database.
         :param port: the port of the PostgreSQL database.
         :param table_name:  the table name to store things under.
+        :param use_pool: whether to use a connection pool (True) or a direct 
connection (False)
+        :param pool_kwargs: additional kwargs to pass to the pool creation
         """
-        connection = await asyncpg.connect(
-            user=user, password=password, database=db_name, host=host, 
port=port
-        )
-        return cls(connection, table_name)
+        if use_pool:
+            pool = await cls.create_pool(
+                user=user,
+                password=password,
+                database=db_name,
+                host=host,
+                port=port,
+                **pool_kwargs,
+            )
+            return cls(connection=None, pool=pool, table_name=table_name)
+        else:
+            # Original behavior - direct connection
+            connection = await asyncpg.connect(
+                user=user, password=password, database=db_name, host=host, 
port=port
+            )
+            return cls(connection=connection, table_name=table_name)
 
-    def __init__(self, connection, table_name: str = "burr_state", 
serde_kwargs: dict = None):
+    def __init__(
+        self,
+        connection=None,
+        pool=None,
+        table_name: str = "burr_state",
+        serde_kwargs: dict = None,
+    ):
         """Constructor
 
-        :param connection: the connection to the PostgreSQL database.
+        :param connection: the connection to the PostgreSQL database (optional 
if pool is provided)
+        :param pool: a connection pool to use instead of a direct connection 
(optional if connection is provided)
         :param table_name:  the table name to store things under.
+        :param serde_kwargs: kwargs for state serialization/deserialization
         """
+        if connection is None and pool is None:
+            raise ValueError("Either connection or pool must be provided")
+
         self.table_name = table_name
         self.connection = connection
+        self.pool = pool
         self.serde_kwargs = serde_kwargs or {}
         self._initialized = False
 
+    def copy(self) -> "Self":
+        """Creates a copy of this persister.
+
+        If using a pool, returns a new persister that will acquire its own 
connection from the pool.
+        If using a direct connection, just returns a new persister with the 
same connection (won't work for async parallelism)
+        """
+        if self.pool is not None:
+            return AsyncPostgreSQLPersister(
+                connection=None,
+                pool=self.pool,
+                table_name=self.table_name,
+                serde_kwargs=self.serde_kwargs,
+            )
+        else:
+            return AsyncPostgreSQLPersister(
+                connection=self.connection,
+                table_name=self.table_name,
+                serde_kwargs=self.serde_kwargs,
+            )
+
     async def __aenter__(self):
         return self
 
     async def __aexit__(self, exc_type, exc_value, traceback):
-        await self.connection.close()
+        await self.cleanup()
         return False
 
+    async def _get_connection(self):
+        """Gets a connection - either the dedicated one or one from the 
pool."""
+        if self.pool is not None:
+            return await self.pool.acquire(), True
+        elif self.connection is not None:
+            return self.connection, False
+        else:
+            raise ValueError("No connection or pool available")
+
+    async def _release_connection(self, connection, acquired):
+        """Releases a connection back to the pool if it was acquired."""
+        if acquired and self.pool is not None:
+            await self.pool.release(connection)
+
     def set_serde_kwargs(self, serde_kwargs: dict):
         """Sets the serde_kwargs for the persister."""
         self.serde_kwargs = serde_kwargs
 
     async def create_table(self, table_name: str):
         """Helper function to create the table where things are stored."""
-        async with self.connection.transaction():
-            await self.connection.execute(
-                f"""
-                CREATE TABLE IF NOT EXISTS {table_name} (
-                    partition_key TEXT DEFAULT '{self.PARTITION_KEY_DEFAULT}',
-                    app_id TEXT NOT NULL,
-                    sequence_id INTEGER NOT NULL,
-                    position TEXT NOT NULL,
-                    status TEXT NOT NULL,
-                    state JSONB NOT NULL,
-                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-                    PRIMARY KEY (partition_key, app_id, sequence_id, position)
-                )"""
-            )
-            await self.connection.execute(
-                f"""
-                CREATE INDEX IF NOT EXISTS {table_name}_created_at_index ON 
{table_name} (created_at);
-            """
-            )
+        conn, acquired = await self._get_connection()
+        try:
+            async with conn.transaction():
+                await conn.execute(
+                    f"""
+                    CREATE TABLE IF NOT EXISTS {table_name} (
+                        partition_key TEXT DEFAULT 
'{self.PARTITION_KEY_DEFAULT}',
+                        app_id TEXT NOT NULL,
+                        sequence_id INTEGER NOT NULL,
+                        position TEXT NOT NULL,
+                        status TEXT NOT NULL,
+                        state JSONB NOT NULL,
+                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                        PRIMARY KEY (partition_key, app_id, sequence_id, 
position)
+                    )"""
+                )
+                await conn.execute(
+                    f"""
+                    CREATE INDEX IF NOT EXISTS {table_name}_created_at_index 
ON {table_name} (created_at);
+                """
+                )
+        finally:
+            await self._release_connection(conn, acquired)
 
     async def initialize(self):
         """Creates the table"""
@@ -139,23 +241,35 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
         if self._initialized:
             return True
 
-        query = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE 
table_name = $1)"
-        self._initialized = await self.connection.fetchval(query, 
self.table_name, column=0)
-        return self._initialized
+        conn, acquired = await self._get_connection()
+        try:
+            query = "SELECT EXISTS (SELECT FROM information_schema.tables 
WHERE table_name = $1)"
+            self._initialized = await conn.fetchval(query, self.table_name, 
column=0)
+            return self._initialized
+        finally:
+            await self._release_connection(conn, acquired)
 
     async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
         """Lists the app_ids for a given partition_key."""
-        query = (
-            f"SELECT DISTINCT app_id, created_at FROM {self.table_name} "
-            "WHERE partition_key = $1 "
-            "ORDER BY created_at DESC"
-        )
-        fetched_data = await self.connection.fetch(query, partition_key)
-        app_ids = [row[0] for row in fetched_data]
-        return app_ids
+        conn, acquired = await self._get_connection()
+        try:
+            query = (
+                f"SELECT DISTINCT app_id, created_at FROM {self.table_name} "
+                "WHERE partition_key = $1 "
+                "ORDER BY created_at DESC"
+            )
+            fetched_data = await conn.fetch(query, partition_key)
+            app_ids = [row[0] for row in fetched_data]
+            return app_ids
+        finally:
+            await self._release_connection(conn, acquired)
 
     async def load(
-        self, partition_key: Optional[str], app_id: str, sequence_id: int = 
None, **kwargs
+        self,
+        partition_key: Optional[str],
+        app_id: str,
+        sequence_id: int = None,
+        **kwargs,
     ) -> Optional[persistence.PersistedStateData]:
         """Loads state for a given partition id.
 
@@ -171,47 +285,53 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
             partition_key = self.PARTITION_KEY_DEFAULT
         logger.debug("Loading %s, %s, %s", partition_key, app_id, sequence_id)
 
-        if app_id is None:
-            # get latest for all app_ids
-            query = (
-                f"SELECT position, state, sequence_id, app_id, created_at, 
status FROM {self.table_name} "
-                "WHERE partition_key = $1 "
-                f"ORDER BY CREATED_AT DESC LIMIT 1"
-            )
-            row = await self.connection.fetchrow(query, partition_key)
-
-        elif sequence_id is None:
-            query = (
-                f"SELECT position, state, sequence_id, app_id, created_at, 
status FROM {self.table_name} "
-                "WHERE partition_key = $1 AND app_id = $2 "
-                f"ORDER BY sequence_id DESC LIMIT 1"
-            )
-            row = await self.connection.fetchrow(query, partition_key, app_id)
-        else:
-            query = (
-                f"SELECT position, state, sequence_id, app_id, created_at, 
status FROM {self.table_name} "
-                "WHERE partition_key = $1 AND app_id = $2 AND sequence_id = $3 
"
-            )
-            row = await self.connection.fetchrow(
-                query,
-                partition_key,
-                app_id,
-                sequence_id,
-            )
-        if row is None:
-            return None
-        # converts from asyncpg str to dict
-        json_row = json.loads(row[1])
-        _state = state.State.deserialize(json_row, **self.serde_kwargs)
-        return {
-            "partition_key": partition_key,
-            "app_id": row[3],
-            "sequence_id": row[2],
-            "position": row[0],
-            "state": _state,
-            "created_at": row[4],
-            "status": row[5],
-        }
+        conn, acquired = await self._get_connection()
+        try:
+            row = None
+            if app_id is None:
+                # get latest for all app_ids
+                query = (
+                    f"SELECT position, state, sequence_id, app_id, created_at, 
status FROM {self.table_name} "
+                    "WHERE partition_key = $1 "
+                    f"ORDER BY CREATED_AT DESC LIMIT 1"
+                )
+                row = await conn.fetchrow(query, partition_key)
+            elif sequence_id is None:
+                query = (
+                    f"SELECT position, state, sequence_id, app_id, created_at, 
status FROM {self.table_name} "
+                    "WHERE partition_key = $1 AND app_id = $2 "
+                    f"ORDER BY sequence_id DESC LIMIT 1"
+                )
+                row = await conn.fetchrow(query, partition_key, app_id)
+            else:
+                query = (
+                    f"SELECT position, state, sequence_id, app_id, created_at, 
status FROM {self.table_name} "
+                    "WHERE partition_key = $1 AND app_id = $2 AND sequence_id 
= $3 "
+                )
+                row = await conn.fetchrow(
+                    query,
+                    partition_key,
+                    app_id,
+                    sequence_id,
+                )
+
+            if row is None:
+                return None
+
+            # converts from asyncpg str to dict
+            json_row = json.loads(row[1])
+            _state = state.State.deserialize(json_row, **self.serde_kwargs)
+            return {
+                "partition_key": partition_key,
+                "app_id": row[3],
+                "sequence_id": row[2],
+                "position": row[0],
+                "state": _state,
+                "created_at": row[4],
+                "status": row[5],
+            }
+        finally:
+            await self._release_connection(conn, acquired)
 
     async def save(
         self,
@@ -250,15 +370,21 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
             status,
         )
 
-        json_state = json.dumps(state.serialize(**self.serde_kwargs))
-        query = (
-            f"INSERT INTO {self.table_name} (partition_key, app_id, 
sequence_id, position, state, status) "
-            "VALUES ($1, $2, $3, $4, $5, $6)"
-        )
-        await self.connection.execute(
-            query, partition_key, app_id, sequence_id, position, json_state, 
status
-        )
+        conn, acquired = await self._get_connection()
+        try:
+            json_state = json.dumps(state.serialize(**self.serde_kwargs))
+            query = (
+                f"INSERT INTO {self.table_name} (partition_key, app_id, 
sequence_id, position, state, status) "
+                "VALUES ($1, $2, $3, $4, $5, $6)"
+            )
+            await conn.execute(
+                query, partition_key, app_id, sequence_id, position, 
json_state, status
+            )
+        finally:
+            await self._release_connection(conn, acquired)
 
     async def cleanup(self):
         """Closes the connection to the database."""
-        await self.connection.close()
+        if self.connection is not None:
+            await self.connection.close()
+            self.connection = None
diff --git a/docs/concepts/actions.rst b/docs/concepts/actions.rst
index 541cea83..2316fae3 100644
--- a/docs/concepts/actions.rst
+++ b/docs/concepts/actions.rst
@@ -15,6 +15,22 @@ Actions do the heavy-lifting in a workflow. They should 
contain all complex comp
 either through a class-based or function-based API. If actions implement 
``async def run`` then will be run in an
 asynchronous context (and thus require one of the async application functions).
 
+.. note::
+    When implementing asynchronous actions with ``async def run``, you must 
also override the ``is_async`` method
+    to return ``True``. This tells the framework to execute the action in an 
asynchronous context:
+
+    .. code-block:: python
+
+        class AsyncAction(Action):
+            @property
+            def is_async(self) -> bool:
+                return True
+
+            async def run(self, state: State) -> dict:
+                # Async implementation
+                ...
+
+
 Actions have two primary responsibilities:
 
 1. ``run`` -- compute a result
diff --git a/docs/concepts/parallelism.rst b/docs/concepts/parallelism.rst
index ba7fa455..b878921f 100644
--- a/docs/concepts/parallelism.rst
+++ b/docs/concepts/parallelism.rst
@@ -617,6 +617,106 @@ To do this, you would:
 3. Join them in parallel, waiting for any user-input if provided
 4. Decide after every step of the first graph whether you want to cancel the 
second graph or not -- E.G. is the user satisfied.
 
+
+Async Parallelism
+================
+
+Burr also supports asynchronous parallelism. When working in an async context, 
you need to make a few adjustments to your parallel actions:
+
+1. Make your methods async
+--------------------------
+
+The `action`, `states`, `reduce`, and other methods should be defined as async:
+
+.. code-block:: python
+
+    class AsyncMapActionsAndStatesExample(MapActionsAndStates):
+
+        async def action(self, state: State, inputs: Dict[str, Any]) -> 
AsyncGenerator[Action, None]:
+            # Yield multiple model components to run in parallel
+            for i, model_config in enumerate(self._model_configs):
+                yield 
ModelResponse(config=model_config).with_name(f"model_{i}")
+
+        async def states(self, state: State, inputs: Dict[str, Any]) -> 
AsyncGenerator[State, None]:
+            # Prepare the state with the user query
+            for prompt in [
+                "What is the meaning of life?",
+                "What is the airspeed velocity of an unladen swallow?",
+                "What is the best way to cook a steak?",
+            ]:
+                yield state.update(prompt=prompt)
+
+        async def reduce(self, state: State, states: AsyncGenerator[State, 
None]) -> State:
+            # Collect all model responses
+            all_responses = []
+            async for sub_state in states:
+                model_key = sub_state.get("model_key")
+                response = sub_state.get(model_key, [])[-1].get("content", "")
+                all_responses.append(response)
+
+            return state.update(ensemble_responses=all_responses)
+
+2. Implement the is_async method
+-------------------------------
+
+You must override the `is_async` method to return `True`:
+
+.. code-block:: python
+
+    class AsyncMapActionsAndStatesExample(MapActionsAndStates):
+
+        @property
+        def is_async(self) -> bool:
+            return True
+
+        # ... other methods ...
+
+3. Use async persisters with connection pools
+--------------------------------------------
+
+When using state persistence with async parallelism, make sure to use the 
async version of persisters and initialize them with a connection pool:
+
+.. code-block:: python
+
+    from burr.integrations.persisters.b_asyncpg import AsyncPGPersister
+
+    # Create an async persister with a connection pool
+    persister = AsyncPGPersister.from_values(
+        host="localhost",
+        port=5432,
+        user="postgres",
+        password="postgres",
+        database="burr",
+        use_pool=True  # Important for parallelism!
+    )
+
+    app = (
+        ApplicationBuilder()
+        .with_state_persister(persister)
+        .with_action(
+            async_parallel_action=AsyncMapActionsAndStatesExample(),
+        )
+        .abuild()
+    )
+
+Connection pools are crucial for handling concurrent operations. Direct 
connections cannot be shared across different tasks and may cause errors in 
concurrent scenarios.
+
+Remember to properly clean up your async persisters when you're done with them:
+
+.. code-block:: python
+
+    # Using as a context manager
+    async with AsyncPGPersister.from_values(..., use_pool=True) as persister:
+        # Use persister here
+
+    # Or manual cleanup
+    persister = AsyncPGPersister.from_values(..., use_pool=True)
+    try:
+        # Use persister here
+    finally:
+        await persister.cleanup()
+
+
 Notes
 =====
 
@@ -631,4 +731,4 @@ Things that may change:
 1. We will likely alter the executor API to be more flexible, although we will 
probably allow for use of the current executor API
 2. We will be adding guard-rails for generator-types (sync versus async)
 3. The UI is a WIP -- we have more sophisticated capabilities but are still 
polishing them
-4. Support for action-level executors
+4. Support for action-level executors
\ No newline at end of file
diff --git a/docs/concepts/sync-vs-async.rst b/docs/concepts/sync-vs-async.rst
index c24e37c5..4058eb4f 100644
--- a/docs/concepts/sync-vs-async.rst
+++ b/docs/concepts/sync-vs-async.rst
@@ -21,6 +21,31 @@ Burr gives you the ability to write synchronous (standard 
python) and asynchrono
    * :py:meth:`.run() <.Application.run()>`
    * :py:meth:`.stream_result() <.Application.stream_result()>`
 
+Checklist for Async Applications
+-------------------------------
+
+When building asynchronous applications with Burr, ensure you:
+
+1. **Use async action implementations**:
+   * Implement ``async def run`` methods in your actions
+   * Override the ``is_async`` property to return ``True`` in all async 
class-based actions
+   * Use ``await`` for all I/O operations inside your actions
+
+2. **Use async builder and application methods**:
+   * Use ``.abuild()`` instead of ``.build()``
+   * Use ``.arun()``, ``.aiterate()``, and ``.astream_result()`` instead of 
their sync counterparts
+
+3. **Use async hooks and persisters**:
+   * Implement async hooks with ``async def`` methods
+   * Use async persisters (e.g., ``AsyncPGPersister`` instead of 
``PGPersister``)
+   * Properly clean up async resources using context managers or explicit 
cleanup calls
+
+4. **For parallel actions**:
+   * Make ``actions``, ``states``, and ``reduce`` methods async
+   * Override ``is_async`` to return ``True``
+   * Use ``AsyncGenerator`` return types
+   * Use async persisters with connection pools
+
 Comparison
 ----------
 

Reply via email to