This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new e8f5d7680f fix: upgrade_catalog_perms and downgrade_catalog_perms
implementation (#29860)
e8f5d7680f is described below
commit e8f5d7680ff14342b2ed46cc0b8c3bd4463fa3c2
Author: Michael S. Molina <[email protected]>
AuthorDate: Fri Aug 16 08:39:36 2024 -0400
fix: upgrade_catalog_perms and downgrade_catalog_perms implementation
(#29860)
---
superset/migrations/shared/catalogs.py | 360 ++++++++++++++++++++++++---------
1 file changed, 269 insertions(+), 91 deletions(-)
diff --git a/superset/migrations/shared/catalogs.py
b/superset/migrations/shared/catalogs.py
index b09c71739f..b75214291b 100644
--- a/superset/migrations/shared/catalogs.py
+++ b/superset/migrations/shared/catalogs.py
@@ -18,7 +18,8 @@
from __future__ import annotations
import logging
-from typing import Any, Type
+from datetime import datetime
+from typing import Any, Type, Union
import sqlalchemy as sa
from alembic import op
@@ -35,8 +36,7 @@ from superset.migrations.shared.security_converge import (
)
from superset.models.core import Database
-logger = logging.getLogger(__name__)
-
+logger = logging.getLogger("alembic")
Base: Type[Any] = declarative_base()
@@ -95,6 +95,16 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))
+ModelType = Union[Type[Query], Type[SavedQuery], Type[TabState],
Type[TableSchema]]
+
+MODELS: list[tuple[ModelType, str]] = [
+ (Query, "database_id"),
+ (SavedQuery, "db_id"),
+ (TabState, "database_id"),
+ (TableSchema, "database_id"),
+]
+
+
def get_known_schemas(database_name: str, session: Session) -> list[str]:
"""
Read all known schemas from the existing schema permissions.
@@ -112,6 +122,234 @@ def get_known_schemas(database_name: str, session:
Session) -> list[str]:
return sorted({name[0][1:-1].split("].[")[-1] for name in names})
+def get_batch_size(session: Session) -> int:
+ max_sqlite_in = 999
+ return max_sqlite_in if session.bind.dialect.name == "sqlite" else
1_000_000
+
+
+def print_processed_batch(
+ start_time: datetime,
+ offset: int,
+ total_rows: int,
+ model: ModelType,
+ batch_size: int,
+) -> None:
+ """
+ Print the progress of batch processing.
+
+ This function logs the progress of processing a batch of rows from a model.
+ It calculates the elapsed time since the start of the batch processing and
+ logs the number of rows processed along with the percentage completion.
+
+ Parameters:
+ start_time (datetime): The start time of the batch processing.
+ offset (int): The current offset in the batch processing.
+ total_rows (int): The total number of rows to process.
+ model (ModelType): The model being processed.
+ batch_size (int): The size of the batch being processed.
+ """
+ elapsed_time = datetime.now() - start_time
+ elapsed_seconds = elapsed_time.total_seconds()
+ elapsed_formatted = f"{int(elapsed_seconds //
3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds %
60):02}"
+ rows_processed = min(offset + batch_size, total_rows)
+ logger.info(
+ f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,}
{model.__tablename__} rows processed "
+ f"({(rows_processed / total_rows) * 100:.2f}%)"
+ )
+
+
+def update_catalog_column(
+ session: Session, database: Database, catalog: str, downgrade: bool = False
+) -> None:
+ """
+ Update the `catalog` column in the specified models to the given catalog.
+
+ This function iterates over a list of models defined by MODELS and updates
+ the `catalog` columnto the specified catalog or None depending on the
downgrade
+ parameter. The update is performed in batches to optimize performance and
reduce
+ memory usage.
+
+ Parameters:
+ session (Session): The SQLAlchemy session to use for database
operations.
+ database (Database): The database instance containing the models to
update.
+ catalog (Catalog): The new catalog value to set in the `catalog`
column or
+ the default catalog if `downgrade` is True.
+ downgrade (bool): If True, the `catalog` column is set to None where
the
+ catalog matches the specified catalog.
+ """
+ start_time = datetime.now()
+
+ logger.info(f"Updating {database.database_name} models to catalog
{catalog}")
+
+ for model, column in MODELS:
+ # Get the total number of rows that match the condition
+ total_rows = (
+ session.query(sa.func.count(model.id))
+ .filter(getattr(model, column) == database.id)
+ .filter(model.catalog == catalog if downgrade else True)
+ .scalar()
+ )
+
+ logger.info(
+ f"Total rows to be processed for {model.__tablename__}:
{total_rows:,}"
+ )
+
+ batch_size = get_batch_size(session)
+ limit_value = min(batch_size, total_rows)
+
+ # Update in batches using row numbers
+ for i in range(0, total_rows, batch_size):
+ subquery = (
+ session.query(model.id)
+ .filter(getattr(model, column) == database.id)
+ .filter(model.catalog == catalog if downgrade else True)
+ .order_by(model.id)
+ .offset(i)
+ .limit(limit_value)
+ .subquery()
+ )
+
+ # SQLite does not support multiple-table criteria within UPDATE
+ if session.bind.dialect.name == "sqlite":
+ ids_to_update = [row.id for row in
session.query(subquery.c.id).all()]
+ if ids_to_update:
+ session.execute(
+ sa.update(model)
+ .where(model.id.in_(ids_to_update))
+ .values(catalog=None if downgrade else catalog)
+ .execution_options(synchronize_session=False)
+ )
+ else:
+ session.execute(
+ sa.update(model)
+ .where(model.id == subquery.c.id)
+ .values(catalog=None if downgrade else catalog)
+ .execution_options(synchronize_session=False)
+ )
+
+ print_processed_batch(start_time, i, total_rows, model, batch_size)
+
+
+def update_schema_catalog_perms(
+ session: Session,
+ database: Database,
+ catalog_perm: str | None,
+ catalog: str,
+ downgrade: bool = False,
+) -> None:
+ """
+ Update schema and catalog permissions for tables and charts in a given
database.
+
+ This function updates the `catalog`, `catalog_perm`, and `schema_perm`
fields for
+ tables and charts associated with the specified database. If `downgrade`
is True,
+ the `catalog` and `catalog_perm` fields are set to None, otherwise they
are set
+ to the provided `catalog` and `catalog_perm` values.
+
+ Args:
+ session (Session): The SQLAlchemy session to use for database
operations.
+ database (Database): The database object whose tables and charts will
be updated.
+ catalog_perm (str): The new catalog permission to set.
+ catalog (str): The new catalog to set.
+ downgrade (bool, optional): If True, reset the `catalog` and
`catalog_perm` fields to None.
+ Defaults to False.
+ """
+ # Mapping of table id to schema permission
+ mapping = {}
+
+ for table in (
+ session.query(SqlaTable)
+ .filter_by(database_id=database.id)
+ .filter_by(catalog=catalog if downgrade else None)
+ ):
+ schema_perm = security_manager.get_schema_perm(
+ database.database_name,
+ None if downgrade else catalog,
+ table.schema,
+ )
+ table.catalog = None if downgrade else catalog
+ table.catalog_perm = catalog_perm
+ table.schema_perm = schema_perm
+ mapping[table.id] = schema_perm
+
+ # Select all slices of type table that belong to the database
+ for chart in (
+ session.query(Slice)
+ .join(SqlaTable, Slice.datasource_id == SqlaTable.id)
+ .join(Database, SqlaTable.database_id == Database.id)
+ .filter(Database.id == database.id)
+ .filter(Slice.datasource_type == "table")
+ ):
+ # We only care about tables that exist in the mapping
+ if mapping.get(chart.datasource_id) is not None:
+ chart.catalog_perm = catalog_perm
+ chart.schema_perm = mapping[chart.datasource_id]
+
+
+def delete_models_non_default_catalog(
+ session: Session, database: Database, catalog: str
+) -> None:
+ """
+ Delete models that are not in the default catalog.
+
+ This function iterates over a list of models defined by MODELS and deletes
+ the rows where the `catalog` column does not match the specified catalog.
+
+ Parameters:
+ session (Session): The SQLAlchemy session to use for database
operations.
+ database (Database): The database instance containing the models to
delete.
+ catalog (Catalog): The catalog to use to filter the models to delete.
+ """
+ start_time = datetime.now()
+
+ logger.info(f"Deleting models not in the default catalog: {catalog}")
+
+ for model, column in MODELS:
+ # Get the total number of rows that match the condition
+ total_rows = (
+ session.query(sa.func.count(model.id))
+ .filter(getattr(model, column) == database.id)
+ .filter(model.catalog != catalog)
+ .scalar()
+ )
+
+ logger.info(
+ f"Total rows to be processed for {model.__tablename__}:
{total_rows:,}"
+ )
+
+ batch_size = get_batch_size(session)
+ limit_value = min(batch_size, total_rows)
+
+ # Update in batches using row numbers
+ for i in range(0, total_rows, batch_size):
+ subquery = (
+ session.query(model.id)
+ .filter(getattr(model, column) == database.id)
+ .filter(model.catalog != catalog)
+ .order_by(model.id)
+ .offset(i)
+ .limit(limit_value)
+ .subquery()
+ )
+
+ # SQLite does not support multiple-table criteria within DELETE
+ if session.bind.dialect.name == "sqlite":
+ ids_to_delete = [row.id for row in
session.query(subquery.c.id).all()]
+ if ids_to_delete:
+ session.execute(
+ sa.delete(model)
+ .where(model.id.in_(ids_to_delete))
+ .execution_options(synchronize_session=False)
+ )
+ else:
+ session.execute(
+ sa.delete(model)
+ .where(model.id == subquery.c.id)
+ .execution_options(synchronize_session=False)
+ )
+
+ print_processed_batch(start_time, i, total_rows, model, batch_size)
+
+
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models and permissions when catalogs are introduced in a DB engine
spec.
@@ -157,11 +395,13 @@ def upgrade_database_catalogs(
"""
Upgrade a given database to support the default catalog.
"""
- catalog_perm = security_manager.get_catalog_perm(
+ catalog_perm: str | None = security_manager.get_catalog_perm(
database.database_name,
default_catalog,
)
- pvms: dict[str, tuple[str, ...]] = {catalog_perm: ("catalog_access",)}
+ pvms: dict[str, tuple[str, ...]] = (
+ {catalog_perm: ("catalog_access",)} if catalog_perm else {}
+ )
# rename existing schema permissions to include the catalog, and also find
any new
# schemas
@@ -170,39 +410,10 @@ def upgrade_database_catalogs(
# update existing models that have a `catalog` column so it points to the
default
# catalog
- models = [
- (Query, "database_id"),
- (SavedQuery, "db_id"),
- (TabState, "database_id"),
- (TableSchema, "database_id"),
- ]
- for model, column in models:
- for instance in session.query(model).filter(
- getattr(model, column) == database.id
- ):
- instance.catalog = default_catalog
+ update_catalog_column(session, database, default_catalog, False)
# update `schema_perm` and `catalog_perm` for tables and charts
- for table in session.query(SqlaTable).filter_by(
- database_id=database.id,
- catalog=None,
- ):
- schema_perm = security_manager.get_schema_perm(
- database.database_name,
- default_catalog,
- table.schema,
- )
-
- table.catalog = default_catalog
- table.catalog_perm = catalog_perm
- table.schema_perm = schema_perm
-
- for chart in session.query(Slice).filter_by(
- datasource_id=table.id,
- datasource_type="table",
- ):
- chart.catalog_perm = catalog_perm
- chart.schema_perm = schema_perm
+ update_schema_catalog_perms(session, database, catalog_perm,
default_catalog, False)
# add any new catalogs discovered and their schemas
new_catalog_pvms = add_non_default_catalogs(database, default_catalog,
session)
@@ -233,13 +444,15 @@ def add_non_default_catalogs(
# edited.
return {}
- pvms = {}
+ pvms: dict[str, tuple[str]] = {}
for catalog in catalogs:
- perm = security_manager.get_catalog_perm(database.database_name,
catalog)
- pvms[perm] = ("catalog_access",)
-
- new_schema_pvms = create_schema_perms(database, catalog, session)
- pvms.update(new_schema_pvms)
+ perm: str | None = security_manager.get_catalog_perm(
+ database.database_name, catalog
+ )
+ if perm:
+ pvms[perm] = ("catalog_access",)
+ new_schema_pvms = create_schema_perms(database, catalog)
+ pvms.update(new_schema_pvms)
return pvms
@@ -266,12 +479,12 @@ def upgrade_schema_perms(
perms = {}
for schema in schemas:
- current_perm = security_manager.get_schema_perm(
+ current_perm: str | None = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
- new_perm = security_manager.get_schema_perm(
+ new_perm: str | None = security_manager.get_schema_perm(
database.database_name,
default_catalog,
schema,
@@ -283,7 +496,7 @@ def upgrade_schema_perms(
.one_or_none()
):
existing_pvm.name = new_perm
- else:
+ elif new_perm:
# new schema discovered, need to create a new permission
perms[new_perm] = ("schema_access",)
@@ -293,7 +506,6 @@ def upgrade_schema_perms(
def create_schema_perms(
database: Database,
catalog: str,
- session: Session,
) -> dict[str, tuple[str]]:
"""
Create schema permissions for a given catalog.
@@ -307,12 +519,14 @@ def create_schema_perms(
return {}
return {
- security_manager.get_schema_perm(
- database.database_name,
- catalog,
- schema,
- ): ("schema_access",)
+ perm: ("schema_access",)
for schema in schemas
+ if (
+ perm := security_manager.get_schema_perm(
+ database.database_name, catalog, schema
+ )
+ )
+ is not None
}
@@ -374,49 +588,13 @@ def downgrade_database_catalogs(
# permissions associated with other catalogs
downgrade_schema_perms(database, default_catalog, session)
- # update existing models
- models = [
- (Query, "database_id"),
- (SavedQuery, "db_id"),
- (TabState, "database_id"),
- (TableSchema, "database_id"),
- ]
- for model, column in models:
- for instance in session.query(model).filter(
- getattr(model, column) == database.id,
- model.catalog == default_catalog, # type: ignore
- ):
- instance.catalog = None
+ update_catalog_column(session, database, default_catalog, True)
- # update `schema_perm` for tables and charts
- for table in session.query(SqlaTable).filter_by(
- database_id=database.id,
- catalog=default_catalog,
- ):
- schema_perm = security_manager.get_schema_perm(
- database.database_name,
- None,
- table.schema,
- )
-
- table.catalog = None
- table.catalog_perm = None
- table.schema_perm = schema_perm
-
- for chart in session.query(Slice).filter_by(
- datasource_id=table.id,
- datasource_type="table",
- ):
- chart.catalog_perm = None
- chart.schema_perm = schema_perm
+ # update `schema_perm` and `catalog_perm` for tables and charts
+ update_schema_catalog_perms(session, database, None, default_catalog, True)
# delete models referencing non-default catalogs
- for model, column in models:
- for instance in session.query(model).filter(
- getattr(model, column) == database.id,
- model.catalog != default_catalog, # type: ignore
- ):
- session.delete(instance)
+ delete_models_non_default_catalog(session, database, default_catalog)
# delete datasets and any associated permissions
for table in session.query(SqlaTable).filter(