This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch 
snowflake-openlineage-dontuseexternalconnection
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit e94fb0809b59ed37ab0dc1400b9ab57a9413c38a
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Thu Apr 18 00:21:18 2024 +0200

    openlineage, snowflake: do not run queries for Snowflake
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/providers/common/sql/operators/sql.py      |  9 +++
 airflow/providers/openlineage/sqlparser.py         | 61 +++++++++++++++----
 airflow/providers/openlineage/utils/utils.py       |  5 ++
 airflow/providers/snowflake/hooks/snowflake.py     | 24 ++------
 .../providers/snowflake/hooks/snowflake_sql_api.py |  8 +--
 tests/providers/snowflake/hooks/test_snowflake.py  | 37 +++++-------
 .../snowflake/hooks/test_snowflake_sql_api.py      | 32 +++++++---
 .../snowflake/operators/test_snowflake_sql.py      | 69 ++++------------------
 8 files changed, 124 insertions(+), 121 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index 1fd22b86b7..ea791992d5 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -309,6 +309,14 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
 
         hook = self.get_db_hook()
 
+        try:
+            from airflow.providers.openlineage.utils.utils import 
should_use_external_connection
+
+            use_external_connection = should_use_external_connection(hook)
+        except ImportError:
+            # OpenLineage provider release < 1.8.0 - we always use connection
+            use_external_connection = True
+
         connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
         try:
             database_info = hook.get_openlineage_database_info(connection)
@@ -334,6 +342,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             database_info=database_info,
             database=self.database,
             sqlalchemy_engine=hook.get_sqlalchemy_engine(),
+            use_connection=use_external_connection,
         )
 
         return operator_lineage
diff --git a/airflow/providers/openlineage/sqlparser.py 
b/airflow/providers/openlineage/sqlparser.py
index c27dedc53c..f181ff8cce 100644
--- a/airflow/providers/openlineage/sqlparser.py
+++ b/airflow/providers/openlineage/sqlparser.py
@@ -29,6 +29,7 @@ from openlineage.client.facet import (
     ExtractionErrorRunFacet,
     SqlJobFacet,
 )
+from openlineage.client.run import Dataset
 from openlineage.common.sql import DbTableMeta, SqlMeta, parse
 
 from airflow.providers.openlineage.extractors.base import OperatorLineage
@@ -40,7 +41,6 @@ from airflow.providers.openlineage.utils.sql import (
 from airflow.typing_compat import TypedDict
 
 if TYPE_CHECKING:
-    from openlineage.client.run import Dataset
     from sqlalchemy.engine import Engine
 
     from airflow.hooks.base import BaseHook
@@ -104,6 +104,18 @@ class DatabaseInfo:
     normalize_name_method: Callable[[str], str] = default_normalize_name_method
 
 
+def from_table_meta(
+    table_meta: DbTableMeta, database: str | None, namespace: str, 
is_uppercase: bool
+) -> Dataset:
+    if table_meta.database:
+        name = table_meta.qualified_name
+    elif database:
+        name = f"{database}.{table_meta.schema}.{table_meta.name}"
+    else:
+        name = f"{table_meta.schema}.{table_meta.name}"
+    return Dataset(namespace=namespace, name=name if not is_uppercase else 
name.upper())
+
+
 class SQLParser:
     """Interface for openlineage-sql.
 
@@ -117,7 +129,7 @@ class SQLParser:
 
     def parse(self, sql: list[str] | str) -> SqlMeta | None:
         """Parse a single or a list of SQL statements."""
-        return parse(sql=sql, dialect=self.dialect)
+        return parse(sql=sql, dialect=self.dialect, 
default_schema=self.default_schema)
 
     def parse_table_schemas(
         self,
@@ -156,6 +168,23 @@ class SQLParser:
             else None,
         )
 
+    def get_metadata_from_parser(
+        self,
+        inputs: list[DbTableMeta],
+        outputs: list[DbTableMeta],
+        database_info: DatabaseInfo,
+        namespace: str = DEFAULT_NAMESPACE,
+        database: str | None = None,
+    ) -> tuple[list[Dataset], ...]:
+        database = database if database else database_info.database
+        return [
+            from_table_meta(dataset, database, namespace, 
database_info.is_uppercase_names)
+            for dataset in inputs
+        ], [
+            from_table_meta(dataset, database, namespace, 
database_info.is_uppercase_names)
+            for dataset in outputs
+        ]
+
     def attach_column_lineage(
         self, datasets: list[Dataset], database: str | None, parse_result: 
SqlMeta
     ) -> None:
@@ -204,6 +233,7 @@ class SQLParser:
         database_info: DatabaseInfo,
         database: str | None = None,
         sqlalchemy_engine: Engine | None = None,
+        use_connection: bool = True,
     ) -> OperatorLineage:
         """Parse SQL statement(s) and generate OpenLineage metadata.
 
@@ -242,15 +272,24 @@ class SQLParser:
             )
 
         namespace = self.create_namespace(database_info=database_info)
