This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e8f5d7680f fix: upgrade_catalog_perms and downgrade_catalog_perms 
implementation (#29860)
e8f5d7680f is described below

commit e8f5d7680ff14342b2ed46cc0b8c3bd4463fa3c2
Author: Michael S. Molina <[email protected]>
AuthorDate: Fri Aug 16 08:39:36 2024 -0400

    fix: upgrade_catalog_perms and downgrade_catalog_perms implementation 
(#29860)
---
 superset/migrations/shared/catalogs.py | 360 ++++++++++++++++++++++++---------
 1 file changed, 269 insertions(+), 91 deletions(-)

diff --git a/superset/migrations/shared/catalogs.py 
b/superset/migrations/shared/catalogs.py
index b09c71739f..b75214291b 100644
--- a/superset/migrations/shared/catalogs.py
+++ b/superset/migrations/shared/catalogs.py
@@ -18,7 +18,8 @@
 from __future__ import annotations
 
 import logging
-from typing import Any, Type
+from datetime import datetime
+from typing import Any, Type, Union
 
 import sqlalchemy as sa
 from alembic import op
@@ -35,8 +36,7 @@ from superset.migrations.shared.security_converge import (
 )
 from superset.models.core import Database
 
-logger = logging.getLogger(__name__)
-
+logger = logging.getLogger("alembic")
 
 Base: Type[Any] = declarative_base()
 
@@ -95,6 +95,16 @@ class Slice(Base):
     schema_perm = sa.Column(sa.String(1000))
 
 
+ModelType = Union[Type[Query], Type[SavedQuery], Type[TabState], 
Type[TableSchema]]
+
+MODELS: list[tuple[ModelType, str]] = [
+    (Query, "database_id"),
+    (SavedQuery, "db_id"),
+    (TabState, "database_id"),
+    (TableSchema, "database_id"),
+]
+
+
 def get_known_schemas(database_name: str, session: Session) -> list[str]:
     """
     Read all known schemas from the existing schema permissions.
@@ -112,6 +122,234 @@ def get_known_schemas(database_name: str, session: 
Session) -> list[str]:
     return sorted({name[0][1:-1].split("].[")[-1] for name in names})
 
 
+def get_batch_size(session: Session) -> int:
+    max_sqlite_in = 999
+    return max_sqlite_in if session.bind.dialect.name == "sqlite" else 
1_000_000
+
+
+def print_processed_batch(
+    start_time: datetime,
+    offset: int,
+    total_rows: int,
+    model: ModelType,
+    batch_size: int,
+) -> None:
+    """
+    Print the progress of batch processing.
+
+    This function logs the progress of processing a batch of rows from a model.
+    It calculates the elapsed time since the start of the batch processing and
+    logs the number of rows processed along with the percentage completion.
+
+    Parameters:
+        start_time (datetime): The start time of the batch processing.
+        offset (int): The current offset in the batch processing.
+        total_rows (int): The total number of rows to process.
+        model (ModelType): The model being processed.
+        batch_size (int): The size of the batch being processed.
+    """
+    elapsed_time = datetime.now() - start_time
+    elapsed_seconds = elapsed_time.total_seconds()
+    elapsed_formatted = f"{int(elapsed_seconds // 
3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds % 
60):02}"
+    rows_processed = min(offset + batch_size, total_rows)
+    logger.info(
+        f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,} 
{model.__tablename__} rows processed "
+        f"({(rows_processed / total_rows) * 100:.2f}%)"
+    )
+
+
+def update_catalog_column(
+    session: Session, database: Database, catalog: str, downgrade: bool = False
+) -> None:
+    """
+    Update the `catalog` column in the specified models to the given catalog.
+
+    This function iterates over a list of models defined by MODELS and updates
+    the `catalog` columnto the specified catalog or None depending on the 
downgrade
+    parameter. The update is performed in batches to optimize performance and 
reduce
+    memory usage.
+
+    Parameters:
+        session (Session): The SQLAlchemy session to use for database 
operations.
+        database (Database): The database instance containing the models to 
update.
+        catalog (Catalog): The new catalog value to set in the `catalog` 
column or
+            the default catalog if `downgrade` is True.
+        downgrade (bool): If True, the `catalog` column is set to None where 
the
+            catalog matches the specified catalog.
+    """
+    start_time = datetime.now()
+
+    logger.info(f"Updating {database.database_name} models to catalog 
{catalog}")
+
+    for model, column in MODELS:
+        # Get the total number of rows that match the condition
+        total_rows = (
+            session.query(sa.func.count(model.id))
+            .filter(getattr(model, column) == database.id)
+            .filter(model.catalog == catalog if downgrade else True)
+            .scalar()
+        )
+
+        logger.info(
+            f"Total rows to be processed for {model.__tablename__}: 
{total_rows:,}"
+        )
+
+        batch_size = get_batch_size(session)
+        limit_value = min(batch_size, total_rows)
+
+        # Update in batches using row numbers
+        for i in range(0, total_rows, batch_size):
+            subquery = (
+                session.query(model.id)
+                .filter(getattr(model, column) == database.id)
+                .filter(model.catalog == catalog if downgrade else True)
+                .order_by(model.id)
+                .offset(i)
+                .limit(limit_value)
+                .subquery()
+            )
+
+            # SQLite does not support multiple-table criteria within UPDATE
+            if session.bind.dialect.name == "sqlite":
+                ids_to_update = [row.id for row in 
session.query(subquery.c.id).all()]
+                if ids_to_update:
+                    session.execute(
+                        sa.update(model)
+                        .where(model.id.in_(ids_to_update))
+                        .values(catalog=None if downgrade else catalog)
+                        .execution_options(synchronize_session=False)
+                    )
+            else:
+                session.execute(
+                    sa.update(model)
+                    .where(model.id == subquery.c.id)
+                    .values(catalog=None if downgrade else catalog)
+                    .execution_options(synchronize_session=False)
+                )
+
+            print_processed_batch(start_time, i, total_rows, model, batch_size)
+
+
+def update_schema_catalog_perms(
+    session: Session,
+    database: Database,
+    catalog_perm: str | None,
+    catalog: str,
+    downgrade: bool = False,
+) -> None:
+    """
+    Update schema and catalog permissions for tables and charts in a given 
database.
+
+    This function updates the `catalog`, `catalog_perm`, and `schema_perm` 
fields for
+    tables and charts associated with the specified database. If `downgrade` 
is True,
+    the `catalog` and `catalog_perm` fields are set to None, otherwise they 
are set
+    to the provided `catalog` and `catalog_perm` values.
+
+    Args:
+        session (Session): The SQLAlchemy session to use for database 
operations.
+        database (Database): The database object whose tables and charts will 
be updated.
+        catalog_perm (str): The new catalog permission to set.
+        catalog (str): The new catalog to set.
+        downgrade (bool, optional): If True, reset the `catalog` and 
`catalog_perm` fields to None.
+                                    Defaults to False.
+    """
+    # Mapping of table id to schema permission
+    mapping = {}
+
+    for table in (
+        session.query(SqlaTable)
+        .filter_by(database_id=database.id)
+        .filter_by(catalog=catalog if downgrade else None)
+    ):
+        schema_perm = security_manager.get_schema_perm(
+            database.database_name,
+            None if downgrade else catalog,
+            table.schema,
+        )
+        table.catalog = None if downgrade else catalog
+        table.catalog_perm = catalog_perm
+        table.schema_perm = schema_perm
+        mapping[table.id] = schema_perm
+
+    # Select all slices of type table that belong to the database
+    for chart in (
+        session.query(Slice)
+        .join(SqlaTable, Slice.datasource_id == SqlaTable.id)
+        .join(Database, SqlaTable.database_id == Database.id)
+        .filter(Database.id == database.id)
+        .filter(Slice.datasource_type == "table")
+    ):
+        # We only care about tables that exist in the mapping
+        if mapping.get(chart.datasource_id) is not None:
+            chart.catalog_perm = catalog_perm
+            chart.schema_perm = mapping[chart.datasource_id]
+
+
+def delete_models_non_default_catalog(
+    session: Session, database: Database, catalog: str
+) -> None:
+    """
+    Delete models that are not in the default catalog.
+
+    This function iterates over a list of models defined by MODELS and deletes
+    the rows where the `catalog` column does not match the specified catalog.
+
+    Parameters:
+        session (Session): The SQLAlchemy session to use for database 
operations.
+        database (Database): The database instance containing the models to 
delete.
+        catalog (Catalog): The catalog to use to filter the models to delete.
+    """
+    start_time = datetime.now()
+
+    logger.info(f"Deleting models not in the default catalog: {catalog}")
+
+    for model, column in MODELS:
+        # Get the total number of rows that match the condition
+        total_rows = (
+            session.query(sa.func.count(model.id))
+            .filter(getattr(model, column) == database.id)
+            .filter(model.catalog != catalog)
+            .scalar()
+        )
+
+        logger.info(
+            f"Total rows to be processed for {model.__tablename__}: 
{total_rows:,}"
+        )
+
+        batch_size = get_batch_size(session)
+        limit_value = min(batch_size, total_rows)
+
+        # Update in batches using row numbers
+        for i in range(0, total_rows, batch_size):
+            subquery = (
+                session.query(model.id)
+                .filter(getattr(model, column) == database.id)
+                .filter(model.catalog != catalog)
+                .order_by(model.id)
+                .offset(i)
+                .limit(limit_value)
+                .subquery()
+            )
+
+            # SQLite does not support multiple-table criteria within DELETE
+            if session.bind.dialect.name == "sqlite":
+                ids_to_delete = [row.id for row in 
session.query(subquery.c.id).all()]
+                if ids_to_delete:
+                    session.execute(
+                        sa.delete(model)
+                        .where(model.id.in_(ids_to_delete))
+                        .execution_options(synchronize_session=False)
+                    )
+            else:
+                session.execute(
+                    sa.delete(model)
+                    .where(model.id == subquery.c.id)
+                    .execution_options(synchronize_session=False)
+                )
+
+            print_processed_batch(start_time, i, total_rows, model, batch_size)
+
+
 def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
     """
     Update models and permissions when catalogs are introduced in a DB engine 
spec.
@@ -157,11 +395,13 @@ def upgrade_database_catalogs(
     """
     Upgrade a given database to support the default catalog.
     """
-    catalog_perm = security_manager.get_catalog_perm(
+    catalog_perm: str | None = security_manager.get_catalog_perm(
         database.database_name,
         default_catalog,
     )
-    pvms: dict[str, tuple[str, ...]] = {catalog_perm: ("catalog_access",)}
+    pvms: dict[str, tuple[str, ...]] = (
+        {catalog_perm: ("catalog_access",)} if catalog_perm else {}
+    )
 
     # rename existing schema permissions to include the catalog, and also find 
any new
     # schemas
@@ -170,39 +410,10 @@ def upgrade_database_catalogs(
 
     # update existing models that have a `catalog` column so it points to the 
default
     # catalog
-    models = [
-        (Query, "database_id"),
-        (SavedQuery, "db_id"),
-        (TabState, "database_id"),
-        (TableSchema, "database_id"),
-    ]
-    for model, column in models:
-        for instance in session.query(model).filter(
-            getattr(model, column) == database.id
-        ):
-            instance.catalog = default_catalog
+    update_catalog_column(session, database, default_catalog, False)
 
     # update `schema_perm` and `catalog_perm` for tables and charts
-    for table in session.query(SqlaTable).filter_by(
-        database_id=database.id,
-        catalog=None,
-    ):
-        schema_perm = security_manager.get_schema_perm(
-            database.database_name,
-            default_catalog,
-            table.schema,
-        )
-
-        table.catalog = default_catalog
-        table.catalog_perm = catalog_perm
-        table.schema_perm = schema_perm
-
-        for chart in session.query(Slice).filter_by(
-            datasource_id=table.id,
-            datasource_type="table",
-        ):
-            chart.catalog_perm = catalog_perm
-            chart.schema_perm = schema_perm
+    update_schema_catalog_perms(session, database, catalog_perm, 
default_catalog, False)
 
     # add any new catalogs discovered and their schemas
     new_catalog_pvms = add_non_default_catalogs(database, default_catalog, 
session)
@@ -233,13 +444,15 @@ def add_non_default_catalogs(
         # edited.
         return {}
 
-    pvms = {}
+    pvms: dict[str, tuple[str]] = {}
     for catalog in catalogs:
-        perm = security_manager.get_catalog_perm(database.database_name, 
catalog)
-        pvms[perm] = ("catalog_access",)
-
-        new_schema_pvms = create_schema_perms(database, catalog, session)
-        pvms.update(new_schema_pvms)
+        perm: str | None = security_manager.get_catalog_perm(
+            database.database_name, catalog
+        )
+        if perm:
+            pvms[perm] = ("catalog_access",)
+            new_schema_pvms = create_schema_perms(database, catalog)
+            pvms.update(new_schema_pvms)
 
     return pvms
 
@@ -266,12 +479,12 @@ def upgrade_schema_perms(
 
     perms = {}
     for schema in schemas:
-        current_perm = security_manager.get_schema_perm(
+        current_perm: str | None = security_manager.get_schema_perm(
             database.database_name,
             None,
             schema,
         )
-        new_perm = security_manager.get_schema_perm(
+        new_perm: str | None = security_manager.get_schema_perm(
             database.database_name,
             default_catalog,
             schema,
@@ -283,7 +496,7 @@ def upgrade_schema_perms(
             .one_or_none()
         ):
             existing_pvm.name = new_perm
-        else:
+        elif new_perm:
             # new schema discovered, need to create a new permission
             perms[new_perm] = ("schema_access",)
 
@@ -293,7 +506,6 @@ def upgrade_schema_perms(
 def create_schema_perms(
     database: Database,
     catalog: str,
-    session: Session,
 ) -> dict[str, tuple[str]]:
     """
     Create schema permissions for a given catalog.
@@ -307,12 +519,14 @@ def create_schema_perms(
         return {}
 
     return {
-        security_manager.get_schema_perm(
-            database.database_name,
-            catalog,
-            schema,
-        ): ("schema_access",)
+        perm: ("schema_access",)
         for schema in schemas
+        if (
+            perm := security_manager.get_schema_perm(
+                database.database_name, catalog, schema
+            )
+        )
+        is not None
     }
 
 
@@ -374,49 +588,13 @@ def downgrade_database_catalogs(
     # permissions associated with other catalogs
     downgrade_schema_perms(database, default_catalog, session)
 
-    # update existing models
-    models = [
-        (Query, "database_id"),
-        (SavedQuery, "db_id"),
-        (TabState, "database_id"),
-        (TableSchema, "database_id"),
-    ]
-    for model, column in models:
-        for instance in session.query(model).filter(
-            getattr(model, column) == database.id,
-            model.catalog == default_catalog,  # type: ignore
-        ):
-            instance.catalog = None
+    update_catalog_column(session, database, default_catalog, True)
 
-    # update `schema_perm` for tables and charts
-    for table in session.query(SqlaTable).filter_by(
-        database_id=database.id,
-        catalog=default_catalog,
-    ):
-        schema_perm = security_manager.get_schema_perm(
-            database.database_name,
-            None,
-            table.schema,
-        )
-
-        table.catalog = None
-        table.catalog_perm = None
-        table.schema_perm = schema_perm
-
-        for chart in session.query(Slice).filter_by(
-            datasource_id=table.id,
-            datasource_type="table",
-        ):
-            chart.catalog_perm = None
-            chart.schema_perm = schema_perm
+    # update `schema_perm` and `catalog_perm` for tables and charts
+    update_schema_catalog_perms(session, database, None, default_catalog, True)
 
     # delete models referencing non-default catalogs
-    for model, column in models:
-        for instance in session.query(model).filter(
-            getattr(model, column) == database.id,
-            model.catalog != default_catalog,  # type: ignore
-        ):
-            session.delete(instance)
+    delete_models_non_default_catalog(session, database, default_catalog)
 
     # delete datasets and any associated permissions
     for table in session.query(SqlaTable).filter(

Reply via email to