uranusjr commented on code in PR #31398: URL: https://github.com/apache/airflow/pull/31398#discussion_r1244779032
########## airflow/providers/openlineage/utils/sql.py: ########## @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from collections import defaultdict +from contextlib import closing +from enum import IntEnum +from typing import TYPE_CHECKING, Dict, Iterator, List + +from attrs import define, field +from sqlalchemy import Column, MetaData, Table, and_, union_all + +from openlineage.client.facet import SchemaDatasetFacet, SchemaField +from openlineage.client.run import Dataset + +if TYPE_CHECKING: + from sqlalchemy.engine import Engine + from sqlalchemy.sql import ClauseElement + + from airflow.hooks.base import BaseHook + + +logger = logging.getLogger(__name__) + + +class ColumnIndex(IntEnum): + """Enumerates the indices of columns in information schema view.""" + + SCHEMA = 0 + TABLE_NAME = 1 + COLUMN_NAME = 2 + ORDINAL_POSITION = 3 + # Use 'udt_name' which is the underlying type of column + UDT_NAME = 4 + # Database is optional as 5th column + DATABASE = 5 + + +TablesHierarchy = Dict[str, Dict[str, List[str]]] + + +@define +class TableSchema: + """Temporary object used to construct OpenLineage Dataset.""" + + table: str = field() + schema: str | None = field() + database: str | None = field() + fields: list[SchemaField] = field() + + def to_dataset(self, namespace: str, database: str | None = None) -> Dataset: + # Prefix the table name with database and schema name using + # the format: {database_name}.{table_schema}.{table_name}. + name = ".".join( + filter( + lambda x: x is not None, # type: ignore + [self.database if self.database else database, self.schema, self.table], + ) + ) + return Dataset( + namespace=namespace, + name=name, + facets={"schema": SchemaDatasetFacet(fields=self.fields)} if len(self.fields) is not None else {}, + ) + + +def execute_query_on_hook(hook: BaseHook, query: str) -> Iterator[tuple]: + with closing(hook.get_conn()) as conn: + with closing(conn.cursor()) as cursor: + return cursor.execute(query).fetchall() + + +def get_table_schemas( + hook: BaseHook, + namespace: str, + database: str | None, + in_query: str | None, + out_query: str | None, +) -> tuple[list[Dataset], ...]: + """ + This function queries database for table schemas using provided hook. + Responsibility to provide queries for this function is on particular extractors. + If query for input or output table isn't provided, the query is skipped. + """ + in_datasets: list[Dataset] = [] + out_datasets: list[Dataset] = [] + # Do not query if we did not get both queries + if not in_query and not out_query: + return [], [] + + with closing(hook.get_conn()) as conn: + with closing(conn.cursor()) as cursor: + if in_query: + cursor.execute(in_query) + in_datasets += [x.to_dataset(namespace, database) for x in parse_query_result(cursor)] + if out_query: + cursor.execute(out_query) + out_datasets += [x.to_dataset(namespace, database) for x in parse_query_result(cursor)] + return in_datasets, out_datasets + + +def parse_query_result(cursor) -> list[TableSchema]: + """ + This function fetches results from DB-API 2.0 cursor + For each row it creates :class:`TableSchema`. + Returns list of table schemas. + """ + schemas: dict = {} + columns: dict = defaultdict(list) + for row in cursor.fetchall(): + table_schema_name: str = row[ColumnIndex.SCHEMA] + table_name: str = row[ColumnIndex.TABLE_NAME] + table_column: SchemaField = SchemaField( + name=row[ColumnIndex.COLUMN_NAME], + type=row[ColumnIndex.UDT_NAME], + description=None, + ) + ordinal_position = row[ColumnIndex.ORDINAL_POSITION] + try: + table_database = row[ColumnIndex.DATABASE] + except IndexError: + table_database = None + + # Attempt to get table schema + table_key = ".".join(filter(None, [table_database, table_schema_name, table_name])) + + schemas[table_key] = TableSchema( + table=table_name, schema=table_schema_name, database=table_database, fields=[] + ) + columns[table_key].append((ordinal_position, table_column)) + + for schema in schemas.values(): + table_key = ".".join(filter(None, [schema.database, schema.schema, schema.table])) + schema.fields = [x for _, x in sorted(columns[table_key])] + + return list(schemas.values()) + + +def create_information_schema_query( + columns: list[str], + information_schema_table_name: str, + tables_hierarchy: TablesHierarchy, + uppercase_names: bool = False, + sqlalchemy_engine: Engine | None = None, +) -> str: + """This function creates query for getting table schemas from information schema.""" + metadata = MetaData(sqlalchemy_engine) + select_statements = [] + for db, schema_mapping in tables_hierarchy.items(): + schema, table_name = information_schema_table_name.split(".") + if db: + schema = f"{db}.{schema}" + information_schema_table = Table( + table_name, metadata, *[Column(column) for column in columns], schema=schema + ) + filter_clauses = create_filter_clauses(schema_mapping, information_schema_table, uppercase_names) + select_statements.append(information_schema_table.select().filter(*filter_clauses)) + return str( + union_all(*select_statements).compile(sqlalchemy_engine, compile_kwargs={"literal_binds": True}) + ) + + +def create_filter_clauses( + schema_mapping: dict, information_schema_table: Table, uppercase_names: bool = False +) -> ClauseElement: + """ + Creates comprehensive filter clauses for all tables in one database (assuming hierarchy + of database -> schema -> table). + + :param schema_mapping: a dictionary of schema names and list of tables in each + :param information_schema_table: `sqlalchemy.Table` instance used to construct clauses + For most SQL dbs it contains `table_name` and `table_schema` columns, + therefore it is expected the table has them defined. + :param uppercase_names: if True use schema and table names uppercase + """ + filter_clauses = [] + for schema, tables in schema_mapping.items(): + filter_clause = information_schema_table.c.table_name.in_( + map(lambda name: name.upper() if uppercase_names else name, tables) Review Comment: ```suggestion name.upper() if uppercase_names else name for name in tables ``` Seems shorter and easier to understand to me. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
