Copilot commented on code in PR #62963: URL: https://github.com/apache/airflow/pull/62963#discussion_r3071668952
########## providers/common/ai/tests/unit/common/ai/utils/test_dq_planner.py: ########## @@ -0,0 +1,1444 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.providers.common.ai.utils.dq_models import DQCheck, DQCheckGroup, DQPlan +from airflow.providers.common.ai.utils.dq_planner import SQLDQPlanner +from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError +from airflow.providers.common.sql.hooks.sql import DbApiHook + + +def _make_plan(*check_names: str) -> DQPlan: + """Helper: build a minimal DQPlan with one group per check.""" + groups = [ + DQCheckGroup( + group_id="numeric_aggregate", + query=f"SELECT COUNT(*) AS {name}_count FROM t", + checks=[DQCheck(check_name=name, metric_key=f"{name}_count", group_id="numeric_aggregate")], + ) + for name in check_names + ] + return DQPlan(groups=groups) + + +def _make_llm_hook(plan: DQPlan) -> MagicMock: + """Helper: mock PydanticAIHook that returns *plan* from agent.run_sync.""" + mock_usage = MagicMock(requests=1, tool_calls=0, input_tokens=100, output_tokens=50, total_tokens=150) + mock_result = MagicMock(spec=["output", "all_messages", "usage", "response"]) + mock_result.output = plan + mock_result.all_messages.return_value = [] + mock_result.usage.return_value = mock_usage + mock_result.response.model_name = "test-model" + mock_agent = MagicMock(spec=["run_sync"]) + mock_agent.run_sync.return_value = mock_result + mock_hook = MagicMock(spec=PydanticAIHook) + mock_hook.create_agent.return_value = mock_agent + return mock_hook + + +class TestSQLDQPlannerBuildSchema: + def test_returns_manual_schema_context_verbatim(self): + planner = SQLDQPlanner(llm_hook=MagicMock(spec=PydanticAIHook), db_hook=None) + result = planner.build_schema_context( + table_names=None, + schema_context="Table: t\nColumns: id INT", + ) + assert result == "Table: t\nColumns: id INT" + + def test_introspects_via_db_hook_when_no_manual_context(self): + mock_db_hook = MagicMock() + mock_db_hook.get_table_schema.return_value = [{"name": "id", "type": "INT"}] + + planner = SQLDQPlanner(llm_hook=MagicMock(spec=PydanticAIHook), db_hook=mock_db_hook) + result = planner.build_schema_context( + table_names=["customers"], + schema_context=None, + ) + + mock_db_hook.get_table_schema.assert_called_once_with("customers") + assert "customers" in result + assert "id INT" in result + + def test_manual_context_takes_priority_over_db_hook(self): + mock_db_hook = MagicMock() + + planner = SQLDQPlanner(llm_hook=MagicMock(spec=PydanticAIHook), db_hook=mock_db_hook) + result = planner.build_schema_context( + table_names=["t"], Review Comment: Same issue here: `mock_db_hook = MagicMock()` is unspec'd, so `assert_not_called()` etc may pass even if the production API changes. Use `MagicMock(spec=DbApiHook)` (or at least `spec=["get_table_schema"]`) to keep the test aligned with the real hook interface. ########## providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py: ########## @@ -0,0 +1,908 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +SQL-based data-quality plan generation and execution. + +:class:`SQLDQPlanner` is the single entry-point for all SQL DQ logic. +It is deliberately kept separate from the operator so it can be unit-tested +without an Airflow context and later swapped for GEX/SODA planners without +touching the operator. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterator, Sequence +from contextlib import closing +from typing import TYPE_CHECKING, Any + +try: + from airflow.providers.common.ai.utils.sql_validation import ( + DEFAULT_ALLOWED_TYPES, + SQLSafetyError, + validate_sql as _validate_sql, + ) +except ImportError as e: + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + +from airflow.providers.common.ai.utils.db_schema import build_schema_context, resolve_dialect +from airflow.providers.common.ai.utils.dq_models import DQCheckGroup, DQPlan, RowLevelResult, UnexpectedResult +from airflow.providers.common.ai.utils.logging import log_run_summary + +if TYPE_CHECKING: + from pydantic_ai import Agent + from pydantic_ai.messages import ModelMessage + + from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.datafusion.engine import DataFusionEngine + from airflow.providers.common.sql.hooks.sql import DbApiHook + +log = logging.getLogger(__name__) + +_MAX_CHECKS_PER_GROUP = 5 +# Maximum rows fetched from DB per chunk during row-level processing — avoids loading the +# entire result set into memory at once. +_ROW_LEVEL_CHUNK_SIZE = 10_000 +# Hard cap on violation samples stored per check — independent of SQL LIMIT and chunk size. +_MAX_VIOLATION_SAMPLES = 100 + +_PLANNING_SYSTEM_PROMPT = """\ +You are a data-quality SQL expert. + +Given a set of named data-quality checks and a database schema, produce a \ +DQPlan that minimises the number of SQL queries while keeping each group \ +focused and manageable. + +GROUPING STRATEGY (multi-dimensional): + Group checks by **(target_table, check_category)**. Checks on the same table + that belong to different categories MUST be in separate groups. + + Allowed check_category values (assign one per check based on its description): + - null_check — null / missing value counts or percentages + - uniqueness — duplicate detection, cardinality checks + - validity — regex / format / pattern matching on string columns + - numeric_range — range, bounds, or statistical checks on numeric columns + - row_count — total row counts or existence checks + - string_format — length, encoding, whitespace, or character-set checks + - row_level — per-row or anomaly checks that evaluate individual records + + Row-level checks still follow the same grouping rule: group by (target_table, check_category="row_level"). + MAX {max_checks_per_group} CHECKS PER GROUP: + If a (table, category) pair has more than {max_checks_per_group} checks, + split them into sub-groups of at most {max_checks_per_group}. + + GROUP-ID NAMING: + Use the pattern "{{table}}_{{category}}_{{part}}". + Examples: customers_null_check_1, orders_validity_1, orders_validity_2 + + RATIONALE: + Keeping string-column checks (validity, string_format) apart from + numeric-column checks (numeric_range, null_check on numbers) produces + simpler SQL and makes failures easier to diagnose. + + CORRECT (two groups for same table, different categories): + Group customers_null_check_1: + SELECT + (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS null_email_pct, + (COUNT(CASE WHEN name IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS null_name_pct + FROM customers + + Group customers_validity_1: + SELECT + COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS invalid_phone_fmt + FROM customers + + WRONG (mixing null-check and regex-validity in one group): + SELECT + (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS null_email_pct, + COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS invalid_phone_fmt + FROM customers + +OUTPUT RULES: + 1. Each output column must be aliased to exactly the metric_key of its check. + Example: ... AS null_email_pct + 2. Each check_name must exactly match the key in the prompts dict. + 3. metric_key values must be valid SQL column aliases (snake_case, no spaces). + 4. Generates only SELECT queries — no INSERT, UPDATE, DELETE, DROP, or DDL. + 5. Use {dialect} syntax. + 6. Each check must appear in exactly ONE group. + 7. Each check must have a check_category from the allowed list above. + 8. Return a valid DQPlan object. No extra commentary. +""" + +_DATAFUSION_SYNTAX_SECTION = """\ + +DATAFUSION SQL SYNTAX RULES: + The target engine is Apache DataFusion. Observe these syntax differences + from standard PostgreSQL / ANSI SQL: + + 1. NO "FILTER (WHERE ...)" clause. Use CASE expressions instead: + WRONG: COUNT(*) FILTER (WHERE email IS NULL) + RIGHT: COUNT(CASE WHEN email IS NULL THEN 1 END) + + 2. Regex matching uses the tilde operator: + column ~ 'pattern' (match) + column !~ 'pattern' (no match) + Do NOT use SIMILAR TO or POSIX-style ~* (case-insensitive). + + 3. CAST syntax — prefer CAST(expr AS type) over :: shorthand. + + 4. String functions: Use CHAR_LENGTH (not LEN), SUBSTR (not SUBSTRING with FROM/FOR). + + 5. Integer division: DataFusion performs integer division for INT/INT. + Use CAST(expr AS DOUBLE) to force floating-point division. + + 6. Boolean literals: Use TRUE / FALSE (not 1 / 0). + + 7. LIMIT is supported. OFFSET is supported. FETCH FIRST is NOT supported. + + 8. NULL handling: COALESCE, NULLIF, IFNULL are all supported. + NVL and ISNULL are NOT supported. +""" + +_UNEXPECTED_QUERY_PROMPT_SECTION = """\ + +UNEXPECTED VALUE COLLECTION: + For checks whose check_category is "validity" or "string_format", also + generate an unexpected_query field on the DQCheck. This query must: + - SELECT the primary key column(s) and the column(s) being validated + - WHERE the row violates the check condition (the negation of the check) + - LIMIT {sample_size} + - Use {dialect} syntax + - Be a standalone SELECT (not a subquery of the group query) + + For all other categories (null_check, uniqueness, numeric_range, row_count), + set unexpected_query to null — these are aggregate checks where individual + violating rows are not meaningful. + + Example for a phone-format validity check: + unexpected_query: "SELECT id, phone FROM customers WHERE phone !~ '^\\d{{4}}-\\d{{4}}-\\d{{4}}$' LIMIT 100" +""" + +_ROW_LEVEL_PROMPT_SECTION = """ + +ROW-LEVEL CHECKS: + Some checks are marked as row_level. For these: + - Generate a SELECT that returns the primary key column(s) and the column + being validated. Do NOT aggregate. + - Set row_level = true on the DQCheck entry. + - metric_key must be the name of the column containing the value to validate + (the Python validator will read row[metric_key] for each row). + - {row_level_limit_clause} + - Place ALL row-level checks for the same table in a single group. + + Row-level check names that require this treatment: {row_level_check_names} +""" + + +class SQLDQPlanner: + """ + Generates and executes a SQL-based :class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`. + + :param llm_hook: Hook used to call the LLM for plan generation. + :param db_hook: Hook used to execute generated SQL against the database. + :param dialect: SQL dialect forwarded to the LLM prompt and ``validate_sql``. + Auto-detected from *db_hook* when ``None``. + :param max_sql_retries: Maximum number of times a failing SQL group query is sent + back to the LLM for correction before the error is re-raised. Default ``2``. + :param validator_contexts: Pre-built LLM context string from + :meth:`~airflow.providers.common.ai.utils.dq_validation.ValidatorRegistry.build_llm_context`. + Appended to the system prompt so the LLM knows what metric format each + custom validator expects. + :param row_validators: Mapping of ``{check_name: row_level_callable}`` for + checks that require row-by-row Python validation. When a check's name + appears here, ``execute_plan`` fetches all (or sampled) rows and applies + the callable to each value instead of reading a single aggregate scalar. + :param row_level_sample_size: Maximum number of rows to fetch for row-level + checks. ``None`` (default) performs a full scan. A positive integer + instructs the LLM to add ``LIMIT N`` to the generated SELECT. + """ + + def __init__( + self, + *, + llm_hook: PydanticAIHook, + db_hook: DbApiHook | None, + dialect: str | None = None, + max_sql_retries: int = 2, + datasource_config: DataSourceConfig | None = None, + system_prompt: str = "", + agent_params: dict[str, Any] | None = None, + collect_unexpected: bool = False, + unexpected_sample_size: int = 100, + validator_contexts: str = "", + row_validators: dict[str, Any] | None = None, + row_level_sample_size: int | None = None, + ) -> None: + self._llm_hook = llm_hook + self._db_hook = db_hook + self._datasource_config = datasource_config + self._dialect = resolve_dialect(db_hook, dialect) + # Track whether the execution target is DataFusion so the prompt can + # include DataFusion-specific syntax rules. The dialect stays None + # (generic SQL) for sqlglot validation — sqlglot has no DataFusion dialect. + self._is_datafusion = db_hook is None and datasource_config is not None + # When targeting DataFusion, use PostgreSQL dialect for sqlglot validation + # because DataFusion shares regex operators (~, !~) that the generic SQL + # parser does not recognise. + self._validation_dialect: str | None = "postgres" if self._is_datafusion else self._dialect + self._max_sql_retries = max_sql_retries + self._extra_system_prompt = system_prompt + self._agent_params: dict[str, Any] = agent_params or {} + self._collect_unexpected = collect_unexpected + self._unexpected_sample_size = unexpected_sample_size + self._validator_contexts = validator_contexts + self._row_validators: dict[str, Any] = row_validators or {} + self._row_level_sample_size = row_level_sample_size + self._cached_datafusion_engine: DataFusionEngine | None = None + self._plan_agent: Agent[None, DQPlan] | None = None + self._plan_all_messages: list[ModelMessage] | None = None + + def build_schema_context( + self, + table_names: list[str] | None, + schema_context: str | None, + ) -> str: + """ + Return a schema description string for inclusion in the LLM prompt. + + Delegates to :func:`~airflow.providers.common.ai.utils.db_schema.build_schema_context`. + """ + return build_schema_context( + db_hook=self._db_hook, + table_names=table_names, + schema_context=schema_context, + datasource_config=self._datasource_config, + ) + + def generate_plan(self, prompts: dict[str, str], schema_context: str) -> DQPlan: + """ + Ask the LLM to produce a :class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`. + + The LLM receives the user prompts, schema context, and planning instructions + as a structured-output call (``output_type=DQPlan``). After generation the + method verifies that the returned ``check_names`` exactly match + ``prompts.keys()``. + + :param prompts: ``{check_name: natural_language_description}`` dict. + :param schema_context: Schema description previously built via + :meth:`build_schema_context`. + :raises ValueError: If the LLM's plan does not cover every prompt key + exactly once. + """ + dialect_label = self._dialect or ("DataFusion-compatible SQL" if self._is_datafusion else "SQL") + system_prompt = _PLANNING_SYSTEM_PROMPT.format( + dialect=dialect_label, max_checks_per_group=_MAX_CHECKS_PER_GROUP + ) + + if self._is_datafusion: + system_prompt += _DATAFUSION_SYNTAX_SECTION + + if self._collect_unexpected: + system_prompt += _UNEXPECTED_QUERY_PROMPT_SECTION.format( + dialect=dialect_label, sample_size=self._unexpected_sample_size + ) + + if schema_context: + system_prompt += f"\nAvailable schema:\n{schema_context}\n" + + if self._validator_contexts: + system_prompt += self._validator_contexts + + if self._row_validators: + row_level_check_names = ", ".join(sorted(self._row_validators)) + if self._row_level_sample_size is not None: + limit_clause = f"Add LIMIT {self._row_level_sample_size} to the query." + else: + limit_clause = "Do NOT add a LIMIT — return all rows." + system_prompt += _ROW_LEVEL_PROMPT_SECTION.format( + row_level_check_names=row_level_check_names, + row_level_limit_clause=limit_clause, + ) + + if self._extra_system_prompt: + system_prompt += f"\nAdditional instructions:\n{self._extra_system_prompt}\n" + + user_message = self._build_user_message(prompts) + + log.info("Using system prompt:\n%s", system_prompt) + log.info("Using user message:\n%s", user_message) + Review Comment: `generate_plan()` logs the full system prompt and user message at INFO. The system prompt is very large and may include schema context and custom validator hints, which can bloat logs and potentially expose sensitive schema details. Consider logging these at DEBUG (or behind an opt-in flag) and keeping INFO to a short summary (e.g. counts + plan_hash). ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py: ########## @@ -0,0 +1,703 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for generating and executing data-quality checks from natural language using LLMs.""" + +from __future__ import annotations + +import hashlib +import json +from collections.abc import Callable, Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.utils.db_schema import get_db_hook +from airflow.providers.common.ai.utils.dq_models import ( + DQCheckFailedError, + DQCheckGroup, + DQCheckResult, + DQPlan, + DQReport, + RowLevelResult, + UnexpectedResult, +) +from airflow.providers.common.ai.utils.dq_validation import default_registry +from airflow.providers.common.compat.sdk import Variable + +try: + from airflow.providers.common.ai.utils.dq_planner import SQLDQPlanner +except ImportError as e: + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + +if TYPE_CHECKING: + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.hooks.sql import DbApiHook + from airflow.sdk import Context + +_PLAN_VARIABLE_PREFIX = "dq_plan_" +_PLAN_VARIABLE_KEY_MAX_LEN = 200 # stay well under Airflow Variable key length limit + + +def _describe_validator(validator: Callable[[Any], bool]) -> str: + """Return a human-readable validator label for failure messages.""" + display = getattr(validator, "_validator_display", None) + if isinstance(display, str) and display: + return display + validator_name = getattr(validator, "_validator_name", None) + if isinstance(validator_name, str) and validator_name: + return validator_name + validator_name = getattr(validator, "__name__", None) + if isinstance(validator_name, str) and validator_name: + return validator_name + return repr(validator) + + +class LLMDataQualityOperator(LLMOperator): + """ + Generate and execute data-quality checks from natural language descriptions. + + Each entry in ``prompts`` describes **one** data-quality expectation. + The LLM groups related checks into optimised SQL queries, executes them + against the target database, and validates each metric against the + corresponding entry in ``validators``. The task fails if any check + does not pass, gating downstream tasks on data quality. + + Generated SQL plans are cached in Airflow + :class:`~airflow.models.variable.Variable` to avoid repeat LLM calls. + Set ``dry_run=True`` to preview the plan without executing it — the + serialised plan dict is returned without running any SQL. + Set ``require_approval=True`` to gate execution on human review via the + HITL interface: the plan is presented to the reviewer first, and SQL + checks run only after approval. ``dry_run`` and ``require_approval`` + are independent — enabling both returns the plan dict without any + approval prompt. + + :param prompts: Mapping of ``{check_name: natural_language_description}``. + Each key must be unique. Use one check per key; the operator enforces + a strict one-key → one-check mapping. + :param llm_conn_id: Connection ID for the LLM provider. + :param model_id: Model identifier (e.g. ``"openai:gpt-4o"``). + Overrides the model stored in the connection's extra field. + :param system_prompt: Additional instructions appended to the planning prompt. + :param agent_params: Additional keyword arguments passed to the pydantic-ai + ``Agent`` constructor (e.g. ``retries``, ``model_settings``). + :param db_conn_id: Connection ID for the database to run checks against. + Must resolve to a :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`. + :param table_names: Tables to include in the LLM's schema context. + :param schema_context: Manual schema description; bypasses DB introspection. + :param validators: Mapping of ``{check_name: callable}`` where each callable + receives the raw metric value and returns ``True`` (pass) or ``False`` (fail). + Keys must be a subset of ``prompts.keys()``. + Use built-in factories from + :mod:`~airflow.providers.common.ai.utils.dq_validation` or plain lambdas:: + + from airflow.providers.common.ai.utils.dq_validation import null_pct_check + + validators = { + "email_nulls": null_pct_check(max_pct=0.05), + "row_check": lambda v: v >= 1000, + } + + :param dialect: SQL dialect override (``postgres``, ``mysql``, etc.). + Auto-detected from *db_conn_id* when not set. + :param datasource_config: DataFusion datasource for object-storage schema. + :param dry_run: When ``True``, generate and cache the plan but skip execution. + Returns the serialised plan dict instead of a :class:`~airflow.providers.common.ai.utils.dq_models.DQReport`. + :param prompt_version: Optional version tag included in the plan cache key. + Bump this to invalidate cached plans when prompts change semantically + without changing their text. + :param collect_unexpected: When ``True``, the LLM generates an + ``unexpected_query`` for validity / string-format checks. + If any of those checks fail, the unexpected query is executed and + the resulting sample rows are included in the report. + :param unexpected_sample_size: Maximum number of violating rows to return + per failed check. Default ``100``. + :param row_level_sample_size: Maximum number of rows to fetch per row-level + check. ``None`` (default) performs a full table scan — every row is + fetched and validated. A positive integer is passed to the LLM as a + ``LIMIT`` clause on the generated SELECT, bounding execution time and + memory usage at the cost of sampling coverage. + :param require_approval: When ``True``, the operator defers after generating + and caching the DQ plan. The plan SQL is surfaced in the HITL interface + for human review; checks run only after the reviewer approves. Inherited + from :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`. + ``dry_run=True`` takes precedence — combining both flags returns the plan + dict immediately without requesting approval. + """ + + template_fields: Sequence[str] = ( + *LLMOperator.template_fields, + "prompts", + "db_conn_id", + "table_names", + "schema_context", + "prompt_version", + "collect_unexpected", + "unexpected_sample_size", + "row_level_sample_size", + ) + + def __init__( + self, + *, + prompts: dict[str, str], + db_conn_id: str | None = None, + table_names: list[str] | None = None, + schema_context: str | None = None, + validators: dict[str, Callable[[Any], bool]] | None = None, + dialect: str | None = None, + datasource_config: DataSourceConfig | None = None, + prompt_version: str | None = None, + dry_run: bool = False, + collect_unexpected: bool = False, + unexpected_sample_size: int = 100, + row_level_sample_size: int | None = None, + **kwargs: Any, + ) -> None: + kwargs.pop("output_type", None) + kwargs.setdefault("prompt", "LLMDataQualityOperator") + super().__init__(**kwargs) + + self.prompts = prompts + self.db_conn_id = db_conn_id + self.table_names = table_names + self.schema_context = schema_context + self.validators = validators or {} + self.dialect = dialect + self.datasource_config = datasource_config + self.prompt_version = prompt_version + self.dq_dry_run = dry_run + self.collect_unexpected = collect_unexpected + self.unexpected_sample_size = unexpected_sample_size + self.row_level_sample_size = row_level_sample_size + + self._validate_prompts() + self._validate_validator_keys() + + def execute(self, context: Context) -> dict[str, Any]: + """ + Generate the DQ plan (or load from cache), then execute or defer for approval. + + When ``dry_run=True`` the serialised plan dict is returned immediately — + no SQL is executed and no approval is requested. + When ``require_approval=True`` the task defers, presenting the plan to a + human reviewer; data-quality checks run only after the reviewer approves. + + :returns: Dict with keys ``plan``, ``passed``, and ``results``. On success + ``passed=True`` and ``results`` is a list of per-check result dicts. + For row-level checks the ``value`` entry in each result dict is itself + a dict with keys ``total``, ``invalid``, ``invalid_pct``, and + ``sample_violations`` rather than a raw scalar. + When ``dry_run=True`` ``passed=None`` and ``results=None`` — no SQL + is executed. The ``plan`` key is always present in all modes. + :raises DQCheckFailedError: If any data-quality check fails threshold validation. + :raises TaskDeferred: When ``require_approval=True``, defers for human review + before executing the checks. + """ + planner = self._build_planner() + + schema_ctx = planner.build_schema_context( + table_names=self.table_names, schema_context=self.schema_context + ) + + self.log.info("Using schema context:\n%s", schema_ctx) + + plan = self._load_or_generate_plan(planner, schema_ctx) + + if self.dq_dry_run: + self.log.info( + "dry_run=True — skipping execution. Plan contains %d group(s), %d check(s).", + len(plan.groups), + len(plan.check_names), + ) + for group in plan.groups: + self.log.info( + "Group: %s\nChecks: %s\nSQL Query:\n%s\n", + group.group_id, + ", ".join(c.check_name for c in group.checks), + group.query, + ) + return {"plan": plan.model_dump(), "passed": None, "results": None} + + if self.require_approval: + # Defer BEFORE execution — approval gates the SQL checks. + self.defer_for_approval( # type: ignore[misc] + context, + plan.model_dump_json(), + body=self._build_dry_run_markdown(plan), + ) + return {} # type: ignore[return-value] # pragma: no cover + + return self._run_checks_and_report(context, planner, plan) + + def _build_planner(self) -> SQLDQPlanner: + """Construct a :class:`~airflow.providers.common.ai.utils.dq_planner.SQLDQPlanner` from operator config.""" + return SQLDQPlanner( + llm_hook=self.llm_hook, + db_hook=self.db_hook, + dialect=self.dialect, + datasource_config=self.datasource_config, + system_prompt=self.system_prompt, + agent_params=self.agent_params, + collect_unexpected=self.collect_unexpected, + unexpected_sample_size=self.unexpected_sample_size, + validator_contexts=self.validator_contexts, + row_validators=self._collect_row_validators(), + row_level_sample_size=self.row_level_sample_size, + ) + + @cached_property + def validator_contexts(self) -> str: + """Return validator-specific LLM context rendered from configured validators.""" + return default_registry.build_llm_context(self.validators) + + def _run_checks_and_report( + self, + context: Context, + planner: SQLDQPlanner, + plan: DQPlan, + ) -> dict[str, Any]: + """ + Execute *plan* against the database, apply validators, and return the serialised report. + + :raises DQCheckFailedError: If any data-quality check fails. + """ + results_map = planner.execute_plan(plan) + check_results = self._validate_results(results_map, plan) + + # Collect unexpected rows for failed validity/format checks. + if self.collect_unexpected: + failed_names = {r.check_name for r in check_results if not r.passed} + if failed_names: + unexpected_map = planner.execute_unexpected_queries(plan, failed_names) + self._attach_unexpected(check_results, unexpected_map) + + report = DQReport.build(check_results) + + output: dict[str, Any] = { + "plan": plan.model_dump(), + "passed": report.passed, + "results": [ + { + "check_name": r.check_name, + "metric_key": r.metric_key, + # RowLevelResult is not JSON-serialisable; convert to a plain dict. + "value": ( + { + "total": r.value.total, + "invalid": r.value.invalid, + "invalid_pct": r.value.invalid_pct, + "sample_violations": r.value.sample_violations, + "sample_size": r.value.sample_size, + } + if isinstance(r.value, RowLevelResult) + else r.value + ), + "passed": r.passed, + "failure_reason": r.failure_reason, + **( + { + "unexpected_records": r.unexpected.unexpected_records, + "unexpected_sample_size": r.unexpected.sample_size, + } + if r.unexpected + else {} + ), + } + for r in report.results + ], + } + + if not report.passed: + # Push results to XCom before failing so downstream tasks + # (e.g. with trigger_rule=all_done) can still inspect them. + context["ti"].xcom_push(key="return_value", value=output) + raise DQCheckFailedError(report.failure_summary) + + self.log.info("All %d data-quality check(s) passed.", len(report.results)) + return output + + def _build_dry_run_markdown(self, plan: DQPlan) -> str: + """ + Build a structured markdown summary of the DQ plan for the HITL review body. + + Aggregate groups and row-level groups are rendered in separate sections so + reviewers can immediately distinguish SQL-aggregate checks from per-row + validation logic. + """ + aggregate_groups = [g for g in plan.groups if not any(c.row_level for c in g.checks)] + row_level_groups = [g for g in plan.groups if any(c.row_level for c in g.checks)] + + total_checks = len(plan.check_names) + agg_count = sum(len(g.checks) for g in aggregate_groups) + row_count = sum(len(g.checks) for g in row_level_groups) + + lines: list[str] = [ + "# LLM Data Quality Plan", + "", + "| | |", + "|---|---|", + f"| **Plan hash** | `{plan.plan_hash or 'N/A'}` |", + f"| **Total checks** | {total_checks} |", + f"| **Aggregate checks** | {agg_count} ({len(aggregate_groups)} group{'s' if len(aggregate_groups) != 1 else ''}) |", + f"| **Row-level checks** | {row_count} ({len(row_level_groups)} group{'s' if len(row_level_groups) != 1 else ''}) |", + "", + ] + + if aggregate_groups: + lines += [ + "---", + "", + "## Aggregate Checks", + "", + "> Each group runs as a **single SQL query**. " + "Result columns are matched to check names by metric key.", + "", + ] + for group in aggregate_groups: + lines += self._render_aggregate_group(group) + + if row_level_groups: + lines += [ + "---", + "", + "## Row-Level Checks", + "", + "> Row-level checks fetch **raw column values** and apply Python-side " + "validation per row. The threshold controls the maximum allowed fraction " + "of invalid rows before the check fails.", + "", + ] + for group in row_level_groups: + lines += self._render_row_level_group(group) + + return "\n".join(lines).rstrip() + + def _render_aggregate_group(self, group: DQCheckGroup) -> list[str]: + """Render one aggregate SQL group as a markdown subsection.""" + lines: list[str] = [ + f"### `{group.group_id}`", + "", + "| Check name | Metric key | Category |", + "|---|---|---|", + ] + for check in group.checks: + category = check.check_category or "—" + lines.append(f"| `{check.check_name}` | `{check.metric_key}` | {category} |") + + lines += [ + "", + "```sql", + group.query.strip(), + "```", + "", + ] + + # Unexpected queries — only show when present. + unexpected = [(c.check_name, c.unexpected_query) for c in group.checks if c.unexpected_query] + if unexpected: + lines += ["<details><summary>Unexpected-row queries</summary>", ""] + for check_name, uq in unexpected: + lines += [ + f"**`{check_name}`**", + "", + "```sql", + (uq or "").strip(), + "```", + "", + ] + lines += ["</details>", ""] + + return lines + + def _render_row_level_group(self, group: DQCheckGroup) -> list[str]: + """Render one row-level group as a markdown subsection with threshold info.""" + lines: list[str] = [ + f"### `{group.group_id}`", + "", + "| Check name | Metric key | Max invalid % |", + "|---|---|---|", + ] + for check in group.checks: + validator = self.validators.get(check.check_name) + max_pct = getattr(validator, "_max_invalid_pct", None) + threshold_str = f"{max_pct:.2%}" if max_pct is not None else "—" + lines.append(f"| `{check.check_name}` | `{check.metric_key}` | {threshold_str} |") + + lines += [ + "", + "```sql", + group.query.strip(), + "```", + "", + ] + return lines + + def _load_or_generate_plan(self, planner: SQLDQPlanner, schema_ctx: str) -> DQPlan: + """Return a cached plan when available, otherwise generate and cache a new one.""" + if not isinstance(self.prompts, dict): + raise TypeError("prompts must be a dict[str, str] before generating a DQ plan.") + + row_validator_thresholds = self._collect_row_validator_thresholds() + + plan_hash = _compute_plan_hash( + self.prompts, + self.prompt_version, + self.collect_unexpected, + self.row_level_sample_size, + schema_context=schema_ctx, + unexpected_sample_size=self.unexpected_sample_size, + validator_contexts=self.validator_contexts, + row_validator_thresholds=row_validator_thresholds, + ) + variable_key = f"{_PLAN_VARIABLE_PREFIX}{plan_hash}" + + cached_json = Variable.get(variable_key, None) + if cached_json is not None: + self.log.info("DQ plan cache hit — key: %r", variable_key) + plan = DQPlan.model_validate_json(cached_json) + if not plan.plan_hash: + plan.plan_hash = plan_hash + return plan + + self.log.info("DQ plan cache miss — generating via LLM (key: %r).", variable_key) + plan = planner.generate_plan(self.prompts, schema_ctx) + plan.plan_hash = plan_hash + Variable.set(variable_key, plan.model_dump_json()) + return plan + + def _validate_results( + self, + results_map: dict[str, Any], + plan: DQPlan, + ) -> list[DQCheckResult]: + """ + Apply validators to each metric value and return per-check results. + + For aggregate checks each validator callable receives the raw metric + value returned by the database. For row-level checks, where *value* is + a :class:`~airflow.providers.common.ai.utils.dq_models.RowLevelResult`, + the pass/fail decision compares ``invalid_pct`` against the validator's + ``_max_invalid_pct`` attribute (defaulting to ``0.0`` when absent). + Aggregate checks without a registered validator are logged and marked + as passed; row-level checks require an explicit validator and fail + when none is provided. + + :param results_map: ``{check_name: metric_value_or_RowLevelResult}`` as + returned by + :meth:`~airflow.providers.common.ai.utils.dq_planner.SQLDQPlanner.execute_plan`. + :param plan: The DQ plan whose groups and checks drive iteration order. + :returns: Per-check + :class:`~airflow.providers.common.ai.utils.dq_models.DQCheckResult` + list in plan-group order. + :raises ValueError: If *results_map* is missing a key for any check in *plan*. + """ + check_results: list[DQCheckResult] = [] + + for group in plan.groups: + for check in group.checks: + if check.check_name not in results_map: + raise ValueError( + f"Planner did not return a result for check {check.check_name!r} " + f"(group {group.group_id!r}). Available keys: {sorted(results_map)}" + ) + value = results_map[check.check_name] + validator = self.validators.get(check.check_name) + + passed = True + failure_reason: str | None = None + + if isinstance(value, RowLevelResult): + if validator is None: + self.log.error( + "No validator found for row-level check %r (metric key: %r). " + "Row-level checks require an explicit validator.", + check.check_name, + check.metric_key, + ) + passed = False + failure_reason = ( + "Row-level check requires a registered row-level validator, " + "but none was provided." + ) + else: + if not hasattr(validator, "_max_invalid_pct"): + self.log.warning( + "Row-level validator for check %r has no '_max_invalid_pct' attribute — " + "defaulting threshold to 0.0%%. Every invalid row will fail the check.", + check.check_name, + ) + max_pct = getattr(validator, "_max_invalid_pct", 0.0) + passed = value.invalid_pct <= max_pct + if not passed: + failure_reason = ( + f"Row-level check failed: {value.invalid}/{value.total} rows invalid " + f"({value.invalid_pct:.4%}), threshold {max_pct:.4%}" + ) Review Comment: Row-level validation assumes `validator._max_invalid_pct` is numeric. If a user passes a non-numeric value (e.g. string via templating), `value.invalid_pct <= max_pct` will raise `TypeError` and fail the task with a low-signal exception. Consider validating/coercing `max_pct` to `float` and raising a clear `ValueError` that names the check when it's missing/invalid. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
