This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 44b97e1687 Add OpenLineage support for Redshift SQL. (#35794)
44b97e1687 is described below
commit 44b97e168733b08b308f16b2738b6c15e8a35862
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Thu Jan 4 18:49:35 2024 +0100
Add OpenLineage support for Redshift SQL. (#35794)
Add flat information schema query support in SQLParser.
Signed-off-by: Jakub Dardzinski <[email protected]>
Co-authored-by: Niko Oliveira <[email protected]>
---
airflow/providers/amazon/aws/hooks/redshift_sql.py | 61 ++++++
airflow/providers/openlineage/sqlparser.py | 13 +-
airflow/providers/openlineage/utils/sql.py | 97 ++++++---
.../amazon/aws/hooks/test_redshift_sql.py | 44 ++++
.../amazon/aws/operators/test_redshift_sql.py | 241 +++++++++++++++++++++
tests/providers/openlineage/utils/test_sql.py | 62 ++++--
6 files changed, 472 insertions(+), 46 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py
b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index 66659cb0a1..580efc1443 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -30,6 +30,7 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook
if TYPE_CHECKING:
from airflow.models.connection import Connection
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
class RedshiftSQLHook(DbApiHook):
@@ -197,3 +198,63 @@ class RedshiftSQLHook(DbApiHook):
conn_kwargs_dejson = self.conn.extra_dejson
conn_kwargs: dict = {**conn_params, **conn_kwargs_dejson}
return redshift_connector.connect(**conn_kwargs)
+
+ def get_openlineage_database_info(self, connection: Connection) ->
DatabaseInfo:
+ """Returns Redshift specific information for OpenLineage."""
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+ authority = self._get_openlineage_redshift_authority_part(connection)
+
+ return DatabaseInfo(
+ scheme="redshift",
+ authority=authority,
+ database=connection.schema,
+ information_schema_table_name="SVV_REDSHIFT_COLUMNS",
+ information_schema_columns=[
+ "schema_name",
+ "table_name",
+ "column_name",
+ "ordinal_position",
+ "data_type",
+ "database_name",
+ ],
+ is_information_schema_cross_db=True,
+ use_flat_cross_db_query=True,
+ )
+
+ def _get_openlineage_redshift_authority_part(self, connection: Connection)
-> str:
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+ port = connection.port or 5439
+
+ cluster_identifier = None
+
+ if connection.extra_dejson.get("iam", False):
+ cluster_identifier =
connection.extra_dejson.get("cluster_identifier")
+ region_name = AwsBaseHook(aws_conn_id=self.aws_conn_id).region_name
+ identifier = f"{cluster_identifier}.{region_name}"
+ if not cluster_identifier:
+ identifier = self._get_identifier_from_hostname(connection.host)
+ return f"{identifier}:{port}"
+
+ def _get_identifier_from_hostname(self, hostname: str) -> str:
+ parts = hostname.split(".")
+ if "amazonaws.com" in hostname and len(parts) == 6:
+ return f"{parts[0]}.{parts[2]}"
+ else:
+ self.log.debug(
+ """Could not parse identifier from hostname '%s'.
+ You are probably using IP to connect to Redshift cluster.
+ Expected format:
'cluster_identifier.id.region_name.redshift.amazonaws.com'
+ Falling back to whole hostname.""",
+ hostname,
+ )
+ return hostname
+
+ def get_openlineage_database_dialect(self, connection: Connection) -> str:
+ """Returns redshift dialect."""
+ return "redshift"
+
+ def get_openlineage_default_schema(self) -> str | None:
+ """Returns current schema. This is usually changed with
``SEARCH_PATH`` parameter."""
+ return self.get_first("SELECT CURRENT_SCHEMA();")[0]
diff --git a/airflow/providers/openlineage/sqlparser.py
b/airflow/providers/openlineage/sqlparser.py
index 41c378fc27..d54c19dbc8 100644
--- a/airflow/providers/openlineage/sqlparser.py
+++ b/airflow/providers/openlineage/sqlparser.py
@@ -67,6 +67,7 @@ class GetTableSchemasParams(TypedDict):
is_cross_db: bool
information_schema_columns: list[str]
information_schema_table: str
+ use_flat_cross_db_query: bool
is_uppercase_names: bool
database: str | None
@@ -83,6 +84,8 @@ class DatabaseInfo:
:param database: Takes precedence over parsed database name.
:param information_schema_columns: List of columns names from information
schema table.
:param information_schema_table_name: Information schema table name.
+ :param use_flat_cross_db_query: Specifies if single information schema
table should be used
+ for cross-database queries (e.g. for Redshift).
:param is_information_schema_cross_db: Specifies if information schema
contains
cross-database data.
:param is_uppercase_names: Specifies if database accepts only uppercase
names (e.g. Snowflake).
@@ -95,6 +98,7 @@ class DatabaseInfo:
database: str | None = None
information_schema_columns: list[str] = DEFAULT_INFORMATION_SCHEMA_COLUMNS
information_schema_table_name: str = DEFAULT_INFORMATION_SCHEMA_TABLE_NAME
+ use_flat_cross_db_query: bool = False
is_information_schema_cross_db: bool = False
is_uppercase_names: bool = False
normalize_name_method: Callable[[str], str] = default_normalize_name_method
@@ -133,6 +137,7 @@ class SQLParser:
"information_schema_table":
database_info.information_schema_table_name,
"is_uppercase_names": database_info.is_uppercase_names,
"database": database or database_info.database,
+ "use_flat_cross_db_query": database_info.use_flat_cross_db_query,
}
return get_table_schemas(
hook,
@@ -297,9 +302,10 @@ class SQLParser:
tables: list[DbTableMeta],
normalize_name: Callable[[str], str],
is_cross_db: bool,
- information_schema_columns,
- information_schema_table,
- is_uppercase_names,
+ information_schema_columns: list[str],
+ information_schema_table: str,
+ is_uppercase_names: bool,
+ use_flat_cross_db_query: bool,
database: str | None = None,
sqlalchemy_engine: Engine | None = None,
) -> str:
@@ -314,6 +320,7 @@ class SQLParser:
columns=information_schema_columns,
information_schema_table_name=information_schema_table,
tables_hierarchy=tables_hierarchy,
+ use_flat_cross_db_query=use_flat_cross_db_query,
uppercase_names=is_uppercase_names,
sqlalchemy_engine=sqlalchemy_engine,
)
diff --git a/airflow/providers/openlineage/utils/sql.py
b/airflow/providers/openlineage/utils/sql.py
index 7bd6043040..da08fa68d4 100644
--- a/airflow/providers/openlineage/utils/sql.py
+++ b/airflow/providers/openlineage/utils/sql.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional
from attrs import define
from openlineage.client.facet import SchemaDatasetFacet, SchemaField
from openlineage.client.run import Dataset
-from sqlalchemy import Column, MetaData, Table, and_, union_all
+from sqlalchemy import Column, MetaData, Table, and_, or_, union_all
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
@@ -42,7 +42,7 @@ class ColumnIndex(IntEnum):
ORDINAL_POSITION = 3
# Use 'udt_name' which is the underlying type of column
UDT_NAME = 4
- # Database is optional as 5th column
+ # Database is optional as 6th column
DATABASE = 5
@@ -145,55 +145,96 @@ def create_information_schema_query(
information_schema_table_name: str,
tables_hierarchy: TablesHierarchy,
uppercase_names: bool = False,
+ use_flat_cross_db_query: bool = False,
sqlalchemy_engine: Engine | None = None,
) -> str:
"""Creates query for getting table schemas from information schema."""
metadata = MetaData(sqlalchemy_engine)
select_statements = []
- for db, schema_mapping in tables_hierarchy.items():
- # Information schema table name is expected to be "<
information_schema schema >.<view/table name>"
- # usually "information_schema.columns". In order to use table
identifier correct for various table
- # we need to pass first part of dot-separated identifier as `schema`
argument to `sqlalchemy.Table`.
- if db:
- # Use database as first part of table identifier.
- schema = db
- table_name = information_schema_table_name
- else:
- # When no database passed, use schema as first part of table
identifier.
- schema, table_name = information_schema_table_name.split(".")
+ # Don't iterate over tables hierarchy, just pass it to query single
information schema table
+ if use_flat_cross_db_query:
information_schema_table = Table(
- table_name,
+ information_schema_table_name,
metadata,
*[Column(column) for column in columns],
- schema=schema,
quote=False,
)
- filter_clauses = create_filter_clauses(schema_mapping,
information_schema_table, uppercase_names)
-
select_statements.append(information_schema_table.select().filter(*filter_clauses))
+ filter_clauses = create_filter_clauses(
+ tables_hierarchy,
+ information_schema_table,
+ uppercase_names=uppercase_names,
+ )
+
select_statements.append(information_schema_table.select().filter(filter_clauses))
+ else:
+ for db, schema_mapping in tables_hierarchy.items():
+ # Information schema table name is expected to be "<
information_schema schema >.<view/table name>"
+ # usually "information_schema.columns". In order to use table
identifier correct for various table
+ # we need to pass first part of dot-separated identifier as
`schema` argument to `sqlalchemy.Table`.
+ if db:
+ # Use database as first part of table identifier.
+ schema = db
+ table_name = information_schema_table_name
+ else:
+ # When no database passed, use schema as first part of table
identifier.
+ schema, table_name = information_schema_table_name.split(".")
+ information_schema_table = Table(
+ table_name,
+ metadata,
+ *[Column(column) for column in columns],
+ schema=schema,
+ quote=False,
+ )
+ filter_clauses = create_filter_clauses(
+ {None: schema_mapping},
+ information_schema_table,
+ uppercase_names=uppercase_names,
+ )
+
select_statements.append(information_schema_table.select().filter(filter_clauses))
return str(
union_all(*select_statements).compile(sqlalchemy_engine,
compile_kwargs={"literal_binds": True})
)
def create_filter_clauses(
- schema_mapping: dict, information_schema_table: Table, uppercase_names:
bool = False
+ mapping: dict,
+ information_schema_table: Table,
+ uppercase_names: bool = False,
) -> ClauseElement:
"""
Creates comprehensive filter clauses for all tables in one database.
- :param schema_mapping: a dictionary of schema names and list of tables in
each
+ :param mapping: a nested dictionary of database, schema names and list of
tables in each
:param information_schema_table: `sqlalchemy.Table` instance used to
construct clauses
For most SQL dbs it contains `table_name` and `table_schema` columns,
therefore it is expected the table has them defined.
:param uppercase_names: if True use schema and table names uppercase
"""
+ table_schema_column_name =
information_schema_table.columns[ColumnIndex.SCHEMA].name
+ table_name_column_name =
information_schema_table.columns[ColumnIndex.TABLE_NAME].name
+ try:
+ table_database_column_name =
information_schema_table.columns[ColumnIndex.DATABASE].name
+ except IndexError:
+ table_database_column_name = ""
+
filter_clauses = []
- for schema, tables in schema_mapping.items():
- filter_clause = information_schema_table.c.table_name.in_(
- name.upper() if uppercase_names else name for name in tables
- )
- if schema:
- schema = schema.upper() if uppercase_names else schema
- filter_clause = and_(information_schema_table.c.table_schema ==
schema, filter_clause)
- filter_clauses.append(filter_clause)
- return filter_clauses
+ for db, schema_mapping in mapping.items():
+ schema_level_clauses = []
+ for schema, tables in schema_mapping.items():
+ filter_clause =
information_schema_table.c[table_name_column_name].in_(
+ name.upper() if uppercase_names else name for name in tables
+ )
+ if schema:
+ schema = schema.upper() if uppercase_names else schema
+ filter_clause = and_(
+ information_schema_table.c[table_schema_column_name] ==
schema, filter_clause
+ )
+ schema_level_clauses.append(filter_clause)
+ if db and table_database_column_name:
+ db = db.upper() if uppercase_names else db
+ filter_clause = and_(
+ information_schema_table.c[table_database_column_name] == db,
or_(*schema_level_clauses)
+ )
+ filter_clauses.append(filter_clause)
+ else:
+ filter_clauses.extend(schema_level_clauses)
+ return or_(*filter_clauses)
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index 4871522489..aced8cae13 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -31,6 +31,7 @@ LOGIN_PASSWORD = "password"
LOGIN_HOST = "host"
LOGIN_PORT = 5439
LOGIN_SCHEMA = "dev"
+MOCK_REGION_NAME = "eu-north-1"
class TestRedshiftSQLHookConn:
@@ -240,3 +241,46 @@ class TestRedshiftSQLHookConn:
ClusterIdentifier=expected_cluster_identifier,
AutoCreate=False,
)
+
+ @mock.patch.dict("os.environ",
AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}")
+ @pytest.mark.parametrize(
+ "connection_host, connection_extra, expected_identity",
+ [
+ # test without a connection host but with a cluster_identifier in
connection extra
+ (
+ None,
+ {"iam": True, "cluster_identifier":
"cluster_identifier_from_extra"},
+ f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+ ),
+ # test with a connection host and without a cluster_identifier in
connection extra
+ (
+
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
+ {"iam": True},
+ "cluster_identifier_from_host.my_region",
+ ),
+ # test with both connection host and cluster_identifier in
connection extra
+ (
+ "cluster_identifier_from_host.x.y",
+ {"iam": True, "cluster_identifier":
"cluster_identifier_from_extra"},
+ f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+ ),
+ # test when hostname doesn't match pattern
+ (
+ "1.2.3.4",
+ {},
+ "1.2.3.4",
+ ),
+ ],
+ )
+ def test_get_openlineage_redshift_authority_part(
+ self,
+ connection_host,
+ connection_extra,
+ expected_identity,
+ ):
+ self.connection.host = connection_host
+ self.connection.extra = json.dumps(connection_extra)
+
+ assert f"{expected_identity}:{LOGIN_PORT}" ==
self.db_hook._get_openlineage_redshift_authority_part(
+ self.connection
+ )
diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py
b/tests/providers/amazon/aws/operators/test_redshift_sql.py
new file mode 100644
index 0000000000..d1c6e26151
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py
@@ -0,0 +1,241 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, call, patch
+
+import pytest
+from openlineage.client.facet import (
+ ColumnLineageDatasetFacet,
+ ColumnLineageDatasetFacetFieldsAdditional,
+ ColumnLineageDatasetFacetFieldsAdditionalInputFields,
+ SchemaDatasetFacet,
+ SchemaField,
+ SqlJobFacet,
+)
+from openlineage.client.run import Dataset
+
+from airflow.models.connection import Connection
+from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
+
+MOCK_REGION_NAME = "eu-north-1"
+
+
+class TestRedshiftSQLOpenLineage:
+ @patch.dict("os.environ",
AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}")
+ @pytest.mark.parametrize(
+ "connection_host, connection_extra, expected_identity",
+ [
+ # test without a connection host but with a cluster_identifier in
connection extra
+ (
+ None,
+ {"iam": True, "cluster_identifier":
"cluster_identifier_from_extra"},
+ f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+ ),
+ # test with a connection host and without a cluster_identifier in
connection extra
+ (
+
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
+ {"iam": True},
+ "cluster_identifier_from_host.my_region",
+ ),
+ # test with both connection host and cluster_identifier in
connection extra
+ (
+ "cluster_identifier_from_host.x.y",
+ {"iam": True, "cluster_identifier":
"cluster_identifier_from_extra"},
+ f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+ ),
+ # test when hostname doesn't match pattern
+ (
+ "1.2.3.4",
+ {},
+ "1.2.3.4",
+ ),
+ ],
+ )
+ @patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+ def test_execute_openlineage_events(
+ self, mock_aws_hook_conn, connection_host, connection_extra,
expected_identity
+ ):
+ DB_NAME = "database"
+ DB_SCHEMA_NAME = "public"
+
+ ANOTHER_DB_NAME = "another_db"
+ ANOTHER_DB_SCHEMA = "another_schema"
+
+ # Mock AWS Connection
+ mock_aws_hook_conn.get_cluster_credentials.return_value = {
+ "DbPassword": "aws_token",
+ "DbUser": "IAM:user",
+ }
+
+ class RedshiftSQLHookForTests(RedshiftSQLHook):
+ get_conn = MagicMock(name="conn")
+ get_connection = MagicMock()
+
+ def get_first(self, *_):
+ return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]
+
+ dbapi_hook = RedshiftSQLHookForTests()
+
+ class RedshiftOperatorForTest(SQLExecuteQueryOperator):
+ def get_db_hook(self):
+ return dbapi_hook
+
+ sql = (
+ "INSERT INTO Test_table\n"
+ "SELECT t1.*, t2.additional_constant FROM
ANOTHER_db.another_schema.popular_orders_day_of_week t1\n"
+ "JOIN little_table t2 ON t1.order_day_of_week =
t2.order_day_of_week;\n"
+ "FORGOT TO COMMENT"
+ )
+ op = RedshiftOperatorForTest(task_id="redshift-operator", sql=sql)
+ rows = [
+ [
+ (
+ ANOTHER_DB_SCHEMA,
+ "popular_orders_day_of_week",
+ "order_day_of_week",
+ 1,
+ "varchar",
+ ANOTHER_DB_NAME,
+ ),
+ (
+ ANOTHER_DB_SCHEMA,
+ "popular_orders_day_of_week",
+ "order_placed_on",
+ 2,
+ "timestamp",
+ ANOTHER_DB_NAME,
+ ),
+ (
+ ANOTHER_DB_SCHEMA,
+ "popular_orders_day_of_week",
+ "orders_placed",
+ 3,
+ "int4",
+ ANOTHER_DB_NAME,
+ ),
+ (DB_SCHEMA_NAME, "little_table", "order_day_of_week", 1,
"varchar", DB_NAME),
+ (DB_SCHEMA_NAME, "little_table", "additional_constant", 2,
"varchar", DB_NAME),
+ ],
+ [
+ (DB_SCHEMA_NAME, "test_table", "order_day_of_week", 1,
"varchar", DB_NAME),
+ (DB_SCHEMA_NAME, "test_table", "order_placed_on", 2,
"timestamp", DB_NAME),
+ (DB_SCHEMA_NAME, "test_table", "orders_placed", 3, "int4",
DB_NAME),
+ (DB_SCHEMA_NAME, "test_table", "additional_constant", 4,
"varchar", DB_NAME),
+ ],
+ ]
+ dbapi_hook.get_connection.return_value = Connection(
+ conn_id="redshift_default",
+ conn_type="redshift",
+ host=connection_host,
+ extra=connection_extra,
+ )
+
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = rows
+
+ lineage = op.get_openlineage_facets_on_start()
+ assert
dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [
+ call(
+ "SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
+ "SVV_REDSHIFT_COLUMNS.table_name, "
+ "SVV_REDSHIFT_COLUMNS.column_name, "
+ "SVV_REDSHIFT_COLUMNS.ordinal_position, "
+ "SVV_REDSHIFT_COLUMNS.data_type, "
+ "SVV_REDSHIFT_COLUMNS.database_name \n"
+ "FROM SVV_REDSHIFT_COLUMNS \n"
+ "WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
+ "OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
+ "AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
+ "SVV_REDSHIFT_COLUMNS.table_name IN
('popular_orders_day_of_week')"
+ ),
+ call(
+ "SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
+ "SVV_REDSHIFT_COLUMNS.table_name, "
+ "SVV_REDSHIFT_COLUMNS.column_name, "
+ "SVV_REDSHIFT_COLUMNS.ordinal_position, "
+ "SVV_REDSHIFT_COLUMNS.data_type, "
+ "SVV_REDSHIFT_COLUMNS.database_name \n"
+ "FROM SVV_REDSHIFT_COLUMNS \n"
+ "WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
+ ),
+ ]
+
+ expected_namespace = f"redshift://{expected_identity}:5439"
+
+ assert lineage.inputs == [
+ Dataset(
+ namespace=expected_namespace,
+
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week",
+ facets={
+ "schema": SchemaDatasetFacet(
+ fields=[
+ SchemaField(name="order_day_of_week",
type="varchar"),
+ SchemaField(name="order_placed_on",
type="timestamp"),
+ SchemaField(name="orders_placed", type="int4"),
+ ]
+ )
+ },
+ ),
+ Dataset(
+ namespace=expected_namespace,
+ name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table",
+ facets={
+ "schema": SchemaDatasetFacet(
+ fields=[
+ SchemaField(name="order_day_of_week",
type="varchar"),
+ SchemaField(name="additional_constant",
type="varchar"),
+ ]
+ )
+ },
+ ),
+ ]
+ assert lineage.outputs == [
+ Dataset(
+ namespace=expected_namespace,
+ name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table",
+ facets={
+ "schema": SchemaDatasetFacet(
+ fields=[
+ SchemaField(name="order_day_of_week",
type="varchar"),
+ SchemaField(name="order_placed_on",
type="timestamp"),
+ SchemaField(name="orders_placed", type="int4"),
+ SchemaField(name="additional_constant",
type="varchar"),
+ ]
+ ),
+ "columnLineage": ColumnLineageDatasetFacet(
+ fields={
+ "additional_constant":
ColumnLineageDatasetFacetFieldsAdditional(
+ inputFields=[
+
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
+ namespace=expected_namespace,
+ name="database.public.little_table",
+ field="additional_constant",
+ )
+ ],
+ transformationDescription="",
+ transformationType="",
+ )
+ }
+ ),
+ },
+ )
+ ]
+
+ assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}
+
+ assert lineage.run_facets["extractionError"].failedTasks == 1
diff --git a/tests/providers/openlineage/utils/test_sql.py
b/tests/providers/openlineage/utils/test_sql.py
index 8567920578..180defbeec 100644
--- a/tests/providers/openlineage/utils/test_sql.py
+++ b/tests/providers/openlineage/utils/test_sql.py
@@ -262,31 +262,63 @@ def test_get_table_schemas_with_other_database():
@pytest.mark.parametrize(
"schema_mapping, expected",
[
- pytest.param({None: ["C1", "C2"]},
["information_schema.columns.table_name IN ('C1', 'C2')"]),
+ pytest.param({None: {None: ["C1", "C2"]}},
"information_schema.columns.table_name IN ('C1', 'C2')"),
pytest.param(
- {"Schema1": ["Table1"], "Schema2": ["Table2"]},
- [
- "information_schema.columns.table_schema = 'Schema1' AND "
- "information_schema.columns.table_name IN ('Table1')",
- "information_schema.columns.table_schema = 'Schema2' AND "
- "information_schema.columns.table_name IN ('Table2')",
- ],
+ {None: {"Schema1": ["Table1"], "Schema2": ["Table2"]}},
+ "information_schema.columns.table_schema = 'Schema1' AND "
+ "information_schema.columns.table_name IN ('Table1') OR "
+ "information_schema.columns.table_schema = 'Schema2' AND "
+ "information_schema.columns.table_name IN ('Table2')",
),
pytest.param(
- {"Schema1": ["Table1", "Table2"]},
- [
- "information_schema.columns.table_schema = 'Schema1' AND "
- "information_schema.columns.table_name IN ('Table1',
'Table2')",
- ],
+ {None: {"Schema1": ["Table1", "Table2"]}},
+ "information_schema.columns.table_schema = 'Schema1' AND "
+ "information_schema.columns.table_name IN ('Table1', 'Table2')",
+ ),
+ pytest.param(
+ {"Database1": {"Schema1": ["Table1", "Table2"]}},
+ "information_schema.columns.table_database = 'Database1' "
+ "AND information_schema.columns.table_schema = 'Schema1' "
+ "AND information_schema.columns.table_name IN ('Table1',
'Table2')",
+ ),
+ pytest.param(
+ {"Database1": {"Schema1": ["Table1", "Table2"], "Schema2":
["Table3", "Table4"]}},
+ "information_schema.columns.table_database = 'Database1' "
+ "AND (information_schema.columns.table_schema = 'Schema1' "
+ "AND information_schema.columns.table_name IN ('Table1', 'Table2')
"
+ "OR information_schema.columns.table_schema = 'Schema2' "
+ "AND information_schema.columns.table_name IN ('Table3',
'Table4'))",
+ ),
+ pytest.param(
+ {"Database1": {"Schema1": ["Table1", "Table2"]}, "Database2":
{"Schema2": ["Table3", "Table4"]}},
+ "information_schema.columns.table_database = 'Database1' "
+ "AND information_schema.columns.table_schema = 'Schema1' "
+ "AND information_schema.columns.table_name IN ('Table1', 'Table2')
OR "
+ "information_schema.columns.table_database = 'Database2' "
+ "AND information_schema.columns.table_schema = 'Schema2' "
+ "AND information_schema.columns.table_name IN ('Table3',
'Table4')",
),
],
)
def test_create_filter_clauses(schema_mapping, expected):
information_table = Table(
- "columns", MetaData(), *[Column("table_name"),
Column("table_schema")], schema="information_schema"
+ "columns",
+ MetaData(),
+ *[
+ Column(name)
+ for name in [
+ "table_schema",
+ "table_name",
+ "column_name",
+ "ordinal_position",
+ "udt_name",
+ "table_database",
+ ]
+ ],
+ schema="information_schema",
)
clauses = create_filter_clauses(schema_mapping, information_table)
- assert [str(clause.compile(compile_kwargs={"literal_binds": True})) for
clause in clauses] == expected
+ assert str(clauses.compile(compile_kwargs={"literal_binds": True})) ==
expected
def test_create_create_information_schema_query():