bhirsz commented on code in PR #26368:
URL: https://github.com/apache/airflow/pull/26368#discussion_r981107404


##########
airflow/providers/google/cloud/operators/bigquery.py:
##########
@@ -520,6 +524,241 @@ def execute_complete(self, context: Context, event: 
dict[str, Any]) -> None:
         )
 
 
+class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, 
SQLColumnCheckOperator):
+    """
+    BigQueryColumnCheckOperator subclasses the SQLColumnCheckOperator
+    in order to provide a job id for OpenLineage to parse. See base class
+    docstring for usage.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:BigQueryColumnCheckOperator`
+
+    :param table: the table name
+    :param column_mapping: a dictionary relating columns to their checks
+    :param partition_clause: a string SQL statement added to a WHERE clause
+        to partition data
+    :param gcp_conn_id: (Optional) The connection ID used to connect to Google 
Cloud.
+    :param use_legacy_sql: Whether to use legacy SQL (true)
+        or standard SQL (false).
+    :param location: The geographic location of the job. See details at:
+        
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param labels: a dictionary containing labels for the table, passed to 
BigQuery
+    """
+
+    def __init__(
+        self,
+        *,
+        table: str,
+        column_mapping: dict,
+        partition_clause: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        use_legacy_sql: bool = True,
+        location: str | None = None,
+        impersonation_chain: str | Sequence[str] | None = None,
+        labels: dict | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            table=table, column_mapping=column_mapping, 
partition_clause=partition_clause, **kwargs
+        )
+        self.table = table
+        self.column_mapping = column_mapping
+        self.partition_clause = partition_clause
+        self.gcp_conn_id = gcp_conn_id
+        self.use_legacy_sql = use_legacy_sql
+        self.location = location
+        self.impersonation_chain = impersonation_chain
+        self.labels = labels
+        # OpenLineage needs a valid SQL query with the input/output table(s) 
to parse
+        self.sql = ""
+
+    def _submit_job(
+        self,
+        hook: BigQueryHook,
+        job_id: str,
+    ) -> BigQueryJob:
+        """Submit a new job and get the job id for polling the status using 
Trigger."""
+        configuration = {"query": {"query": self.sql}}
+
+        return hook.insert_job(
+            configuration=configuration,
+            project_id=hook.project_id,
+            location=self.location,
+            job_id=job_id,
+            nowait=False,
+        )
+
+    def execute(self, context=None):
+        """Perform checks on the given columns."""
+        hook = self.get_db_hook()
+        failed_tests = []
+        for column in self.column_mapping:
+            checks = [*self.column_mapping[column]]
+            checks_sql = ",".join([self.column_checks[check].replace("column", 
column) for check in checks])
+            partition_clause_statement = f"WHERE {self.partition_clause}" if 
self.partition_clause else ""
+            self.sql = f"SELECT {checks_sql} FROM {self.table} 
{partition_clause_statement};"
+
+            job_id = hook.generate_job_id(
+                dag_id=self.dag_id,
+                task_id=self.task_id,
+                logical_date=context["logical_date"],
+                configuration=self.configuration,
+            )
+            job = self._submit_job(hook, job_id=job_id)
+            context["ti"].xcom_push(key="job_id", value=job.job_id)
+            records = list(job.result().to_dataframe().values.flatten())
+
+            if not records:
+                raise AirflowException(f"The following query returned zero 
rows: {self.sql}")
+
+            self.log.info("Record: %s", records)
+
+            for idx, result in enumerate(records):
+                tolerance = 
self.column_mapping[column][checks[idx]].get("tolerance")
+
+                self.column_mapping[column][checks[idx]]["result"] = result
+                self.column_mapping[column][checks[idx]]["success"] = 
self._get_match(
+                    self.column_mapping[column][checks[idx]], result, tolerance
+                )
+
+            
failed_tests.extend(_get_failed_checks(self.column_mapping[column], column))
+        if failed_tests:
+            raise AirflowException(
+                f"Test failed.\nResults:\n{records!s}\n"
+                "The following tests have failed:"
+                f"\n{''.join(failed_tests)}"
+            )
+
+        self.log.info("All tests have passed")
+
+
+class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
+    """
+    BigQueryTableCheckOperator subclasses the SQLTableCheckOperator
+    in order to provide a job id for OpenLineage to parse. See base class
+    for usage.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:BigQueryTableCheckOperator`
+
+    :param table: the table name
+    :param checks: a dictionary of check names and boolean SQL statements
+    :param partition_clause: a string SQL statement added to a WHERE clause
+        to partition data
+    :param gcp_conn_id: (Optional) The connection ID used to connect to Google 
Cloud.
+    :param use_legacy_sql: Whether to use legacy SQL (true)
+        or standard SQL (false).
+    :param location: The geographic location of the job. See details at:
+        
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param labels: a dictionary containing labels for the table, passed to 
BigQuery
+    """
+
+    def __init__(
+        self,
+        *,
+        table: str,
+        checks: dict,
+        partition_clause: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        use_legacy_sql: bool = True,
+        location: str | None = None,
+        impersonation_chain: str | Sequence[str] | None = None,
+        labels: dict | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(table=table, checks=checks, 
partition_clause=partition_clause, **kwargs)
+        self.table = table
+        self.checks = checks
+        self.partition_clause = partition_clause
+        self.gcp_conn_id = gcp_conn_id
+        self.use_legacy_sql = use_legacy_sql
+        self.location = location
+        self.impersonation_chain = impersonation_chain
+        self.labels = labels
+        # OpenLineage needs a valid SQL query with the input/output table(s) 
to parse
+        self.sql = ""
+
+    def _submit_job(
+        self,
+        hook: BigQueryHook,
+        job_id: str,
+    ) -> BigQueryJob:
+        """Submit a new job and get the job id for polling the status using 
Trigger."""
+        configuration = {"query": {"query": self.sql}}
+
+        return hook.insert_job(
+            configuration=configuration,
+            project_id=hook.project_id,
+            location=self.location,
+            job_id=job_id,
+            nowait=False,
+        )
+
+    def execute(self, context=None):
+        """Execute the given checks on the table."""
+        hook = self.get_db_hook()
+        checks_sql = " UNION ALL ".join(
+            [
+                self.sql_check_template.replace("check_statement", 
value["check_statement"])
+                .replace("_check_name", check_name)
+                .replace("table", self.table)
+                for check_name, value in self.checks.items()
+            ]
+        )
+        partition_clause_statement = f"WHERE {self.partition_clause}" if 
self.partition_clause else ""
+        self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) "
+        f"AS check_table {partition_clause_statement};"
+
+        job_id = hook.generate_job_id(
+            dag_id=self.dag_id,
+            task_id=self.task_id,
+            logical_date=context["logical_date"],
+            configuration=self.configuration,

Review Comment:
   the self.configuration does not exist (or I can't see it). But I see that 
bigquery allows to pass job_id as empty string. In that case (check other 
operators) the following method will be invoked to set job_id: _custom_job_id . 
So you probably don't need to use generate_job_id .



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to