-        inputs, outputs = self.parse_table_schemas(
-            hook=hook,
-            inputs=parse_result.in_tables,
-            outputs=parse_result.out_tables,
-            namespace=namespace,
-            database=database,
-            database_info=database_info,
-            sqlalchemy_engine=sqlalchemy_engine,
-        )
+        if use_connection:
+            inputs, outputs = self.parse_table_schemas(
+                hook=hook,
+                inputs=parse_result.in_tables,
+                outputs=parse_result.out_tables,
+                namespace=namespace,
+                database=database,
+                database_info=database_info,
+                sqlalchemy_engine=sqlalchemy_engine,
+            )
+        else:
+            inputs, outputs = self.get_metadata_from_parser(
+                inputs=parse_result.in_tables,
+                outputs=parse_result.out_tables,
+                namespace=namespace,
+                database=database,
+                database_info=database_info,
+            )
 
         self.attach_column_lineage(outputs, database or 
database_info.database, parse_result)
 
diff --git a/airflow/providers/openlineage/utils/utils.py 
b/airflow/providers/openlineage/utils/utils.py
index 1c777aff76..ad1f3b0951 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -384,3 +384,8 @@ def normalize_sql(sql: str | Iterable[str]):
         sql = [stmt for stmt in sql.split(";") if stmt != ""]
     sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
     return ";\n".join(sql)
+
+
+def should_use_external_connection(hook) -> bool:
+    # TODO: Add checking overrides
+    return hook.__class__.__name__ not in ["SnowflakeHook", 
"SnowflakeSqlApiHook"]
diff --git a/airflow/providers/snowflake/hooks/snowflake.py 
b/airflow/providers/snowflake/hooks/snowflake.py
index 075ae21e30..a9b3ee5209 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import os
 from contextlib import closing, contextmanager
+from functools import cached_property
 from io import StringIO
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, 
overload
@@ -177,6 +178,7 @@ class SnowflakeHook(DbApiHook):
             return extra_dict[field_name] or None
         return extra_dict.get(backcompat_key) or None
 
+    @cached_property
     def _get_conn_params(self) -> dict[str, str | None]:
         """Fetch connection params as a dict.
 
@@ -269,7 +271,7 @@ class SnowflakeHook(DbApiHook):
 
     def get_uri(self) -> str:
         """Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
-        conn_params = self._get_conn_params()
+        conn_params = self._get_conn_params
         return self._conn_params_to_sqlalchemy_uri(conn_params)
 
     def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:
@@ -283,7 +285,7 @@ class SnowflakeHook(DbApiHook):
 
     def get_conn(self) -> SnowflakeConnection:
         """Return a snowflake.connection object."""
-        conn_config = self._get_conn_params()
+        conn_config = self._get_conn_params
         conn = connector.connect(**conn_config)
         return conn
 
