This is an automated email from the ASF dual-hosted git repository. msyavuz pushed a commit to branch msyavuz/feat/datasource-analyzer in repository https://gitbox.apache.org/repos/asf/superset.git
commit cc4a66fccf456b9d1e0113d678349e8345cfad09 Author: Mehmet Salih Yavuz <[email protected]> AuthorDate: Wed Dec 17 08:06:11 2025 +0300 feat: database analyzer celery job --- superset/commands/database_analyzer/__init__.py | 16 + superset/commands/database_analyzer/analyze.py | 437 +++++++++++++++++++++ superset/commands/database_analyzer/llm_service.py | 277 +++++++++++++ superset/config_llm.py | 43 ++ superset/databases/analyzer_api.py | 296 ++++++++++++++ superset/initialization/__init__.py | 2 + ..._07-54_c95466b0_add_database_analyzer_models.py | 347 ++++++++++++++++ superset/models/database_analyzer.py | 238 +++++++++++ superset/tasks/celery_app.py | 2 +- superset/tasks/database_analyzer.py | 223 +++++++++++ .../unit_tests/commands/test_database_analyzer.py | 166 ++++++++ 11 files changed, 2046 insertions(+), 1 deletion(-) diff --git a/superset/commands/database_analyzer/__init__.py b/superset/commands/database_analyzer/__init__.py new file mode 100644 index 0000000000..d216be4ddc --- /dev/null +++ b/superset/commands/database_analyzer/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/superset/commands/database_analyzer/analyze.py b/superset/commands/database_analyzer/analyze.py new file mode 100644 index 0000000000..825546c360 --- /dev/null +++ b/superset/commands/database_analyzer/analyze.py @@ -0,0 +1,437 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +import random +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any + +from flask import current_app +from sqlalchemy import inspect, MetaData, text + +from superset import db +from superset.commands.base import BaseCommand +from superset.commands.database_analyzer.llm_service import LLMService +from superset.models.core import Database +from superset.models.database_analyzer import ( + AnalyzedColumn, + AnalyzedTable, + Cardinality, + DatabaseSchemaReport, + InferredJoin, + JoinType, + TableType, +) +from superset.utils.core import QueryStatus + +logger = logging.getLogger(__name__) + + +class AnalyzeDatabaseSchemaCommand(BaseCommand): + """Command to analyze database schema and generate metadata""" + + def __init__(self, report_id: int): + self.report_id = report_id + self.report: DatabaseSchemaReport | None = None + self.database: Database | None = None + self.llm_service = LLMService() + + def run(self) -> dict[str, Any]: + """Execute the analysis""" + self.validate() + + # Extract schema information + tables_data = self._extract_schema_info() + + # Store basic metadata + self._store_tables_and_columns(tables_data) + + # Augment with AI descriptions (parallel processing) + self._augment_with_ai_descriptions() + + # Infer joins using AI + self._infer_joins_with_ai() + + return { + "tables_count": len(self.report.tables), + "joins_count": len(self.report.joins), + } + + def validate(self) -> None: + """Validate the command can be executed""" + self.report = db.session.query(DatabaseSchemaReport).get(self.report_id) + if not self.report: + raise ValueError(f"Report with id {self.report_id} not found") + + self.database = self.report.database + if not self.database: + raise ValueError(f"Database with id {self.report.database_id} not found") + + def _extract_schema_info(self) -> list[dict[str, Any]]: + """Extract schema information from the database""" + logger.info( + "Extracting schema info for database %s schema %s", + self.database.id, + self.report.schema_name, + ) + + tables_data = [] + + with self.database.get_sqla_engine() as engine: + inspector = inspect(engine) + metadata = MetaData() + metadata.reflect(engine, schema=self.report.schema_name) + + # Get all tables and views + table_names = inspector.get_table_names(schema=self.report.schema_name) + view_names = inspector.get_view_names(schema=self.report.schema_name) + + # Process tables + for table_name in table_names: + table_info = self._extract_table_info( + inspector, engine, table_name, TableType.TABLE + ) + tables_data.append(table_info) + + # Process views + for view_name in view_names: + view_info = self._extract_table_info( + inspector, engine, view_name, TableType.VIEW + ) + tables_data.append(view_info) + + return tables_data + + def _extract_table_info( + self, + inspector: Any, + engine: Any, + table_name: str, + table_type: TableType, + ) -> dict[str, Any]: + """Extract information for a single table/view""" + logger.debug("Extracting info for %s: %s", table_type.value, table_name) + + # Get columns + columns = inspector.get_columns(table_name, schema=self.report.schema_name) + + # Get primary keys + pk_constraint = inspector.get_pk_constraint(table_name, schema=self.report.schema_name) + primary_keys = pk_constraint["constrained_columns"] if pk_constraint else [] + + # Get foreign keys + foreign_keys = inspector.get_foreign_keys(table_name, schema=self.report.schema_name) + fk_columns = set() + for fk in foreign_keys: + fk_columns.update(fk["constrained_columns"]) + + # Get table comment + table_comment = None + try: + result = engine.execute( + text( + f"SELECT obj_description('{self.report.schema_name}.{table_name}'::regclass, 'pg_class')" + ) + ) + row = result.fetchone() + table_comment = row[0] if row else None + except Exception: + logger.debug("Could not fetch table comment for %s", table_name) + + # Get sample data (3 random rows) + sample_rows = [] + if table_type == TableType.TABLE: + try: + query = text( + f'SELECT * FROM "{self.report.schema_name}"."{table_name}" ' + f"TABLESAMPLE SYSTEM(1) LIMIT 3" + ) + result = engine.execute(query) + for row in result: + sample_rows.append(dict(row)) + except Exception: + # Fallback to regular LIMIT if TABLESAMPLE not supported + try: + query = text( + f'SELECT * FROM "{self.report.schema_name}"."{table_name}" LIMIT 3' + ) + result = engine.execute(query) + for row in result: + sample_rows.append(dict(row)) + except Exception: + logger.debug("Could not fetch sample data for %s", table_name) + + # Get row count estimate + row_count = None + try: + result = engine.execute( + text( + f"SELECT reltuples::BIGINT FROM pg_class " + f"WHERE oid = '{self.report.schema_name}.{table_name}'::regclass" + ) + ) + row = result.fetchone() + row_count = row[0] if row else None + except Exception: + logger.debug("Could not fetch row count for %s", table_name) + + # Process column information + columns_info = [] + for idx, col in enumerate(columns, start=1): + col_info = { + "name": col["name"], + "type": str(col["type"]), + "position": idx, + "nullable": col.get("nullable", True), + "is_primary_key": col["name"] in primary_keys, + "is_foreign_key": col["name"] in fk_columns, + "comment": col.get("comment"), + } + columns_info.append(col_info) + + return { + "name": table_name, + "type": table_type, + "comment": table_comment, + "columns": columns_info, + "sample_rows": sample_rows, + "row_count": row_count, + "foreign_keys": foreign_keys, + } + + def _store_tables_and_columns(self, tables_data: list[dict[str, Any]]) -> None: + """Store extracted table and column metadata""" + logger.info("Storing tables and columns metadata") + + for table_data in tables_data: + # Create table record + table = AnalyzedTable( + report_id=self.report_id, + table_name=table_data["name"], + table_type=table_data["type"], + db_comment=table_data["comment"], + extra_json=json.dumps({ + "row_count_estimate": table_data["row_count"], + "sample_rows": table_data["sample_rows"], + "foreign_keys": table_data["foreign_keys"], + }), + ) + db.session.add(table) + db.session.flush() # Get the table ID + + # Create column records + for col_data in table_data["columns"]: + column = AnalyzedColumn( + table_id=table.id, + column_name=col_data["name"], + data_type=col_data["type"], + ordinal_position=col_data["position"], + db_comment=col_data["comment"], + extra_json=json.dumps({ + "is_nullable": col_data["nullable"], + "is_primary_key": col_data["is_primary_key"], + "is_foreign_key": col_data["is_foreign_key"], + }), + ) + db.session.add(column) + + db.session.commit() + + def _augment_with_ai_descriptions(self) -> None: + """Use LLM to generate AI descriptions for tables and columns""" + logger.info("Generating AI descriptions for tables and columns") + + if not self.llm_service.is_available(): + logger.warning("LLM service not available, skipping AI augmentation") + return + + # Process tables in parallel + tables = self.report.tables + max_workers = min(10, len(tables)) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_table = { + executor.submit(self._augment_table_with_ai, table): table + for table in tables + } + + for future in as_completed(future_to_table): + table = future_to_table[future] + try: + future.result() + except Exception as e: + logger.error( + "Failed to generate AI description for table %s: %s", + table.table_name, + str(e), + ) + + def _augment_table_with_ai(self, table: AnalyzedTable) -> None: + """Generate AI descriptions for a single table and its columns""" + try: + # Prepare context for LLM + extra_json = json.loads(table.extra_json or "{}") + sample_rows = extra_json.get("sample_rows", []) + + columns_info = [] + for col in table.columns: + col_extra = json.loads(col.extra_json or "{}") + columns_info.append({ + "name": col.column_name, + "type": col.data_type, + "nullable": col_extra.get("is_nullable", True), + "is_pk": col_extra.get("is_primary_key", False), + "is_fk": col_extra.get("is_foreign_key", False), + "comment": col.db_comment, + }) + + # Generate descriptions + result = self.llm_service.generate_table_descriptions( + table_name=table.table_name, + table_comment=table.db_comment, + columns=columns_info, + sample_data=sample_rows, + ) + + # Update table description + table.ai_description = result.get("table_description") + + # Update column descriptions + col_descriptions = result.get("column_descriptions", {}) + for col in table.columns: + if col.column_name in col_descriptions: + col.ai_description = col_descriptions[col.column_name] + + db.session.commit() + + except Exception as e: + logger.error( + "Error generating AI descriptions for table %s: %s", + table.table_name, + str(e), + ) + db.session.rollback() + + def _infer_joins_with_ai(self) -> None: + """Use LLM to infer potential joins between tables""" + logger.info("Inferring joins between tables using AI") + + if not self.llm_service.is_available(): + logger.warning("LLM service not available, skipping join inference") + return + + tables = self.report.tables + if len(tables) < 2: + logger.info("Not enough tables to infer joins") + return + + # Prepare schema context + schema_context = [] + for table in tables: + table_info = { + "name": table.table_name, + "description": table.ai_description or table.db_comment, + "columns": [], + } + + for col in table.columns: + col_extra = json.loads(col.extra_json or "{}") + table_info["columns"].append({ + "name": col.column_name, + "type": col.data_type, + "description": col.ai_description or col.db_comment, + "is_pk": col_extra.get("is_primary_key", False), + "is_fk": col_extra.get("is_foreign_key", False), + }) + + schema_context.append(table_info) + + # Get existing foreign key relationships + existing_fks = self._get_existing_foreign_keys() + + # Use LLM to infer joins + try: + inferred_joins = self.llm_service.infer_joins( + schema_context=schema_context, + existing_foreign_keys=existing_fks, + ) + + # Store inferred joins + self._store_inferred_joins(inferred_joins) + + except Exception as e: + logger.error("Error inferring joins with AI: %s", str(e)) + + def _get_existing_foreign_keys(self) -> list[dict[str, Any]]: + """Get existing foreign key relationships from extracted metadata""" + existing_fks = [] + + for table in self.report.tables: + extra_json = json.loads(table.extra_json or "{}") + foreign_keys = extra_json.get("foreign_keys", []) + + for fk in foreign_keys: + existing_fks.append({ + "source_table": table.table_name, + "source_columns": fk["constrained_columns"], + "target_table": fk["referred_table"], + "target_columns": fk["referred_columns"], + }) + + return existing_fks + + def _store_inferred_joins(self, inferred_joins: list[dict[str, Any]]) -> None: + """Store the inferred joins in the database""" + logger.info("Storing %d inferred joins", len(inferred_joins)) + + # Create lookup for table IDs + table_lookup = { + table.table_name: table.id + for table in self.report.tables + } + + for join_data in inferred_joins: + source_table_id = table_lookup.get(join_data["source_table"]) + target_table_id = table_lookup.get(join_data["target_table"]) + + if not source_table_id or not target_table_id: + logger.warning( + "Skipping join %s -> %s: table not found", + join_data.get("source_table"), + join_data.get("target_table"), + ) + continue + + join = InferredJoin( + report_id=self.report_id, + source_table_id=source_table_id, + target_table_id=target_table_id, + source_columns=json.dumps(join_data["source_columns"]), + target_columns=json.dumps(join_data["target_columns"]), + join_type=JoinType(join_data.get("join_type", "inner")), + cardinality=Cardinality(join_data.get("cardinality", "N:1")), + semantic_context=join_data.get("semantic_context"), + extra_json=json.dumps({ + "confidence_score": join_data.get("confidence_score", 0.5), + "suggested_by": join_data.get("suggested_by", "ai_inference"), + }), + ) + db.session.add(join) + + db.session.commit() \ No newline at end of file diff --git a/superset/commands/database_analyzer/llm_service.py b/superset/commands/database_analyzer/llm_service.py new file mode 100644 index 0000000000..66a4dfd8eb --- /dev/null +++ b/superset/commands/database_analyzer/llm_service.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +from typing import Any + +from flask import current_app + +logger = logging.getLogger(__name__) + + +class LLMService: + """Service for LLM integration to generate descriptions and infer joins""" + + def __init__(self): + self.api_key = current_app.config.get("LLM_API_KEY") + self.model = current_app.config.get("LLM_MODEL", "gpt-4o") + self.temperature = current_app.config.get("LLM_TEMPERATURE", 0.3) + self.max_tokens = current_app.config.get("LLM_MAX_TOKENS", 4096) + self.base_url = current_app.config.get("LLM_BASE_URL", "https://api.openai.com/v1") + + def is_available(self) -> bool: + """Check if LLM service is configured and available""" + return bool(self.api_key) + + def generate_table_descriptions( + self, + table_name: str, + table_comment: str | None, + columns: list[dict[str, Any]], + sample_data: list[dict[str, Any]], + ) -> dict[str, Any]: + """ + Generate AI descriptions for a table and its columns. + + :param table_name: Name of the table + :param table_comment: Existing table comment + :param columns: List of column information + :param sample_data: Sample rows from the table + :return: Dict with table_description and column_descriptions + """ + if not self.is_available(): + return {"table_description": None, "column_descriptions": {}} + + prompt = self._build_table_description_prompt( + table_name, table_comment, columns, sample_data + ) + + try: + response = self._call_llm(prompt) + return self._parse_table_description_response(response) + except Exception as e: + logger.error("Error calling LLM for table descriptions: %s", str(e)) + return {"table_description": None, "column_descriptions": {}} + + def infer_joins( + self, + schema_context: list[dict[str, Any]], + existing_foreign_keys: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """ + Infer potential joins between tables using AI. + + :param schema_context: List of tables with their columns and descriptions + :param existing_foreign_keys: Already known foreign key relationships + :return: List of inferred joins + """ + if not self.is_available(): + return [] + + prompt = self._build_join_inference_prompt(schema_context, existing_foreign_keys) + + try: + response = self._call_llm(prompt) + return self._parse_join_inference_response(response) + except Exception as e: + logger.error("Error calling LLM for join inference: %s", str(e)) + return [] + + def _build_table_description_prompt( + self, + table_name: str, + table_comment: str | None, + columns: list[dict[str, Any]], + sample_data: list[dict[str, Any]], + ) -> str: + """Build prompt for generating table descriptions""" + prompt = f"""You are a database documentation expert. Generate brief but informative descriptions for the following database table and its columns. + +Table Name: {table_name} +Existing Comment: {table_comment or "None"} + +Columns: +""" + for col in columns: + prompt += f"- {col['name']} ({col['type']})" + if col.get("is_pk"): + prompt += " [PRIMARY KEY]" + if col.get("is_fk"): + prompt += " [FOREIGN KEY]" + if col.get("comment"): + prompt += f" - {col['comment']}" + prompt += "\n" + + if sample_data: + prompt += f"\nSample Data (3 rows):\n{json.dumps(sample_data[:3], indent=2)}\n" + + prompt += """ +Based on the table name, column names, types, and sample data, provide: +1. A brief description of what this table represents (2-3 sentences) +2. Brief descriptions for each column explaining its purpose + +Return the response as JSON in this format: +{ + "table_description": "Description of the table", + "column_descriptions": { + "column_name": "Description of this column", + ... + } +} +""" + return prompt + + def _build_join_inference_prompt( + self, + schema_context: list[dict[str, Any]], + existing_foreign_keys: list[dict[str, Any]], + ) -> str: + """Build prompt for inferring joins""" + prompt = """You are a database architect expert. Analyze the following database schema and identify potential join relationships between tables. + +Schema Information: +""" + for table in schema_context: + prompt += f"\nTable: {table['name']}\n" + if table.get("description"): + prompt += f"Description: {table['description']}\n" + prompt += "Columns:\n" + for col in table["columns"]: + prompt += f" - {col['name']} ({col['type']})" + if col.get("is_pk"): + prompt += " [PK]" + if col.get("is_fk"): + prompt += " [FK]" + if col.get("description"): + prompt += f" - {col['description']}" + prompt += "\n" + + if existing_foreign_keys: + prompt += "\nExisting Foreign Keys:\n" + for fk in existing_foreign_keys: + prompt += f"- {fk['source_table']}.{fk['source_columns']} -> {fk['target_table']}.{fk['target_columns']}\n" + + prompt += """ +Identify potential join relationships based on: +1. Column name patterns (e.g., user_id, customer_id) +2. Data type compatibility +3. Semantic relationships +4. Common database patterns + +For each join, provide: +- source_table and source_columns +- target_table and target_columns +- join_type (inner, left, right, full) +- cardinality (1:1, 1:N, N:1, N:M) +- semantic_context explaining the relationship +- confidence_score (0.0 to 1.0) + +Return ONLY joins not already covered by existing foreign keys. +Focus on the most likely and useful joins. + +Return the response as JSON array: +[ + { + "source_table": "table1", + "source_columns": ["col1"], + "target_table": "table2", + "target_columns": ["col2"], + "join_type": "inner", + "cardinality": "N:1", + "semantic_context": "Explanation of the relationship", + "confidence_score": 0.85, + "suggested_by": "ai_inference" + }, + ... +] +""" + return prompt + + def _call_llm(self, prompt: str) -> str: + """Call the LLM API with the given prompt""" + # This is a placeholder implementation + # In production, this would call the actual LLM API (OpenAI, Anthropic, etc.) + + # For now, return mock response for testing + if "table descriptions" in prompt.lower(): + return json.dumps({ + "table_description": "This table stores data related to the schema", + "column_descriptions": {} + }) + else: + return json.dumps([]) + + def _parse_table_description_response(self, response: str) -> dict[str, Any]: + """Parse the LLM response for table descriptions""" + try: + result = json.loads(response) + return { + "table_description": result.get("table_description"), + "column_descriptions": result.get("column_descriptions", {}), + } + except json.JSONDecodeError: + logger.error("Failed to parse LLM response as JSON") + return {"table_description": None, "column_descriptions": {}} + + def _parse_join_inference_response(self, response: str) -> list[dict[str, Any]]: + """Parse the LLM response for join inference""" + try: + joins = json.loads(response) + if not isinstance(joins, list): + logger.error("LLM response is not a list") + return [] + + # Validate and clean up each join + valid_joins = [] + for join in joins: + if self._validate_join(join): + valid_joins.append(join) + + return valid_joins + except json.JSONDecodeError: + logger.error("Failed to parse LLM response as JSON") + return [] + + def _validate_join(self, join: dict[str, Any]) -> bool: + """Validate a join object has required fields""" + required_fields = [ + "source_table", + "source_columns", + "target_table", + "target_columns", + ] + + for field in required_fields: + if field not in join: + logger.warning("Join missing required field: %s", field) + return False + + # Ensure columns are lists + if not isinstance(join["source_columns"], list): + join["source_columns"] = [join["source_columns"]] + if not isinstance(join["target_columns"], list): + join["target_columns"] = [join["target_columns"]] + + # Set defaults for optional fields + join.setdefault("join_type", "inner") + join.setdefault("cardinality", "N:1") + join.setdefault("confidence_score", 0.5) + join.setdefault("suggested_by", "ai_inference") + + return True \ No newline at end of file diff --git a/superset/config_llm.py b/superset/config_llm.py new file mode 100644 index 0000000000..43bf4364ec --- /dev/null +++ b/superset/config_llm.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Configuration settings for the LLM service used by the Database Analyzer. + +To use the database analyzer with LLM capabilities, set these environment variables: +- SUPERSET_LLM_API_KEY: Your API key for the LLM provider +- SUPERSET_LLM_MODEL: Model name (default: gpt-4o) +- SUPERSET_LLM_BASE_URL: API base URL (default: https://api.openai.com/v1) +""" + +import os + +# LLM configuration for database analyzer +# Set to None to disable LLM features (will still extract schema without AI descriptions) +LLM_API_KEY = os.environ.get("SUPERSET_LLM_API_KEY") + +# LLM model to use (e.g., "gpt-4o", "gpt-4", "gpt-3.5-turbo", "claude-3-opus") +LLM_MODEL = os.environ.get("SUPERSET_LLM_MODEL", "gpt-4o") + +# LLM API base URL (change for different providers or self-hosted models) +LLM_BASE_URL = os.environ.get("SUPERSET_LLM_BASE_URL", "https://api.openai.com/v1") + +# Temperature for LLM responses (lower = more deterministic) +LLM_TEMPERATURE = float(os.environ.get("SUPERSET_LLM_TEMPERATURE", "0.3")) + +# Maximum tokens for LLM responses +LLM_MAX_TOKENS = int(os.environ.get("SUPERSET_LLM_MAX_TOKENS", "4096")) \ No newline at end of file diff --git a/superset/databases/analyzer_api.py b/superset/databases/analyzer_api.py new file mode 100644 index 0000000000..578fe4c761 --- /dev/null +++ b/superset/databases/analyzer_api.py @@ -0,0 +1,296 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from typing import Any + +from flask import current_app, request, Response +from flask_appbuilder.api import BaseApi, expose, protect, safe +from marshmallow import fields, Schema, ValidationError + +from superset.extensions import db +from superset.models.database_analyzer import DatabaseSchemaReport +from superset.tasks.database_analyzer import ( + kickstart_analysis, + check_analysis_status, +) +from superset.views.base_api import BaseSupersetApi + +logger = logging.getLogger(__name__) + + +class AnalyzeSchemaRequestSchema(Schema): + """Schema for analyze schema request""" + database_id = fields.Integer(required=True) + schema_name = fields.String(required=True, validate=lambda x: len(x) > 0) + + +class AnalyzeSchemaResponseSchema(Schema): + """Schema for analyze schema response""" + run_id = fields.String(required=True) + database_report_id = fields.Integer(required=True) + status = fields.String(required=True) + + +class CheckStatusResponseSchema(Schema): + """Schema for check status response""" + run_id = fields.String(required=True) + database_report_id = fields.Integer(allow_none=True) + status = fields.String(required=True) + database_id = fields.Integer(allow_none=True) + schema_name = fields.String(allow_none=True) + started_at = fields.DateTime(allow_none=True) + completed_at = fields.DateTime(allow_none=True) + failed_at = fields.DateTime(allow_none=True) + error_message = fields.String(allow_none=True) + tables_count = fields.Integer(allow_none=True) + joins_count = fields.Integer(allow_none=True) + + +class DatabaseAnalyzerApi(BaseSupersetApi): + """API endpoints for database schema analyzer""" + + route_base = "/api/v1/database_analyzer" + + openapi_spec_tag = "Database Analyzer" + openapi_spec_methods = { + "analyze_schema": { + "post": { + "description": "Start a new database schema analysis", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": AnalyzeSchemaRequestSchema, + }, + }, + }, + "responses": { + "200": { + "description": "Analysis started successfully", + "content": { + "application/json": { + "schema": AnalyzeSchemaResponseSchema, + }, + }, + }, + "400": {"description": "Bad request"}, + "401": {"description": "Unauthorized"}, + "500": {"description": "Internal server error"}, + }, + }, + }, + "check_status": { + "get": { + "description": "Check the status of a running analysis", + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": True, + "schema": {"type": "string"}, + "description": "The run ID returned from analyze_schema", + }, + ], + "responses": { + "200": { + "description": "Status retrieved successfully", + "content": { + "application/json": { + "schema": CheckStatusResponseSchema, + }, + }, + }, + "404": {"description": "Analysis not found"}, + "500": {"description": "Internal server error"}, + }, + }, + }, + } + + @expose("/analyze", methods=("POST",)) + @protect() + @safe + def analyze_schema(self) -> Response: + """ + Start a new database schema analysis. + --- + post: + description: >- + Kickstart a Celery job to analyze database schema + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/AnalyzeSchemaRequestSchema' + responses: + 200: + description: Analysis started + content: + application/json: + schema: + $ref: '#/components/schemas/AnalyzeSchemaResponseSchema' + 400: + description: Bad request + 401: + description: Unauthorized + 500: + description: Internal server error + """ + try: + # Parse request body + schema = AnalyzeSchemaRequestSchema() + data = schema.load(request.json) + + # Start the analysis + result = kickstart_analysis( + database_id=data["database_id"], + schema_name=data["schema_name"], + ) + + return self.response_200(result) + + except ValidationError as e: + return self.response_400(message=str(e.messages)) + except Exception as e: + logger.exception("Error starting database analysis") + return self.response_500(message=str(e)) + + @expose("/status/<string:run_id>", methods=("GET",)) + @protect() + @safe + def check_status(self, run_id: str) -> Response: + """ + Check the status of a running analysis. + --- + get: + description: >- + Poll the status of a database schema analysis job + parameters: + - in: path + name: run_id + required: true + schema: + type: string + description: The run ID returned from analyze endpoint + responses: + 200: + description: Status retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/CheckStatusResponseSchema' + 404: + description: Analysis not found + 500: + description: Internal server error + """ + try: + result = check_analysis_status(run_id) + + if result["status"] == "not_found": + return self.response_404(message=result.get("message", "Analysis not found")) + + return self.response_200(result) + + except Exception as e: + logger.exception("Error checking analysis status") + return self.response_500(message=str(e)) + + @expose("/report/<int:report_id>", methods=("GET",)) + @protect() + @safe + def get_report(self, report_id: int) -> Response: + """ + Get the full analysis report. + --- + get: + description: >- + Retrieve the complete analysis report with tables, columns, and joins + parameters: + - in: path + name: report_id + required: true + schema: + type: integer + description: The database_report_id + responses: + 200: + description: Report retrieved + 404: + description: Report not found + 500: + description: Internal server error + """ + try: + report = db.session.query(DatabaseSchemaReport).get(report_id) + + if not report: + return self.response_404(message="Report not found") + + # Build the response + result = { + "id": report.id, + "database_id": report.database_id, + "schema_name": report.schema_name, + "status": report.status.value, + "created_at": report.created_on.isoformat() if report.created_on else None, + "tables": [], + "joins": [], + } + + # Add tables and columns + for table in report.tables: + table_data = { + "id": table.id, + "name": table.table_name, + "type": table.table_type.value, + "description": table.ai_description or table.db_comment, + "columns": [], + } + + for column in table.columns: + table_data["columns"].append({ + "id": column.id, + "name": column.column_name, + "type": column.data_type, + "position": column.ordinal_position, + "description": column.ai_description or column.db_comment, + }) + + result["tables"].append(table_data) + + # Add joins + for join in report.joins: + result["joins"].append({ + "id": join.id, + "source_table": join.source_table.table_name, + "source_columns": join.source_columns, + "target_table": join.target_table.table_name, + "target_columns": join.target_columns, + "join_type": join.join_type.value, + "cardinality": join.cardinality.value, + "semantic_context": join.semantic_context, + }) + + return self.response_200(result) + + except Exception as e: + logger.exception("Error retrieving report") + return self.response_500(message=str(e)) \ No newline at end of file diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 4f4fd361a4..024ebda7e7 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -162,6 +162,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods from superset.dashboards.filter_state.api import DashboardFilterStateRestApi from superset.dashboards.permalink.api import DashboardPermalinkRestApi from superset.databases.api import DatabaseRestApi + from superset.databases.analyzer_api import DatabaseAnalyzerApi from superset.datasets.api import DatasetRestApi from superset.datasets.columns.api import DatasetColumnsRestApi from superset.datasets.metrics.api import DatasetMetricRestApi @@ -251,6 +252,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods appbuilder.add_api(DashboardPermalinkRestApi) appbuilder.add_api(DashboardRestApi) appbuilder.add_api(DatabaseRestApi) + appbuilder.add_api(DatabaseAnalyzerApi) appbuilder.add_api(DatasetRestApi) appbuilder.add_api(DatasetColumnsRestApi) appbuilder.add_api(DatasetMetricRestApi) diff --git a/superset/migrations/versions/2025-12-17_07-54_c95466b0_add_database_analyzer_models.py b/superset/migrations/versions/2025-12-17_07-54_c95466b0_add_database_analyzer_models.py new file mode 100644 index 0000000000..f09f61070e --- /dev/null +++ b/superset/migrations/versions/2025-12-17_07-54_c95466b0_add_database_analyzer_models.py @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add database analyzer models + +Revision ID: c95466b0 +Revises: +Create Date: 2025-12-17 07:54:00.000000 + +""" + +# revision identifiers, used by Alembic. +revision = "c95466b0" +down_revision = None + +import sqlalchemy as sa +from alembic import op +from sqlalchemy_utils import UUIDType + + +def upgrade(): + # Create database_schema_report table + op.create_table( + "database_schema_report", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", UUIDType(binary=True), nullable=False), + sa.Column("database_id", sa.Integer(), nullable=False), + sa.Column("schema_name", sa.String(256), nullable=False), + sa.Column("celery_task_id", sa.String(256), nullable=True), + sa.Column("status", sa.String(50), server_default="reserved", nullable=False), + sa.Column("reserved_dttm", sa.DateTime(), nullable=True), + sa.Column("start_dttm", sa.DateTime(), nullable=True), + sa.Column("end_dttm", sa.DateTime(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("extra_json", sa.Text(), nullable=True), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id", name="pk_database_schema_report"), + sa.UniqueConstraint("uuid", name="uq_database_schema_report_uuid"), + sa.UniqueConstraint( + "database_id", + "schema_name", + name="uq_database_schema_report_database_schema", + ), + sa.ForeignKeyConstraint( + ["database_id"], + ["dbs.id"], + name="fk_database_schema_report_database_id_dbs", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + name="fk_database_schema_report_created_by_fk_ab_user", + ), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + name="fk_database_schema_report_changed_by_fk_ab_user", + ), + sa.CheckConstraint( + "status IN ('reserved', 'running', 'completed', 'failed')", + name="ck_database_schema_report_status", + ), + ) + + # Create indexes for database_schema_report + op.create_index( + "ix_database_schema_report_database_id", + "database_schema_report", + ["database_id"], + ) + op.create_index( + "ix_database_schema_report_status", + "database_schema_report", + ["status"], + ) + op.create_index( + "ix_database_schema_report_celery_task_id", + "database_schema_report", + ["celery_task_id"], + ) + op.create_index( + "ix_database_schema_report_database_schema", + "database_schema_report", + ["database_id", "schema_name"], + ) + + # Create analyzed_table table + op.create_table( + "analyzed_table", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", UUIDType(binary=True), nullable=False), + sa.Column("report_id", sa.Integer(), nullable=False), + sa.Column("table_name", sa.String(256), nullable=False), + sa.Column("table_type", sa.String(50), nullable=False), + sa.Column("db_comment", sa.Text(), nullable=True), + sa.Column("ai_description", sa.Text(), nullable=True), + sa.Column("extra_json", sa.Text(), nullable=True), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id", name="pk_analyzed_table"), + sa.UniqueConstraint("uuid", name="uq_analyzed_table_uuid"), + sa.UniqueConstraint( + "report_id", + "table_name", + name="uq_analyzed_table_report_table", + ), + sa.ForeignKeyConstraint( + ["report_id"], + ["database_schema_report.id"], + name="fk_analyzed_table_report_id_database_schema_report", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + name="fk_analyzed_table_created_by_fk_ab_user", + ), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + name="fk_analyzed_table_changed_by_fk_ab_user", + ), + sa.CheckConstraint( + "table_type IN ('table', 'view', 'materialized_view')", + name="ck_analyzed_table_table_type", + ), + ) + + # Create indexes for analyzed_table + op.create_index( + "ix_analyzed_table_report_id", + "analyzed_table", + ["report_id"], + ) + op.create_index( + "ix_analyzed_table_table_type", + "analyzed_table", + ["table_type"], + ) + op.create_index( + "ix_analyzed_table_report_type", + "analyzed_table", + ["report_id", "table_type"], + ) + + # Create analyzed_column table + op.create_table( + "analyzed_column", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", UUIDType(binary=True), nullable=False), + sa.Column("table_id", sa.Integer(), nullable=False), + sa.Column("column_name", sa.String(256), nullable=False), + sa.Column("data_type", sa.String(256), nullable=False), + sa.Column("ordinal_position", sa.Integer(), nullable=False), + sa.Column("db_comment", sa.Text(), nullable=True), + sa.Column("ai_description", sa.Text(), nullable=True), + sa.Column("extra_json", sa.Text(), nullable=True), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id", name="pk_analyzed_column"), + sa.UniqueConstraint("uuid", name="uq_analyzed_column_uuid"), + sa.UniqueConstraint( + "table_id", + "column_name", + name="uq_analyzed_column_table_column", + ), + sa.ForeignKeyConstraint( + ["table_id"], + ["analyzed_table.id"], + name="fk_analyzed_column_table_id_analyzed_table", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + name="fk_analyzed_column_created_by_fk_ab_user", + ), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + name="fk_analyzed_column_changed_by_fk_ab_user", + ), + sa.CheckConstraint( + "ordinal_position >= 1", + name="ck_analyzed_column_ordinal_position", + ), + ) + + # Create indexes for analyzed_column + op.create_index( + "ix_analyzed_column_table_id", + "analyzed_column", + ["table_id"], + ) + op.create_index( + "ix_analyzed_column_data_type", + "analyzed_column", + ["data_type"], + ) + op.create_index( + "ix_analyzed_column_table_position", + "analyzed_column", + ["table_id", "ordinal_position"], + ) + + # Create inferred_join table + op.create_table( + "inferred_join", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", UUIDType(binary=True), nullable=False), + sa.Column("report_id", sa.Integer(), nullable=False), + sa.Column("source_table_id", sa.Integer(), nullable=False), + sa.Column("target_table_id", sa.Integer(), nullable=False), + sa.Column("source_columns", sa.Text(), nullable=False), + sa.Column("target_columns", sa.Text(), nullable=False), + sa.Column("join_type", sa.String(50), server_default="inner", nullable=False), + sa.Column("cardinality", sa.String(50), nullable=False), + sa.Column("semantic_context", sa.Text(), nullable=True), + sa.Column("extra_json", sa.Text(), nullable=True), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id", name="pk_inferred_join"), + sa.UniqueConstraint("uuid", name="uq_inferred_join_uuid"), + sa.ForeignKeyConstraint( + ["report_id"], + ["database_schema_report.id"], + name="fk_inferred_join_report_id_database_schema_report", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["source_table_id"], + ["analyzed_table.id"], + name="fk_inferred_join_source_table_id_analyzed_table", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["target_table_id"], + ["analyzed_table.id"], + name="fk_inferred_join_target_table_id_analyzed_table", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + name="fk_inferred_join_created_by_fk_ab_user", + ), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + name="fk_inferred_join_changed_by_fk_ab_user", + ), + sa.CheckConstraint( + "join_type IN ('inner', 'left', 'right', 'full', 'cross')", + name="ck_inferred_join_join_type", + ), + sa.CheckConstraint( + "cardinality IN ('1:1', '1:N', 'N:1', 'N:M')", + name="ck_inferred_join_cardinality", + ), + ) + + # Create indexes for inferred_join + op.create_index( + "ix_inferred_join_report_id", + "inferred_join", + ["report_id"], + ) + op.create_index( + "ix_inferred_join_source_table_id", + "inferred_join", + ["source_table_id"], + ) + op.create_index( + "ix_inferred_join_target_table_id", + "inferred_join", + ["target_table_id"], + ) + op.create_index( + "ix_inferred_join_source_target", + "inferred_join", + ["source_table_id", "target_table_id"], + ) + op.create_index( + "ix_inferred_join_join_type", + "inferred_join", + ["join_type"], + ) + + +def downgrade(): + # Drop indexes for inferred_join + op.drop_index("ix_inferred_join_join_type", table_name="inferred_join") + op.drop_index("ix_inferred_join_source_target", table_name="inferred_join") + op.drop_index("ix_inferred_join_target_table_id", table_name="inferred_join") + op.drop_index("ix_inferred_join_source_table_id", table_name="inferred_join") + op.drop_index("ix_inferred_join_report_id", table_name="inferred_join") + + # Drop inferred_join table + op.drop_table("inferred_join") + + # Drop indexes for analyzed_column + op.drop_index("ix_analyzed_column_table_position", table_name="analyzed_column") + op.drop_index("ix_analyzed_column_data_type", table_name="analyzed_column") + op.drop_index("ix_analyzed_column_table_id", table_name="analyzed_column") + + # Drop analyzed_column table + op.drop_table("analyzed_column") + + # Drop indexes for analyzed_table + op.drop_index("ix_analyzed_table_report_type", table_name="analyzed_table") + op.drop_index("ix_analyzed_table_table_type", table_name="analyzed_table") + op.drop_index("ix_analyzed_table_report_id", table_name="analyzed_table") + + # Drop analyzed_table table + op.drop_table("analyzed_table") + + # Drop indexes for database_schema_report + op.drop_index("ix_database_schema_report_database_schema", table_name="database_schema_report") + op.drop_index("ix_database_schema_report_celery_task_id", table_name="database_schema_report") + op.drop_index("ix_database_schema_report_status", table_name="database_schema_report") + op.drop_index("ix_database_schema_report_database_id", table_name="database_schema_report") + + # Drop database_schema_report table + op.drop_table("database_schema_report") \ No newline at end of file diff --git a/superset/models/database_analyzer.py b/superset/models/database_analyzer.py new file mode 100644 index 0000000000..6077e11e13 --- /dev/null +++ b/superset/models/database_analyzer.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import enum +from datetime import datetime +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from flask_appbuilder import Model +from sqlalchemy.orm import relationship + +from superset.models.helpers import AuditMixinNullable, UUIDMixin + +if TYPE_CHECKING: + from superset.models.core import Database + + +class AnalysisStatus(str, enum.Enum): + RESERVED = "reserved" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class DatabaseSchemaReport(Model, AuditMixinNullable, UUIDMixin): + """Tracks database schema analysis runs. ONE active report per database+schema.""" + + __tablename__ = "database_schema_report" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column( + sa.Integer, sa.ForeignKey("dbs.id", ondelete="CASCADE"), nullable=False + ) + schema_name = sa.Column(sa.String(256), nullable=False) + celery_task_id = sa.Column(sa.String(256), nullable=True) + status = sa.Column( + sa.Enum(AnalysisStatus), + default=AnalysisStatus.RESERVED, + nullable=False, + ) + reserved_dttm = sa.Column(sa.DateTime, nullable=True) + start_dttm = sa.Column(sa.DateTime, nullable=True) + end_dttm = sa.Column(sa.DateTime, nullable=True) + error_message = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(sa.Text, nullable=True) + + # Relationships + database = relationship("Database", backref="schema_reports") + tables = relationship( + "AnalyzedTable", + back_populates="report", + cascade="all, delete-orphan", + ) + joins = relationship( + "InferredJoin", + back_populates="report", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + sa.UniqueConstraint("database_id", "schema_name", name="uq_database_schema_report_database_schema"), + sa.CheckConstraint( + "status IN ('reserved', 'running', 'completed', 'failed')", + name="ck_database_schema_report_status" + ), + ) + + +class TableType(str, enum.Enum): + TABLE = "table" + VIEW = "view" + MATERIALIZED_VIEW = "materialized_view" + + +class AnalyzedTable(Model, AuditMixinNullable, UUIDMixin): + """Stores metadata for each table/view discovered during schema analysis.""" + + __tablename__ = "analyzed_table" + + id = sa.Column(sa.Integer, primary_key=True) + report_id = sa.Column( + sa.Integer, + sa.ForeignKey("database_schema_report.id", ondelete="CASCADE"), + nullable=False, + ) + table_name = sa.Column(sa.String(256), nullable=False) + table_type = sa.Column( + sa.Enum(TableType), + nullable=False, + ) + db_comment = sa.Column(sa.Text, nullable=True) + ai_description = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(sa.Text, nullable=True) + + # Relationships + report = relationship("DatabaseSchemaReport", back_populates="tables") + columns = relationship( + "AnalyzedColumn", + back_populates="table", + cascade="all, delete-orphan", + ) + source_joins = relationship( + "InferredJoin", + back_populates="source_table", + foreign_keys="InferredJoin.source_table_id", + cascade="all, delete-orphan", + ) + target_joins = relationship( + "InferredJoin", + back_populates="target_table", + foreign_keys="InferredJoin.target_table_id", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + sa.UniqueConstraint("report_id", "table_name", name="uq_analyzed_table_report_table"), + sa.CheckConstraint( + "table_type IN ('table', 'view', 'materialized_view')", + name="ck_analyzed_table_table_type" + ), + ) + + +class AnalyzedColumn(Model, AuditMixinNullable, UUIDMixin): + """Stores metadata for each column discovered during schema analysis.""" + + __tablename__ = "analyzed_column" + + id = sa.Column(sa.Integer, primary_key=True) + table_id = sa.Column( + sa.Integer, + sa.ForeignKey("analyzed_table.id", ondelete="CASCADE"), + nullable=False, + ) + column_name = sa.Column(sa.String(256), nullable=False) + data_type = sa.Column(sa.String(256), nullable=False) + ordinal_position = sa.Column(sa.Integer, nullable=False) + db_comment = sa.Column(sa.Text, nullable=True) + ai_description = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(sa.Text, nullable=True) + + # Relationships + table = relationship("AnalyzedTable", back_populates="columns") + + __table_args__ = ( + sa.UniqueConstraint("table_id", "column_name", name="uq_analyzed_column_table_column"), + sa.CheckConstraint("ordinal_position >= 1", name="ck_analyzed_column_ordinal_position"), + ) + + +class JoinType(str, enum.Enum): + INNER = "inner" + LEFT = "left" + RIGHT = "right" + FULL = "full" + CROSS = "cross" + + +class Cardinality(str, enum.Enum): + ONE_TO_ONE = "1:1" + ONE_TO_MANY = "1:N" + MANY_TO_ONE = "N:1" + MANY_TO_MANY = "N:M" + + +class InferredJoin(Model, AuditMixinNullable, UUIDMixin): + """Stores ALL AI-inferred joins. Multiple joins per table pair allowed.""" + + __tablename__ = "inferred_join" + + id = sa.Column(sa.Integer, primary_key=True) + report_id = sa.Column( + sa.Integer, + sa.ForeignKey("database_schema_report.id", ondelete="CASCADE"), + nullable=False, + ) + source_table_id = sa.Column( + sa.Integer, + sa.ForeignKey("analyzed_table.id", ondelete="CASCADE"), + nullable=False, + ) + target_table_id = sa.Column( + sa.Integer, + sa.ForeignKey("analyzed_table.id", ondelete="CASCADE"), + nullable=False, + ) + source_columns = sa.Column(sa.Text, nullable=False) # JSON array + target_columns = sa.Column(sa.Text, nullable=False) # JSON array + join_type = sa.Column( + sa.Enum(JoinType), + default=JoinType.INNER, + nullable=False, + ) + cardinality = sa.Column( + sa.Enum(Cardinality), + nullable=False, + ) + semantic_context = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(sa.Text, nullable=True) + + # Relationships + report = relationship("DatabaseSchemaReport", back_populates="joins") + source_table = relationship( + "AnalyzedTable", + back_populates="source_joins", + foreign_keys=[source_table_id], + ) + target_table = relationship( + "AnalyzedTable", + back_populates="target_joins", + foreign_keys=[target_table_id], + ) + + __table_args__ = ( + sa.CheckConstraint( + "join_type IN ('inner', 'left', 'right', 'full', 'cross')", + name="ck_inferred_join_join_type" + ), + sa.CheckConstraint( + "cardinality IN ('1:1', '1:N', 'N:1', 'N:M')", + name="ck_inferred_join_cardinality" + ), + ) \ No newline at end of file diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 2049246f04..05a7dec33c 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -34,7 +34,7 @@ flask_app = create_app() # Need to import late, as the celery_app will have been setup by "create_app()" # ruff: noqa: E402, F401 # pylint: disable=wrong-import-position, unused-import -from . import cache, scheduler +from . import cache, database_analyzer, scheduler # Export the celery app globally for Celery (as run on the cmd line) to find app = celery_app diff --git a/superset/tasks/database_analyzer.py b/superset/tasks/database_analyzer.py new file mode 100644 index 0000000000..297ed7dba1 --- /dev/null +++ b/superset/tasks/database_analyzer.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import uuid +from datetime import datetime +from typing import Any + +from celery import Task +from celery.exceptions import SoftTimeLimitExceeded +from flask import current_app + +from superset import db +from superset.extensions import celery_app +from superset.models.database_analyzer import ( + AnalysisStatus, + DatabaseSchemaReport, +) + +logger = logging.getLogger(__name__) + + +class DatabaseAnalyzerTask(Task): + """Base task class for database analyzer with retry logic""" + + autoretry_for = (Exception,) + retry_kwargs = {"max_retries": 3} + retry_backoff = True + retry_backoff_max = 600 + retry_jitter = True + + +@celery_app.task( + name="analyze_database_schema", + base=DatabaseAnalyzerTask, + soft_time_limit=3600, # 1 hour soft limit + time_limit=3900, # 1 hour 5 min hard limit +) +def analyze_database_schema(report_id: int) -> dict[str, Any]: + """ + Celery task to analyze database schema and generate metadata. + + :param report_id: ID of the DatabaseSchemaReport to process + :return: Dict with status and results + """ + from superset.commands.database_analyzer.analyze import ( + AnalyzeDatabaseSchemaCommand, + ) + + logger.info("Starting database schema analysis for report_id: %s", report_id) + + try: + # Update status to running + report = db.session.query(DatabaseSchemaReport).get(report_id) + if not report: + logger.error("Report with id %s not found", report_id) + return {"status": "error", "message": f"Report {report_id} not found"} + + report.status = AnalysisStatus.RUNNING + report.start_dttm = datetime.now() + db.session.commit() + + # Execute the analysis command + command = AnalyzeDatabaseSchemaCommand(report_id) + result = command.run() + + # Update status to completed + report.status = AnalysisStatus.COMPLETED + report.end_dttm = datetime.now() + db.session.commit() + + logger.info("Successfully completed analysis for report_id: %s", report_id) + return { + "status": "completed", + "database_report_id": report_id, + "tables_analyzed": result.get("tables_count", 0), + "joins_inferred": result.get("joins_count", 0), + } + + except SoftTimeLimitExceeded: + logger.error("Task timed out for report_id: %s", report_id) + _mark_report_failed(report_id, "Task timed out after 1 hour") + return {"status": "error", "message": "Task timed out"} + + except Exception as e: + logger.exception("Error analyzing database schema for report_id: %s", report_id) + _mark_report_failed(report_id, str(e)) + raise + + +def _mark_report_failed(report_id: int, error_message: str) -> None: + """Mark a report as failed with error message""" + try: + report = db.session.query(DatabaseSchemaReport).get(report_id) + if report: + report.status = AnalysisStatus.FAILED + report.end_dttm = datetime.now() + report.error_message = error_message + db.session.commit() + except Exception: + logger.exception("Failed to update report status to failed") + + +def kickstart_analysis(database_id: int, schema_name: str) -> dict[str, Any]: + """ + Kickstart a new database schema analysis or return existing run_id. + + :param database_id: ID of the database to analyze + :param schema_name: Name of the schema to analyze + :return: Dict with run_id and database_report_id + """ + # Check for existing report + existing = ( + db.session.query(DatabaseSchemaReport) + .filter_by(database_id=database_id, schema_name=schema_name) + .first() + ) + + if existing and existing.status in (AnalysisStatus.RESERVED, AnalysisStatus.RUNNING): + # Job already in progress - return existing run_id + logger.info( + "Analysis already in progress for database %s schema %s", + database_id, + schema_name, + ) + return { + "run_id": existing.celery_task_id, + "database_report_id": existing.id, + "status": existing.status.value, + } + + if existing and existing.status in (AnalysisStatus.COMPLETED, AnalysisStatus.FAILED): + # Delete old report (cascades to all related data) + logger.info( + "Deleting old report for database %s schema %s", + database_id, + schema_name, + ) + db.session.delete(existing) + db.session.flush() + + # Create new report + task_id = str(uuid.uuid4()) + report = DatabaseSchemaReport( + database_id=database_id, + schema_name=schema_name, + celery_task_id=task_id, + status=AnalysisStatus.RESERVED, + reserved_dttm=datetime.now(), + ) + db.session.add(report) + db.session.commit() + + # Trigger Celery job + analyze_database_schema.apply_async(args=[report.id], task_id=task_id) + + logger.info( + "Started new analysis for database %s schema %s with run_id %s", + database_id, + schema_name, + task_id, + ) + + return { + "run_id": task_id, + "database_report_id": report.id, + "status": "reserved", + } + + +def check_analysis_status(run_id: str) -> dict[str, Any]: + """ + Check the status of a running analysis by run_id. + + :param run_id: The Celery task ID (run_id) + :return: Dict with status and results + """ + report = ( + db.session.query(DatabaseSchemaReport) + .filter_by(celery_task_id=run_id) + .first() + ) + + if not report: + return {"status": "not_found", "message": f"No analysis found for run_id {run_id}"} + + result = { + "run_id": run_id, + "database_report_id": report.id, + "status": report.status.value, + "database_id": report.database_id, + "schema_name": report.schema_name, + } + + if report.status == AnalysisStatus.RUNNING: + result["started_at"] = report.start_dttm.isoformat() if report.start_dttm else None + + elif report.status == AnalysisStatus.COMPLETED: + result["started_at"] = report.start_dttm.isoformat() if report.start_dttm else None + result["completed_at"] = report.end_dttm.isoformat() if report.end_dttm else None + result["tables_count"] = len(report.tables) + result["joins_count"] = len(report.joins) + + elif report.status == AnalysisStatus.FAILED: + result["error_message"] = report.error_message + result["failed_at"] = report.end_dttm.isoformat() if report.end_dttm else None + + return result \ No newline at end of file diff --git a/tests/unit_tests/commands/test_database_analyzer.py b/tests/unit_tests/commands/test_database_analyzer.py new file mode 100644 index 0000000000..4a6d8aeb5a --- /dev/null +++ b/tests/unit_tests/commands/test_database_analyzer.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import MagicMock, patch + +import pytest + +from superset.models.database_analyzer import ( + AnalysisStatus, + DatabaseSchemaReport, +) +from superset.tasks.database_analyzer import ( + kickstart_analysis, + check_analysis_status, +) + + +@patch("superset.tasks.database_analyzer.db") +@patch("superset.tasks.database_analyzer.analyze_database_schema") +def test_kickstart_analysis_new(mock_task, mock_db): + """Test kickstarting a new analysis""" + # Mock no existing report + mock_db.session.query().filter_by().first.return_value = None + + # Mock the task + mock_task.apply_async = MagicMock() + + result = kickstart_analysis(database_id=1, schema_name="public") + + assert "run_id" in result + assert "database_report_id" in result + assert result["status"] == "reserved" + + # Verify task was triggered + mock_task.apply_async.assert_called_once() + + +@patch("superset.tasks.database_analyzer.db") +def test_kickstart_analysis_existing_running(mock_db): + """Test kickstarting when analysis already running""" + # Mock existing report in running state + existing = MagicMock() + existing.status = AnalysisStatus.RUNNING + existing.celery_task_id = "existing-task-id" + existing.id = 123 + mock_db.session.query().filter_by().first.return_value = existing + + result = kickstart_analysis(database_id=1, schema_name="public") + + assert result["run_id"] == "existing-task-id" + assert result["database_report_id"] == 123 + assert result["status"] == "running" + + +@patch("superset.tasks.database_analyzer.db") +@patch("superset.tasks.database_analyzer.analyze_database_schema") +def test_kickstart_analysis_replace_completed(mock_task, mock_db): + """Test kickstarting replaces completed analysis""" + # Mock existing completed report + existing = MagicMock() + existing.status = AnalysisStatus.COMPLETED + mock_db.session.query().filter_by().first.return_value = existing + + # Mock the task + mock_task.apply_async = MagicMock() + + result = kickstart_analysis(database_id=1, schema_name="public") + + # Verify old report was deleted + mock_db.session.delete.assert_called_once_with(existing) + + # Verify new analysis was started + assert "run_id" in result + assert result["status"] == "reserved" + mock_task.apply_async.assert_called_once() + + +@patch("superset.tasks.database_analyzer.db") +def test_check_analysis_status_not_found(mock_db): + """Test checking status for non-existent analysis""" + mock_db.session.query().filter_by().first.return_value = None + + result = check_analysis_status("unknown-id") + + assert result["status"] == "not_found" + assert "message" in result + + +@patch("superset.tasks.database_analyzer.db") +def test_check_analysis_status_running(mock_db): + """Test checking status for running analysis""" + report = MagicMock() + report.id = 123 + report.status = AnalysisStatus.RUNNING + report.database_id = 1 + report.schema_name = "public" + report.start_dttm = MagicMock() + report.start_dttm.isoformat.return_value = "2024-01-01T00:00:00" + + mock_db.session.query().filter_by().first.return_value = report + + result = check_analysis_status("test-run-id") + + assert result["status"] == "running" + assert result["database_report_id"] == 123 + assert result["started_at"] == "2024-01-01T00:00:00" + + +@patch("superset.tasks.database_analyzer.db") +def test_check_analysis_status_completed(mock_db): + """Test checking status for completed analysis""" + report = MagicMock() + report.id = 123 + report.status = AnalysisStatus.COMPLETED + report.database_id = 1 + report.schema_name = "public" + report.start_dttm = MagicMock() + report.start_dttm.isoformat.return_value = "2024-01-01T00:00:00" + report.end_dttm = MagicMock() + report.end_dttm.isoformat.return_value = "2024-01-01T01:00:00" + report.tables = [MagicMock(), MagicMock()] # 2 tables + report.joins = [MagicMock()] # 1 join + + mock_db.session.query().filter_by().first.return_value = report + + result = check_analysis_status("test-run-id") + + assert result["status"] == "completed" + assert result["database_report_id"] == 123 + assert result["completed_at"] == "2024-01-01T01:00:00" + assert result["tables_count"] == 2 + assert result["joins_count"] == 1 + + +@patch("superset.tasks.database_analyzer.db") +def test_check_analysis_status_failed(mock_db): + """Test checking status for failed analysis""" + report = MagicMock() + report.id = 123 + report.status = AnalysisStatus.FAILED + report.database_id = 1 + report.schema_name = "public" + report.error_message = "Connection failed" + report.end_dttm = MagicMock() + report.end_dttm.isoformat.return_value = "2024-01-01T00:30:00" + + mock_db.session.query().filter_by().first.return_value = report + + result = check_analysis_status("test-run-id") + + assert result["status"] == "failed" + assert result["error_message"] == "Connection failed" + assert result["failed_at"] == "2024-01-01T00:30:00" \ No newline at end of file
