This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 33f81bfb93 Update snowflake naming for account names and locators.
(#41775)
33f81bfb93 is described below
commit 33f81bfb93a25dfd190213c2a2aaa03958a0fb10
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Tue Aug 27 10:41:29 2024 +0200
Update snowflake naming for account names and locators. (#41775)
Signed-off-by: Jakub Dardzinski <[email protected]>
---
airflow/providers/snowflake/utils/openlineage.py | 42 +++++++++++++++-------
.../providers/snowflake/utils/test_openlineage.py | 26 +++++++++++++-
2 files changed, 54 insertions(+), 14 deletions(-)
diff --git a/airflow/providers/snowflake/utils/openlineage.py
b/airflow/providers/snowflake/utils/openlineage.py
index bbf1a90387..12821d49f3 100644
--- a/airflow/providers/snowflake/utils/openlineage.py
+++ b/airflow/providers/snowflake/utils/openlineage.py
@@ -21,16 +21,34 @@ from urllib.parse import quote, urlparse, urlunparse
def fix_account_name(name: str) -> str:
"""Fix account name to have the following format:
<account_id>.<region>.<cloud>."""
- spl = name.split(".")
- if len(spl) == 1:
- account = spl[0]
- region, cloud = "us-west-1", "aws"
- elif len(spl) == 2:
- account, region = spl
- cloud = "aws"
- else:
- account, region, cloud = spl
- return f"{account}.{region}.{cloud}"
+ if not any(word in name for word in ["-", "_"]):
+ # If there is neither '-' nor '_' in the name, we append
`.us-west-1.aws`
+ return f"{name}.us-west-1.aws"
+
+ if "." in name:
+ # Logic for account locator with dots remains unchanged
+ spl = name.split(".")
+ if len(spl) == 1:
+ account = spl[0]
+ region, cloud = "us-west-1", "aws"
+ elif len(spl) == 2:
+ account, region = spl
+ cloud = "aws"
+ else:
+ account, region, cloud = spl
+ return f"{account}.{region}.{cloud}"
+
+ # Check for existing accounts with cloud names
+ if cloud := next((c for c in ["aws", "gcp", "azure"] if c in name), ""):
+ parts = name.split(cloud)
+ account = parts[0].strip("-_.")
+
+ if not (region := parts[1].strip("-_.").replace("_", "-")):
+ return name
+ return f"{account}.{region}.{cloud}"
+
+ # Default case, return the original name
+ return name
def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
@@ -57,8 +75,6 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
if not hostname:
return uri
- # old account identifier like xy123456
- if "." in hostname or not any(word in hostname for word in ["-", "_"]):
- hostname = fix_account_name(hostname)
+ hostname = fix_account_name(hostname)
# else - its new hostname, just return it
return urlunparse((parts.scheme, hostname, parts.path, parts.params,
parts.query, parts.fragment))
diff --git a/tests/providers/snowflake/utils/test_openlineage.py
b/tests/providers/snowflake/utils/test_openlineage.py
index a85ed9c2af..393341c5ff 100644
--- a/tests/providers/snowflake/utils/test_openlineage.py
+++ b/tests/providers/snowflake/utils/test_openlineage.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import pytest
-from airflow.providers.snowflake.utils.openlineage import
fix_snowflake_sqlalchemy_uri
+from airflow.providers.snowflake.utils.openlineage import fix_account_name,
fix_snowflake_sqlalchemy_uri
@pytest.mark.parametrize(
@@ -60,3 +60,27 @@ from airflow.providers.snowflake.utils.openlineage import
fix_snowflake_sqlalche
)
def test_snowflake_sqlite_account_urls(source, target):
assert fix_snowflake_sqlalchemy_uri(source) == target
+
+
+# Unit Tests using pytest.mark.parametrize
[email protected](
+ "name, expected",
+ [
+ ("xy12345", "xy12345.us-west-1.aws"), # No '-' or '_' in name
+ ("xy12345.us-west-1.aws", "xy12345.us-west-1.aws"), # Already
complete locator
+ ("xy12345.us-west-2.gcp", "xy12345.us-west-2.gcp"), # Already
complete locator for GCP
+ ("xy12345aws", "xy12345aws.us-west-1.aws"), # AWS without '-' or '_'
+ ("xy12345-aws", "xy12345-aws"), # AWS with '-'
+ ("xy12345_gcp-europe-west1", "xy12345.europe-west1.gcp"), # GCP with
'_'
+ ("myaccount_gcp-asia-east1", "myaccount.asia-east1.gcp"), # GCP with
region and '_'
+ ("myaccount_azure-eastus", "myaccount.eastus.azure"), # Azure with
region
+ ("myorganization-1234", "myorganization-1234"), # No change needed
+ ("my.organization", "my.organization.us-west-1.aws"), # Dot in name
+ ],
+)
+def test_fix_account_name(name, expected):
+ assert fix_account_name(name) == expected
+ assert (
+ fix_snowflake_sqlalchemy_uri(f"snowflake://{name}/database/schema")
+ == f"snowflake://{expected}/database/schema"
+ )