@@ -294,7 +296,7 @@ class SnowflakeHook(DbApiHook):
         :return: the created engine.
         """
         engine_kwargs = engine_kwargs or {}
-        conn_params = self._get_conn_params()
+        conn_params = self._get_conn_params
         if "insecure_mode" in conn_params:
             engine_kwargs.setdefault("connect_args", {})
             engine_kwargs["connect_args"]["insecure_mode"] = True
@@ -458,21 +460,7 @@ class SnowflakeHook(DbApiHook):
         return "snowflake"
 
     def get_openlineage_default_schema(self) -> str | None:
-        """
-        Attempt to get current schema.
-
-        Usually ``SELECT CURRENT_SCHEMA();`` should work.
-        However, apparently you may set ``database`` without ``schema``
-        and get results from ``SELECT CURRENT_SCHEMAS();`` but not
-        from ``SELECT CURRENT_SCHEMA();``.
-        It still may return nothing if no database is set in connection.
-        """
-        schema = self._get_conn_params()["schema"]
-        if not schema:
-            current_schemas = self.get_first("SELECT 
PARSE_JSON(CURRENT_SCHEMAS())[0]::string;")[0]
-            if current_schemas:
-                _, schema = current_schemas.split(".")
-        return schema
+        return self._get_conn_params["schema"]
 
     def _get_openlineage_authority(self, _) -> str:
         from openlineage.common.provider.snowflake import 
fix_snowflake_sqlalchemy_uri
diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py 
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 6eec055eb5..3f52a43a1d 100644
--- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -86,7 +86,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
     @property
     def account_identifier(self) -> str:
         """Returns snowflake account identifier."""
-        conn_config = self._get_conn_params()
+        conn_config = self._get_conn_params
         account_identifier = f"https://{conn_config['account']}"
 
         if conn_config["region"]:
@@ -147,7 +147,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
             When executing the statement, Snowflake replaces placeholders (? 
and :name) in
             the statement with these specified values.
         """
-        conn_config = self._get_conn_params()
+        conn_config = self._get_conn_params
 
         req_id = uuid.uuid4()
         url = 
f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements"
@@ -186,7 +186,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
 
     def get_headers(self) -> dict[str, Any]:
         """Form auth headers based on either OAuth token or JWT token from 
private key."""
-        conn_config = self._get_conn_params()
+        conn_config = self._get_conn_params
 
         # Use OAuth if refresh_token and client_id and client_secret are 
provided
         if all(
@@ -225,7 +225,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
 
     def get_oauth_token(self) -> str:
         """Generate temporary OAuth access token using refresh token in 
connection details."""
-        conn_config = self._get_conn_params()
+        conn_config = self._get_conn_params
         url = 
f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"
         data = {
             "grant_type": "refresh_token",
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py 
b/tests/providers/snowflake/hooks/test_snowflake.py
index fb5b1a5514..54a18eeca7 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -270,7 +270,7 @@ class TestPytestSnowflakeHook:
     ):
         with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
             assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == 
expected_uri
-            assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == 
expected_conn_params
+            assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == 
expected_conn_params
 
     def test_get_conn_params_should_support_private_auth_in_connection(
         self, encrypted_temporary_private_key: Path
@@ -288,7 +288,7 @@ class TestPytestSnowflakeHook:
             },
         }
         with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
-            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
 
     @pytest.mark.parametrize("include_params", [True, False])
     def test_hook_param_beats_extra(self, include_params):
@@ -311,7 +311,7 @@ class TestPytestSnowflakeHook:
             assert hook_params != extras
             assert SnowflakeHook(
                 snowflake_conn_id="test_conn", **(hook_params if 
include_params else {})
-            )._get_conn_params() == {
+            )._get_conn_params == {
                 "user": None,
                 "password": "",
                 "application": "AIRFLOW",
@@ -340,7 +340,7 @@ class TestPytestSnowflakeHook:
             ).get_uri(),
         ):
             assert list(extras.values()) != list(extras_prefixed.values())
-            assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == {
+            assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == {
                 "user": None,
                 "password": "",
                 "application": "AIRFLOW",
@@ -366,7 +366,7 @@ class TestPytestSnowflakeHook:
             },
         }
         with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
