This is an automated email from the ASF dual-hosted git repository. pierrejeambrun pushed a commit to branch v2-5-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit e2cf93305803027d462373ca356bd5badee1128e Author: Tzu-ping Chung <[email protected]> AuthorDate: Wed Jan 18 18:19:35 2023 +0800 Resolve all variables in pickled XCom iterator (#28982) (cherry picked from commit ccf53e167ea57716c76ec7ab5bd1223f0c0d47d3) --- airflow/models/xcom.py | 7 +++++- tests/conftest.py | 2 +- tests/models/test_taskinstance.py | 49 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 3b43618424..6294fa3d7f 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -731,7 +731,12 @@ class LazyXComAccess(collections.abc.Sequence): # do the same for count(), but I think it should be performant enough to # calculate only that eagerly. with self._get_bound_query() as query: - statement = query.statement.compile(query.session.get_bind()) + statement = query.statement.compile( + query.session.get_bind(), + # This inlines all the values into the SQL string to simplify + # cross-process commuinication as much as possible. + compile_kwargs={"literal_binds": True}, + ) return (str(statement), query.count()) def __setstate__(self, state: Any) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index d71d8eb0f0..3fb6d83489 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -291,7 +291,7 @@ def skip_if_not_marked_with_backend(selected_backend, item): if selected_backend in backend_names: return pytest.skip( - f"The test is skipped because it does not have the right backend marker " + f"The test is skipped because it does not have the right backend marker. " f"Only tests marked with pytest.mark.backend('{selected_backend}') are run: {item}" ) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 17ce74178d..849358cc69 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3592,6 +3592,40 @@ class TestMappedTaskInstanceReceiveValue: assert out_lines == ["hello FOO", "goodbye FOO", "hello BAR", "goodbye BAR"] +def _get_lazy_xcom_access_expected_sql_lines() -> list[str]: + backend = os.environ.get("BACKEND") + if backend == "mssql": + return [ + "SELECT xcom.value", + "FROM xcom", + "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " + "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.[key] = 'xxx'", + ] + elif backend == "mysql": + return [ + "SELECT xcom.value", + "FROM xcom", + "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " + "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.`key` = 'xxx'", + ] + elif backend == "postgres": + return [ + "SELECT xcom.value", + "FROM xcom", + "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " + "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.key = 'xxx'", + ] + elif backend == "sqlite": + return [ + "SELECT xcom.value", + "FROM xcom", + "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " + "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.\"key\" = 'xxx'", + ] + else: + raise RuntimeError(f"unknown backend {backend!r}") + + def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session): with dag_maker(session=session): EmptyOperator(task_id="t") @@ -3599,9 +3633,22 @@ def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session): run: DagRun = dag_maker.create_dagrun() run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session) - original = LazyXComAccess.build_from_xcom_query(session.query(XCom)) + query = session.query(XCom.value).filter_by( + dag_id=run.dag_id, + run_id=run.run_id, + task_id="t", + map_index=-1, + key="xxx", + ) + + original = LazyXComAccess.build_from_xcom_query(query) processed = pickle.loads(pickle.dumps(original)) + # After the object went through pickling, the underlying ORM query should be + # replaced by one backed by a literal SQL string with all variables binded. + sql_lines = [line.strip() for line in str(processed._query.statement.compile(None)).splitlines()] + assert sql_lines == _get_lazy_xcom_access_expected_sql_lines() + assert len(processed) == 1 assert list(processed) == [123]
