This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 33f81bfb93 Update snowflake naming for account names and locators. 
(#41775)
33f81bfb93 is described below

commit 33f81bfb93a25dfd190213c2a2aaa03958a0fb10
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Tue Aug 27 10:41:29 2024 +0200

    Update snowflake naming for account names and locators. (#41775)
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
---
 airflow/providers/snowflake/utils/openlineage.py   | 42 +++++++++++++++-------
 .../providers/snowflake/utils/test_openlineage.py  | 26 +++++++++++++-
 2 files changed, 54 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/snowflake/utils/openlineage.py 
b/airflow/providers/snowflake/utils/openlineage.py
index bbf1a90387..12821d49f3 100644
--- a/airflow/providers/snowflake/utils/openlineage.py
+++ b/airflow/providers/snowflake/utils/openlineage.py
@@ -21,16 +21,34 @@ from urllib.parse import quote, urlparse, urlunparse
 
 def fix_account_name(name: str) -> str:
     """Fix account name to have the following format: 
<account_id>.<region>.<cloud>."""
-    spl = name.split(".")
-    if len(spl) == 1:
-        account = spl[0]
-        region, cloud = "us-west-1", "aws"
-    elif len(spl) == 2:
-        account, region = spl
-        cloud = "aws"
-    else:
-        account, region, cloud = spl
-    return f"{account}.{region}.{cloud}"
+    if not any(word in name for word in ["-", "_"]):
+        # If there is neither '-' nor '_' in the name, we append 
`.us-west-1.aws`
+        return f"{name}.us-west-1.aws"
+
+    if "." in name:
+        # Logic for account locator with dots remains unchanged
+        spl = name.split(".")
+        if len(spl) == 1:
+            account = spl[0]
+            region, cloud = "us-west-1", "aws"
+        elif len(spl) == 2:
+            account, region = spl
+            cloud = "aws"
+        else:
+            account, region, cloud = spl
+        return f"{account}.{region}.{cloud}"
+
+    # Check for existing accounts with cloud names
+    if cloud := next((c for c in ["aws", "gcp", "azure"] if c in name), ""):
+        parts = name.split(cloud)
+        account = parts[0].strip("-_.")
+
+        if not (region := parts[1].strip("-_.").replace("_", "-")):
+            return name
+        return f"{account}.{region}.{cloud}"
+
+    # Default case, return the original name
+    return name
 
 
 def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
@@ -57,8 +75,6 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
     if not hostname:
         return uri
 
-    # old account identifier like xy123456
-    if "." in hostname or not any(word in hostname for word in ["-", "_"]):
-        hostname = fix_account_name(hostname)
+    hostname = fix_account_name(hostname)
     # else - its new hostname, just return it
     return urlunparse((parts.scheme, hostname, parts.path, parts.params, 
parts.query, parts.fragment))
diff --git a/tests/providers/snowflake/utils/test_openlineage.py 
b/tests/providers/snowflake/utils/test_openlineage.py
index a85ed9c2af..393341c5ff 100644
--- a/tests/providers/snowflake/utils/test_openlineage.py
+++ b/tests/providers/snowflake/utils/test_openlineage.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import pytest
 
-from airflow.providers.snowflake.utils.openlineage import 
fix_snowflake_sqlalchemy_uri
+from airflow.providers.snowflake.utils.openlineage import fix_account_name, 
fix_snowflake_sqlalchemy_uri
 
 
 @pytest.mark.parametrize(
@@ -60,3 +60,27 @@ from airflow.providers.snowflake.utils.openlineage import 
fix_snowflake_sqlalche
 )
 def test_snowflake_sqlite_account_urls(source, target):
     assert fix_snowflake_sqlalchemy_uri(source) == target
+
+
+# Unit Tests using pytest.mark.parametrize
[email protected](
+    "name, expected",
+    [
+        ("xy12345", "xy12345.us-west-1.aws"),  # No '-' or '_' in name
+        ("xy12345.us-west-1.aws", "xy12345.us-west-1.aws"),  # Already 
complete locator
+        ("xy12345.us-west-2.gcp", "xy12345.us-west-2.gcp"),  # Already 
complete locator for GCP
+        ("xy12345aws", "xy12345aws.us-west-1.aws"),  # AWS without '-' or '_'
+        ("xy12345-aws", "xy12345-aws"),  # AWS with '-'
+        ("xy12345_gcp-europe-west1", "xy12345.europe-west1.gcp"),  # GCP with 
'_'
+        ("myaccount_gcp-asia-east1", "myaccount.asia-east1.gcp"),  # GCP with 
region and '_'
+        ("myaccount_azure-eastus", "myaccount.eastus.azure"),  # Azure with 
region
+        ("myorganization-1234", "myorganization-1234"),  # No change needed
+        ("my.organization", "my.organization.us-west-1.aws"),  # Dot in name
+    ],
+)
+def test_fix_account_name(name, expected):
+    assert fix_account_name(name) == expected
+    assert (
+        fix_snowflake_sqlalchemy_uri(f"snowflake://{name}/database/schema")
+        == f"snowflake://{expected}/database/schema"
+    )

Reply via email to