-            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
 
     def test_get_conn_params_should_support_private_auth_with_unencrypted_key(
         self, non_encrypted_temporary_private_key
@@ -384,15 +384,15 @@ class TestPytestSnowflakeHook:
             },
         }
         with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
-            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
         connection_kwargs["password"] = ""
         with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
-            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+            assert "private_key" in 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
         connection_kwargs["password"] = _PASSWORD
         with mock.patch.dict(
             "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
         ), pytest.raises(TypeError, match="Password was given but private key 
is not encrypted."):
-            SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+            SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
 
     def test_get_conn_params_should_fail_on_invalid_key(self):
         connection_kwargs = {
@@ -419,8 +419,7 @@ class TestPytestSnowflakeHook:
             AIRFLOW_SNOWFLAKE_PARTNER="PARTNER_NAME",
         ):
             assert (
-                
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()["application"]
-                == "PARTNER_NAME"
+                
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params["application"] == 
"PARTNER_NAME"
             )
 
     def test_get_conn_should_call_connect(self):
@@ -429,7 +428,7 @@ class TestPytestSnowflakeHook:
         ), mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") 
as mock_connector:
             hook = SnowflakeHook(snowflake_conn_id="test_conn")
             conn = hook.get_conn()
-            
mock_connector.connect.assert_called_once_with(**hook._get_conn_params())
+            
mock_connector.connect.assert_called_once_with(**hook._get_conn_params)
             assert mock_connector.connect.return_value == conn
 
     def test_get_sqlalchemy_engine_should_support_pass_auth(self):
@@ -516,7 +515,7 @@ class TestPytestSnowflakeHook:
                 "session_parameters": {"AA": "AAA"},
                 "user": "user",
                 "warehouse": "TEST_WAREHOUSE",
-            } == hook._get_conn_params()
+            } == hook._get_conn_params
             assert (
                 
"snowflake://user:pw@TEST_ACCOUNT.TEST_REGION/TEST_DATABASE/TEST_SCHEMA"
                 
"?application=AIRFLOW&authenticator=TEST_AUTH&role=TEST_ROLE&warehouse=TEST_WAREHOUSE"
@@ -587,22 +586,14 @@ class TestPytestSnowflakeHook:
                 hook.run(sql=empty_statement)
             assert err.value.args[0] == "List of SQL statements is empty"
 
-    @pytest.mark.parametrize(
-        "returned_schema,expected_schema",
-        [([None], ""), (["DATABASE.SCHEMA"], "SCHEMA")],
-    )
-    @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
-    def test_get_openlineage_default_schema_with_no_schema_set(
-        self, mock_get_first, returned_schema, expected_schema
-    ):
+    def test_get_openlineage_default_schema_with_no_schema_set(self):
         connection_kwargs = {
             **BASE_CONNECTION_KWARGS,
-            "schema": None,
+            "schema": "PUBLIC",
         }
         with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
             hook = SnowflakeHook(snowflake_conn_id="test_conn")
-            mock_get_first.return_value = returned_schema
-            assert hook.get_openlineage_default_schema() == expected_schema
+            assert hook.get_openlineage_default_schema() == "PUBLIC"
 
     @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
     def test_get_openlineage_default_schema_with_schema_set(self, 
mock_get_first):
diff --git a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py 
b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
index bfd978755e..5ba18d6e12 100644
--- a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
+++ b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
@@ -20,7 +20,7 @@ import unittest
 import uuid
 from typing import TYPE_CHECKING, Any
 from unittest import mock
-from unittest.mock import AsyncMock
+from unittest.mock import AsyncMock, PropertyMock
 
 import pytest
 import requests
@@ -168,7 +168,10 @@ class TestSnowflakeSqlApiHook:
         ],
     )
     @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
-    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
     def test_execute_query(
         self,
@@ -197,7 +200,10 @@ class TestSnowflakeSqlApiHook:
         [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])],
     )
     @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
