This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3dc99d8a28 feat: Add openlineage support for
CopyFromExternalStageToSnowflakeOperator (#36535)
3dc99d8a28 is described below
commit 3dc99d8a285aaadeb83797e691c9f6ec93ff9c93
Author: Kacper Muda <[email protected]>
AuthorDate: Mon Jan 8 13:02:46 2024 +0100
feat: Add openlineage support for CopyFromExternalStageToSnowflakeOperator
(#36535)
---
.../snowflake/transfers/copy_into_snowflake.py | 163 +++++++++++++++++++-
.../transfers/test_copy_into_snowflake.py | 168 ++++++++++++++++++++-
2 files changed, 327 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/snowflake/transfers/copy_into_snowflake.py
b/airflow/providers/snowflake/transfers/copy_into_snowflake.py
index 10071add1a..342d5dc35a 100644
--- a/airflow/providers/snowflake/transfers/copy_into_snowflake.py
+++ b/airflow/providers/snowflake/transfers/copy_into_snowflake.py
@@ -108,8 +108,12 @@ class
CopyFromExternalStageToSnowflakeOperator(BaseOperator):
self.copy_options = copy_options
self.validation_mode = validation_mode
+ self.hook: SnowflakeHook | None = None
+ self._sql: str | None = None
+ self._result: list[dict[str, Any]] = []
+
def execute(self, context: Any) -> None:
- snowflake_hook = SnowflakeHook(
+ self.hook = SnowflakeHook(
snowflake_conn_id=self.snowflake_conn_id,
warehouse=self.warehouse,
database=self.database,
@@ -127,7 +131,7 @@ class
CopyFromExternalStageToSnowflakeOperator(BaseOperator):
if self.columns_array:
into = f"{into}({', '.join(self.columns_array)})"
- sql = f"""
+ self._sql = f"""
COPY INTO {into}
FROM @{self.stage}/{self.prefix or ""}
{"FILES=(" + ",".join(map(enclose_param, self.files)) + ")" if
self.files else ""}
@@ -137,5 +141,158 @@ class
CopyFromExternalStageToSnowflakeOperator(BaseOperator):
{self.validation_mode or ""}
"""
self.log.info("Executing COPY command...")
- snowflake_hook.run(sql=sql, autocommit=self.autocommit)
+ self._result = self.hook.run( # type: ignore # mypy does not work
well with return_dictionaries=True
+ sql=self._sql,
+ autocommit=self.autocommit,
+ handler=lambda x: x.fetchall(),
+ return_dictionaries=True,
+ )
self.log.info("COPY command completed")
+
+ @staticmethod
+ def _extract_openlineage_unique_dataset_paths(
+ query_result: list[dict[str, Any]],
+ ) -> tuple[list[tuple[str, str]], list[str]]:
+ """Extracts and returns unique OpenLineage dataset paths and file
paths that failed to be parsed.
+
+ Each row in the results is expected to have a 'file' field, which is a
URI.
+ The function parses these URIs and constructs a set of unique
OpenLineage (namespace, name) tuples.
+ Additionally, it captures any URIs that cannot be parsed or processed
+ and returns them in a separate error list.
+
+ For Azure, Snowflake has a unique way of representing URI:
+
azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv
+ that is transformed by this function to a Dataset with more universal
naming convention:
+ Dataset(namespace="wasbs://container_name@account_name",
name="path/to"), as described at
+
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage
+
+ :param query_result: A list of dictionaries, each containing a 'file'
key with a URI value.
+ :return: Two lists - the first is a sorted list of tuples, each
representing a unique dataset path,
+ and the second contains any URIs that cannot be parsed or processed
correctly.
+
+ >>> method =
CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths
+
+ >>> results = [{"file":
"azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}]
+ >>> method(results)
+ ([('wasbs://azure_container@my_account', 'dir3')], [])
+
+ >>> results = [{"file":
"azure://my_account.blob.core.windows.net/azure_container"}]
+ >>> method(results)
+ ([('wasbs://azure_container@my_account', '/')], [])
+
+ >>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"},
{"file": "s3://bucket/a.csv"}]
+ >>> method(results)
+ ([('gcs://bucket', '/'), ('s3://bucket', '/')], [])
+
+ >>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file":
"gcs://bucket/dir/dir2/a.txt"}]
+ >>> method(results)
+ ([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], [])
+
+ >>> results = [
+ ... {"file": "s3://bucket/dir/file.csv"},
+ ... {"file":
"azure://my_account.something_new.windows.net/azure_container"},
+ ... ]
+ >>> method(results)
+ ([('s3://bucket', 'dir')],
['azure://my_account.something_new.windows.net/azure_container'])
+ """
+ import re
+ from pathlib import Path
+ from urllib.parse import urlparse
+
+ azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?"
+ extraction_error_files = []
+ unique_dataset_paths = set()
+
+ for row in query_result:
+ uri = urlparse(row["file"])
+ if uri.scheme == "azure":
+ match = re.fullmatch(azure_regex, row["file"])
+ if not match:
+ extraction_error_files.append(row["file"])
+ continue
+ account_name, container_name, name = match.groups()
+ namespace = f"wasbs://{container_name}@{account_name}"
+ else:
+ namespace = f"{uri.scheme}://{uri.netloc}"
+ name = uri.path.lstrip("/")
+
+ name = Path(name).parent.as_posix()
+ if name in ("", "."):
+ name = "/"
+
+ unique_dataset_paths.add((namespace, name))
+
+ return sorted(unique_dataset_paths), sorted(extraction_error_files)
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Implement _on_complete because we rely on return value of a
query."""
+ import re
+
+ from openlineage.client.facet import (
+ ExternalQueryRunFacet,
+ ExtractionError,
+ ExtractionErrorRunFacet,
+ SqlJobFacet,
+ )
+ from openlineage.client.run import Dataset
+
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self._sql:
+ return OperatorLineage()
+
+ query_results = self._result or []
+ # If no files were uploaded we get [{"status": "0 files were
uploaded..."}]
+ if len(query_results) == 1 and query_results[0].get("status"):
+ query_results = []
+ unique_dataset_paths, extraction_error_files =
self._extract_openlineage_unique_dataset_paths(
+ query_results
+ )
+ input_datasets = [Dataset(namespace=namespace, name=name) for
namespace, name in unique_dataset_paths]
+
+ run_facets = {}
+ if extraction_error_files:
+ self.log.debug(
+ f"Unable to extract Dataset namespace and name "
+ f"for the following files: `{extraction_error_files}`."
+ )
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
+ totalTasks=len(query_results),
+ failedTasks=len(extraction_error_files),
+ errors=[
+ ExtractionError(
+ errorMessage="Unable to extract Dataset namespace and
name.",
+ stackTrace=None,
+ task=file_uri,
+ taskNumber=None,
+ )
+ for file_uri in extraction_error_files
+ ],
+ )
+
+ connection = self.hook.get_connection(getattr(self.hook,
str(self.hook.conn_name_attr)))
+ database_info = self.hook.get_openlineage_database_info(connection)
+
+ dest_name = self.table
+ schema = self.hook.get_openlineage_default_schema()
+ database = database_info.database
+ if schema:
+ dest_name = f"{schema}.{dest_name}"
+ if database:
+ dest_name = f"{database}.{dest_name}"
+
+ snowflake_namespace = SQLParser.create_namespace(database_info)
+ query = SQLParser.normalize_sql(self._sql)
+ query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query))
+
+ run_facets["externalQuery"] = ExternalQueryRunFacet(
+ externalQueryId=self.hook.query_ids[0], source=snowflake_namespace
+ )
+
+ return OperatorLineage(
+ inputs=input_datasets,
+ outputs=[Dataset(namespace=snowflake_namespace, name=dest_name)],
+ job_facets={"sql": SqlJobFacet(query=query)},
+ run_facets=run_facets,
+ )
diff --git a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py
b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py
index 76268d077d..27e02dc41c 100644
--- a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py
+++ b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py
@@ -16,8 +16,20 @@
# under the License.
from __future__ import annotations
+from typing import Callable
from unittest import mock
+from openlineage.client.facet import (
+ ExternalQueryRunFacet,
+ ExtractionError,
+ ExtractionErrorRunFacet,
+ SqlJobFacet,
+)
+from openlineage.client.run import Dataset
+from pytest import mark
+
+from airflow.providers.openlineage.extractors import OperatorLineage
+from airflow.providers.openlineage.sqlparser import DatabaseInfo
from airflow.providers.snowflake.transfers.copy_into_snowflake import
CopyFromExternalStageToSnowflakeOperator
@@ -62,4 +74,158 @@ class TestCopyFromExternalStageToSnowflake:
validation_mode
"""
- mock_hook.return_value.run.assert_called_once_with(sql=sql,
autocommit=True)
+ mock_hook.return_value.run.assert_called_once_with(
+ sql=sql, autocommit=True, return_dictionaries=True,
handler=mock.ANY
+ )
+
+ handler =
mock_hook.return_value.run.mock_calls[0].kwargs.get("handler")
+ assert isinstance(handler, Callable)
+
+
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
+ def test_get_openlineage_facets_on_complete(self, mock_hook):
+ mock_hook().run.return_value = [
+ {"file": "s3://aws_bucket_name/dir1/file.csv"},
+ {"file": "s3://aws_bucket_name_2"},
+ {"file": "gcs://gcs_bucket_name/dir2/file.csv"},
+ {"file": "gcs://gcs_bucket_name_2"},
+ {"file":
"azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"},
+ {"file":
"azure://my_account.blob.core.windows.net/azure_container_2"},
+ ]
+ mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
+ scheme="snowflake_scheme", authority="authority",
database="actual_database"
+ )
+ mock_hook().get_openlineage_default_schema.return_value =
"actual_schema"
+ mock_hook().query_ids = ["query_id_123"]
+
+ expected_inputs = [
+ Dataset(namespace="gcs://gcs_bucket_name", name="dir2"),
+ Dataset(namespace="gcs://gcs_bucket_name_2", name="/"),
+ Dataset(namespace="s3://aws_bucket_name", name="dir1"),
+ Dataset(namespace="s3://aws_bucket_name_2", name="/"),
+ Dataset(namespace="wasbs://azure_container@my_account",
name="dir3"),
+ Dataset(namespace="wasbs://azure_container_2@my_account",
name="/"),
+ ]
+ expected_outputs = [
+ Dataset(namespace="snowflake_scheme://authority",
name="actual_database.actual_schema.table")
+ ]
+ expected_sql = """COPY INTO schema.table\n FROM @stage/\n
FILE_FORMAT=CSV"""
+
+ op = CopyFromExternalStageToSnowflakeOperator(
+ task_id="test",
+ table="table",
+ stage="stage",
+ database="",
+ schema="schema",
+ file_format="CSV",
+ )
+ op.execute(None)
+ result = op.get_openlineage_facets_on_complete(None)
+ assert result == OperatorLineage(
+ inputs=expected_inputs,
+ outputs=expected_outputs,
+ run_facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query_id_123",
source="snowflake_scheme://authority"
+ )
+ },
+ job_facets={"sql": SqlJobFacet(query=expected_sql)},
+ )
+
+ @mark.parametrize("rows", (None, []))
+
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
+ def test_get_openlineage_facets_on_complete_with_empty_inputs(self,
mock_hook, rows):
+ mock_hook().run.return_value = rows
+ mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
+ scheme="snowflake_scheme", authority="authority",
database="actual_database"
+ )
+ mock_hook().get_openlineage_default_schema.return_value =
"actual_schema"
+ mock_hook().query_ids = ["query_id_123"]
+
+ expected_outputs = [
+ Dataset(namespace="snowflake_scheme://authority",
name="actual_database.actual_schema.table")
+ ]
+ expected_sql = """COPY INTO schema.table\n FROM @stage/\n
FILE_FORMAT=CSV"""
+
+ op = CopyFromExternalStageToSnowflakeOperator(
+ task_id="test",
+ table="table",
+ stage="stage",
+ database="",
+ schema="schema",
+ file_format="CSV",
+ )
+ op.execute(None)
+ result = op.get_openlineage_facets_on_complete(None)
+ assert result == OperatorLineage(
+ inputs=[],
+ outputs=expected_outputs,
+ run_facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query_id_123",
source="snowflake_scheme://authority"
+ )
+ },
+ job_facets={"sql": SqlJobFacet(query=expected_sql)},
+ )
+
+
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
+ def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self,
mock_hook):
+ mock_hook().run.return_value = [
+ {"file": "s3://aws_bucket_name/dir1/file.csv"},
+ {"file": "gs://gcp_bucket_name/dir2/file.csv"},
+ {"file":
"azure://my_account.weird-url.net/azure_container/dir3/file.csv"},
+ {"file": "azure://my_account.another_weird-url.net/con/file.csv"},
+ ]
+ mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
+ scheme="snowflake_scheme", authority="authority",
database="actual_database"
+ )
+ mock_hook().get_openlineage_default_schema.return_value =
"actual_schema"
+ mock_hook().query_ids = ["query_id_123"]
+
+ expected_inputs = [
+ Dataset(namespace="gs://gcp_bucket_name", name="dir2"),
+ Dataset(namespace="s3://aws_bucket_name", name="dir1"),
+ ]
+ expected_outputs = [
+ Dataset(namespace="snowflake_scheme://authority",
name="actual_database.actual_schema.table")
+ ]
+ expected_sql = """COPY INTO schema.table\n FROM @stage/\n
FILE_FORMAT=CSV"""
+ expected_run_facets = {
+ "extractionError": ExtractionErrorRunFacet(
+ totalTasks=4,
+ failedTasks=2,
+ errors=[
+ ExtractionError(
+ errorMessage="Unable to extract Dataset namespace and
name.",
+ stackTrace=None,
+
task="azure://my_account.another_weird-url.net/con/file.csv",
+ taskNumber=None,
+ ),
+ ExtractionError(
+ errorMessage="Unable to extract Dataset namespace and
name.",
+ stackTrace=None,
+
task="azure://my_account.weird-url.net/azure_container/dir3/file.csv",
+ taskNumber=None,
+ ),
+ ],
+ ),
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query_id_123",
source="snowflake_scheme://authority"
+ ),
+ }
+
+ op = CopyFromExternalStageToSnowflakeOperator(
+ task_id="test",
+ table="table",
+ stage="stage",
+ database="",
+ schema="schema",
+ file_format="CSV",
+ )
+ op.execute(None)
+ result = op.get_openlineage_facets_on_complete(None)
+ assert result == OperatorLineage(
+ inputs=expected_inputs,
+ outputs=expected_outputs,
+ run_facets=expected_run_facets,
+ job_facets={"sql": SqlJobFacet(query=expected_sql)},
+ )