bhirsz commented on code in PR #26368:
URL: https://github.com/apache/airflow/pull/26368#discussion_r981108819
##########
airflow/providers/google/cloud/operators/bigquery.py:
##########
@@ -520,6 +524,241 @@ def execute_complete(self, context: Context, event:
dict[str, Any]) -> None:
)
+class BigQueryColumnCheckOperator(_BigQueryDbHookMixin,
SQLColumnCheckOperator):
+ """
+ BigQueryColumnCheckOperator subclasses the SQLColumnCheckOperator
+ in order to provide a job id for OpenLineage to parse. See base class
+ docstring for usage.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BigQueryColumnCheckOperator`
+
+ :param table: the table name
+ :param column_mapping: a dictionary relating columns to their checks
+ :param partition_clause: a string SQL statement added to a WHERE clause
+ to partition data
+ :param gcp_conn_id: (Optional) The connection ID used to connect to Google
Cloud.
+ :param use_legacy_sql: Whether to use legacy SQL (true)
+ or standard SQL (false).
+ :param location: The geographic location of the job. See details at:
+
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param labels: a dictionary containing labels for the table, passed to
BigQuery
+ """
+
+ def __init__(
+ self,
+ *,
+ table: str,
+ column_mapping: dict,
+ partition_clause: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ use_legacy_sql: bool = True,
+ location: str | None = None,
+ impersonation_chain: str | Sequence[str] | None = None,
+ labels: dict | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ table=table, column_mapping=column_mapping,
partition_clause=partition_clause, **kwargs
+ )
+ self.table = table
+ self.column_mapping = column_mapping
+ self.partition_clause = partition_clause
+ self.gcp_conn_id = gcp_conn_id
+ self.use_legacy_sql = use_legacy_sql
+ self.location = location
+ self.impersonation_chain = impersonation_chain
+ self.labels = labels
+ # OpenLineage needs a valid SQL query with the input/output table(s)
to parse
+ self.sql = ""
+
+ def _submit_job(
+ self,
+ hook: BigQueryHook,
+ job_id: str,
+ ) -> BigQueryJob:
+ """Submit a new job and get the job id for polling the status using
Trigger."""
+ configuration = {"query": {"query": self.sql}}
+
+ return hook.insert_job(
+ configuration=configuration,
+ project_id=hook.project_id,
+ location=self.location,
+ job_id=job_id,
+ nowait=False,
+ )
+
+ def execute(self, context=None):
+ """Perform checks on the given columns."""
+ hook = self.get_db_hook()
+ failed_tests = []
+ for column in self.column_mapping:
+ checks = [*self.column_mapping[column]]
+ checks_sql = ",".join([self.column_checks[check].replace("column",
column) for check in checks])
+ partition_clause_statement = f"WHERE {self.partition_clause}" if
self.partition_clause else ""
+ self.sql = f"SELECT {checks_sql} FROM {self.table}
{partition_clause_statement};"
+
+ job_id = hook.generate_job_id(
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ logical_date=context["logical_date"],
+ configuration=self.configuration,
+ )
+ job = self._submit_job(hook, job_id=job_id)
+ context["ti"].xcom_push(key="job_id", value=job.job_id)
+ records = list(job.result().to_dataframe().values.flatten())
+
+ if not records:
+ raise AirflowException(f"The following query returned zero
rows: {self.sql}")
+
+ self.log.info("Record: %s", records)
+
+ for idx, result in enumerate(records):
+ tolerance =
self.column_mapping[column][checks[idx]].get("tolerance")
+
+ self.column_mapping[column][checks[idx]]["result"] = result
+ self.column_mapping[column][checks[idx]]["success"] =
self._get_match(
+ self.column_mapping[column][checks[idx]], result, tolerance
+ )
+
+
failed_tests.extend(_get_failed_checks(self.column_mapping[column], column))
+ if failed_tests:
+ raise AirflowException(
+ f"Test failed.\nResults:\n{records!s}\n"
+ "The following tests have failed:"
+ f"\n{''.join(failed_tests)}"
+ )
+
+ self.log.info("All tests have passed")
+
+
+class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
+ """
+ BigQueryTableCheckOperator subclasses the SQLTableCheckOperator
+ in order to provide a job id for OpenLineage to parse. See base class
+ for usage.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BigQueryTableCheckOperator`
+
+ :param table: the table name
+ :param checks: a dictionary of check names and boolean SQL statements
+ :param partition_clause: a string SQL statement added to a WHERE clause
+ to partition data
+ :param gcp_conn_id: (Optional) The connection ID used to connect to Google
Cloud.
+ :param use_legacy_sql: Whether to use legacy SQL (true)
+ or standard SQL (false).
+ :param location: The geographic location of the job. See details at:
+
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param labels: a dictionary containing labels for the table, passed to
BigQuery
+ """
+
+ def __init__(
+ self,
+ *,
+ table: str,
+ checks: dict,
+ partition_clause: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ use_legacy_sql: bool = True,
+ location: str | None = None,
+ impersonation_chain: str | Sequence[str] | None = None,
+ labels: dict | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(table=table, checks=checks,
partition_clause=partition_clause, **kwargs)
+ self.table = table
+ self.checks = checks
+ self.partition_clause = partition_clause
+ self.gcp_conn_id = gcp_conn_id
+ self.use_legacy_sql = use_legacy_sql
+ self.location = location
+ self.impersonation_chain = impersonation_chain
+ self.labels = labels
+ # OpenLineage needs a valid SQL query with the input/output table(s)
to parse
+ self.sql = ""
+
+ def _submit_job(
+ self,
+ hook: BigQueryHook,
+ job_id: str,
+ ) -> BigQueryJob:
+ """Submit a new job and get the job id for polling the status using
Trigger."""
+ configuration = {"query": {"query": self.sql}}
+
+ return hook.insert_job(
+ configuration=configuration,
+ project_id=hook.project_id,
+ location=self.location,
+ job_id=job_id,
+ nowait=False,
+ )
+
+ def execute(self, context=None):
+ """Execute the given checks on the table."""
+ hook = self.get_db_hook()
+ checks_sql = " UNION ALL ".join(
+ [
+ self.sql_check_template.replace("check_statement",
value["check_statement"])
+ .replace("_check_name", check_name)
+ .replace("table", self.table)
+ for check_name, value in self.checks.items()
+ ]
+ )
+ partition_clause_statement = f"WHERE {self.partition_clause}" if
self.partition_clause else ""
+ self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) "
+ f"AS check_table {partition_clause_statement};"
+
+ job_id = hook.generate_job_id(
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ logical_date=context["logical_date"],
+ configuration=self.configuration,
+ )
+ job = self._submit_job(hook, job_id=job_id)
+ context["ti"].xcom_push(key="job_id", value=job.job_id)
+ records = job.result().to_dataframe()
+
+ if records.empty:
+ raise AirflowException(f"The following query returned zero rows:
{self.sql}")
+
+ records.columns = records.columns.str.lower()
+ self.log.info("Record:\n%s", records)
+
+ for row in records.iterrows():
Review Comment:
I don't know what type record is, but you could try to unpack it for more
clarity instead of row[1]:
```
for _, columns records.iterrows():
check = columns.get("check_name")
...
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]