-    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
     def test_execute_query_exception_without_statement_handle(
         self,
@@ -262,7 +268,10 @@ class TestSnowflakeSqlApiHook:
             with pytest.raises(AirflowException, match='Response: {"foo": 
"bar"}, Status Code: 500'):
                 hook.check_query_output(query_ids)
 
-    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
     def test_get_request_url_header_params(self, mock_get_header, 
mock_conn_param):
         """Test get_request_url_header_params by mocking _get_conn_params and 
get_headers"""
@@ -274,7 +283,10 @@ class TestSnowflakeSqlApiHook:
         assert url == 
"https://airflow.af_region.snowflakecomputing.com/api/v2/statements/uuid";
 
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_private_key")
-    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
     
@mock.patch("airflow.providers.snowflake.utils.sql_api_generate_jwt.JWTGenerator.get_token")
     def test_get_headers_should_support_private_key(self, mock_get_token, 
mock_conn_param, mock_private_key):
         """Test get_headers method by mocking get_private_key and 
_get_conn_params method"""
@@ -285,7 +297,10 @@ class TestSnowflakeSqlApiHook:
         assert result == HEADERS
 
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_oauth_token")
-    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
     def test_get_headers_should_support_oauth(self, mock_conn_param, 
mock_oauth_token):
         """Test get_headers method by mocking get_oauth_token and 
_get_conn_params method"""
         mock_conn_param.return_value = CONN_PARAMS_OAUTH
@@ -296,7 +311,10 @@ class TestSnowflakeSqlApiHook:
 
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.HTTPBasicAuth")
     @mock.patch("requests.post")
-    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
     def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
         """Test get_oauth_token method makes the right http request"""
         BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
diff --git a/tests/providers/snowflake/operators/test_snowflake_sql.py 
b/tests/providers/snowflake/operators/test_snowflake_sql.py
index a79064e341..87d77ca813 100644
--- a/tests/providers/snowflake/operators/test_snowflake_sql.py
+++ b/tests/providers/snowflake/operators/test_snowflake_sql.py
@@ -17,7 +17,8 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import MagicMock, call, patch
+from unittest import mock
+from unittest.mock import MagicMock, patch
 
 import pytest
 from _pytest.outcomes import importorskip
@@ -37,8 +38,6 @@ from openlineage.client.facet import (
     ColumnLineageDatasetFacet,
     ColumnLineageDatasetFacetFieldsAdditional,
     ColumnLineageDatasetFacetFieldsAdditionalInputFields,
-    SchemaDatasetFacet,
-    SchemaField,
     SqlJobFacet,
 )
 from openlineage.client.run import Dataset
@@ -163,7 +162,9 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
         )
 
 
-def test_execute_openlineage_events():
[email protected]("airflow.providers.openlineage.utils.utils.should_use_external_connection")
+def test_execute_openlineage_events(should_use_external_connection):
+    should_use_external_connection.return_value = False
     DB_NAME = "DATABASE"
     DB_SCHEMA_NAME = "PUBLIC"
 
@@ -174,9 +175,6 @@ def test_execute_openlineage_events():
         get_conn = MagicMock(name="conn")
         get_connection = MagicMock()
 
-        def get_first(self, *_):
-            return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]
-
     dbapi_hook = SnowflakeHookForTests()
 
     class SnowflakeOperatorForTest(SnowflakeOperator):
@@ -185,7 +183,7 @@ def test_execute_openlineage_events():
 
     sql = (
         "INSERT INTO Test_table\n"
-        "SELECT t1.*, t2.additional_constant FROM 
ANOTHER_db.another_schema.popular_orders_day_of_week t1\n"
+        "SELECT t1.*, t2.additional_constant FROM 
ANOTHER_DB.ANOTHER_SCHEMA.popular_orders_day_of_week t1\n"
         "JOIN little_table t2 ON t1.order_day_of_week = 
t2.order_day_of_week;\n"
         "FORGOT TO COMMENT"
     )
@@ -223,6 +221,7 @@ def test_execute_openlineage_events():
     dbapi_hook.get_connection.return_value = Connection(
         conn_id="snowflake_default",
         conn_type="snowflake",
+        schema="PUBLIC",
         extra={
             "account": "test_account",
             "region": "us-east",
@@ -233,55 +232,17 @@ def test_execute_openlineage_events():
     dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect 
= rows
 
     lineage = op.get_openlineage_facets_on_start()
-    assert 
dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [
-        call(
-            "SELECT database.information_schema.columns.table_schema, 
database.information_schema.columns.table_name, "
-            "database.information_schema.columns.column_name, 
database.information_schema.columns.ordinal_position, "
-            "database.information_schema.columns.data_type, 
database.information_schema.columns.table_catalog \n"
-            "FROM database.information_schema.columns \n"
-            "WHERE database.information_schema.columns.table_name IN 
('LITTLE_TABLE') "
-            "UNION ALL "
-            "SELECT another_db.information_schema.columns.table_schema, 
another_db.information_schema.columns.table_name, "
-            "another_db.information_schema.columns.column_name, 
another_db.information_schema.columns.ordinal_position, "
-            "another_db.information_schema.columns.data_type, 
another_db.information_schema.columns.table_catalog \n"
-            "FROM another_db.information_schema.columns \n"
-            "WHERE another_db.information_schema.columns.table_schema = 
'ANOTHER_SCHEMA' "
-            "AND another_db.information_schema.columns.table_name IN 
('POPULAR_ORDERS_DAY_OF_WEEK')"
-        ),
-        call(
-            "SELECT database.information_schema.columns.table_schema, 
database.information_schema.columns.table_name, "
-            "database.information_schema.columns.column_name, 
database.information_schema.columns.ordinal_position, "
-            "database.information_schema.columns.data_type, 
database.information_schema.columns.table_catalog \n"
-            "FROM database.information_schema.columns \n"
-            "WHERE database.information_schema.columns.table_name IN 
('TEST_TABLE')"
-        ),
-    ]
+    # Not calling Snowflake
+    assert 
dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == []
 
     assert lineage.inputs == [
         Dataset(
             namespace="snowflake://test_account.us-east.aws",
-            
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.POPULAR_ORDERS_DAY_OF_WEEK",
-            facets={
-                "schema": SchemaDatasetFacet(
-                    fields=[
-                        SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"),
-                        SchemaField(name="ORDER_PLACED_ON", 
type="TIMESTAMP_NTZ"),
-                        SchemaField(name="ORDERS_PLACED", type="NUMBER"),
-                    ]
-                )
-            },
+            name=f"{DB_NAME}.{DB_SCHEMA_NAME}.LITTLE_TABLE",
         ),
         Dataset(
             namespace="snowflake://test_account.us-east.aws",
-            name=f"{DB_NAME}.{DB_SCHEMA_NAME}.LITTLE_TABLE",
-            facets={
-                "schema": SchemaDatasetFacet(
-                    fields=[
-                        SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"),
-                        SchemaField(name="ADDITIONAL_CONSTANT", type="TEXT"),
-                    ]
-                )
-            },
+            
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.POPULAR_ORDERS_DAY_OF_WEEK",
         ),
     ]
     assert lineage.outputs == [
@@ -289,14 +250,6 @@ def test_execute_openlineage_events():
             namespace="snowflake://test_account.us-east.aws",
             name=f"{DB_NAME}.{DB_SCHEMA_NAME}.TEST_TABLE",
             facets={
-                "schema": SchemaDatasetFacet(
-                    fields=[
-                        SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"),
-                        SchemaField(name="ORDER_PLACED_ON", 
type="TIMESTAMP_NTZ"),
-                        SchemaField(name="ORDERS_PLACED", type="NUMBER"),
-                        SchemaField(name="ADDITIONAL_CONSTANT", type="TEXT"),
-                    ]
-                ),
                 "columnLineage": ColumnLineageDatasetFacet(
                     fields={
                         "additional_constant": 
ColumnLineageDatasetFacetFieldsAdditional(

Reply via email to