This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 4.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 804ecf48d24c65e09482f596f8f4ae5ac88f7eb1
Author: John Bodley <[email protected]>
AuthorDate: Mon May 13 11:55:59 2024 -0700

    fix: Update migration logic in #27119 (#28422)
---
 superset/migrations/shared/utils.py                | 33 +++++++++++----
 ...14-43_17fcea065655_change_text_to_mediumtext.py | 48 ++++++++++++++--------
 superset/models/sql_lab.py                         | 14 +++++--
 superset/utils/core.py                             |  6 ++-
 .../utils/pandas_postprocessing/contribution.py    | 16 ++++----
 5 files changed, 80 insertions(+), 37 deletions(-)

diff --git a/superset/migrations/shared/utils.py 
b/superset/migrations/shared/utils.py
index 2ae0dfeac1..d6a664f330 100644
--- a/superset/migrations/shared/utils.py
+++ b/superset/migrations/shared/utils.py
@@ -35,21 +35,40 @@ logger = logging.getLogger(__name__)
 DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))
 
 
-def table_has_column(table: str, column: str) -> bool:
+def get_table_column(
+    table_name: str,
+    column_name: str,
+) -> Optional[list[dict[str, Any]]]:
     """
-    Checks if a column exists in a given table.
+    Get the specified column.
 
-    :param table: A table name
-    :param column: A column name
-    :returns: True iff the column exists in the table
+    :param table_name: The Table name
+    :param column_name: The column name
+    :returns: The column
     """
 
     insp = inspect(op.get_context().bind)
 
     try:
-        return any(col["name"] == column for col in insp.get_columns(table))
+        for column in insp.get_columns(table_name):
+            if column["name"] == column_name:
+                return column
     except NoSuchTableError:
-        return False
+        pass
+
+    return None
+
+
+def table_has_column(table_name: str, column_name: str) -> bool:
+    """
+    Checks if a column exists in a given table.
+
+    :param table_name: A table name
+    :param column_name: A column name
+    :returns: True iff the column exists in the table
+    """
+
+    return bool(get_table_column(table_name, column_name))
 
 
 uuid_by_dialect = {
diff --git 
a/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py
 
b/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py
index e63ab6ac56..1f4474eeed 100644
--- 
a/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py
+++ 
b/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py
@@ -28,8 +28,10 @@ down_revision = "87d38ad83218"
 
 import sqlalchemy as sa
 from alembic import op
+from sqlalchemy.dialects.mysql import MEDIUMTEXT, TEXT
 from sqlalchemy.dialects.mysql.base import MySQLDialect
 
+from superset.migrations.shared.utils import get_table_column
 from superset.utils.core import MediumText
 
 TABLE_COLUMNS = [
@@ -38,8 +40,6 @@ TABLE_COLUMNS = [
     "dashboards.css",
     "keyvalue.value",
     "query.extra_json",
-    "query.executed_sql",
-    "query.select_sql",
     "report_execution_log.value_row_json",
     "report_recipient.recipient_config_json",
     "report_schedule.sql",
@@ -65,23 +65,35 @@ NOT_NULL_COLUMNS = ["keyvalue.value", 
"row_level_security_filters.clause"]
 
 def upgrade():
     if isinstance(op.get_bind().dialect, MySQLDialect):
-        for column in TABLE_COLUMNS:
-            with op.batch_alter_table(column.split(".")[0]) as batch_op:
-                batch_op.alter_column(
-                    column.split(".")[1],
-                    existing_type=sa.Text(),
-                    type_=MediumText(),
-                    existing_nullable=column not in NOT_NULL_COLUMNS,
-                )
+        for item in TABLE_COLUMNS:
+            table_name, column_name = item.split(".")
+
+            if (column := get_table_column(table_name, column_name)) and 
isinstance(
+                column["type"],
+                TEXT,
+            ):
+                with op.batch_alter_table(table_name) as batch_op:
+                    batch_op.alter_column(
+                        column_name,
+                        existing_type=sa.Text(),
+                        type_=MediumText(),
+                        existing_nullable=item not in NOT_NULL_COLUMNS,
+                    )
 
 
 def downgrade():
     if isinstance(op.get_bind().dialect, MySQLDialect):
-        for column in TABLE_COLUMNS:
-            with op.batch_alter_table(column.split(".")[0]) as batch_op:
-                batch_op.alter_column(
-                    column.split(".")[1],
-                    existing_type=MediumText(),
-                    type_=sa.Text(),
-                    existing_nullable=column not in NOT_NULL_COLUMNS,
-                )
+        for item in TABLE_COLUMNS:
+            table_name, column_name = item.split(".")
+
+            if (column := get_table_column(table_name, column_name)) and 
isinstance(
+                column["type"],
+                MEDIUMTEXT,
+            ):
+                with op.batch_alter_table(table_name) as batch_op:
+                    batch_op.alter_column(
+                        column_name,
+                        existing_type=MediumText(),
+                        type_=sa.Text(),
+                        existing_nullable=item not in NOT_NULL_COLUMNS,
+                    )
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 25c21cdfc8..78c29bd2bb 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -57,7 +57,13 @@ from superset.models.helpers import (
 )
 from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
 from superset.sqllab.limiting_factor import LimitingFactor
-from superset.utils.core import get_column_name, MediumText, QueryStatus, 
user_label
+from superset.utils.core import (
+    get_column_name,
+    LongText,
+    MediumText,
+    QueryStatus,
+    user_label,
+)
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
@@ -107,11 +113,11 @@ class Query(
     tab_name = Column(String(256))
     sql_editor_id = Column(String(256))
     schema = Column(String(256))
-    sql = Column(MediumText())
+    sql = Column(LongText())
     # Query to retrieve the results,
     # used only in case of select_as_cta_used is true.
-    select_sql = Column(MediumText())
-    executed_sql = Column(MediumText())
+    select_sql = Column(LongText())
+    executed_sql = Column(LongText())
     # Could be configured in the superset config.
     limit = Column(Integer)
     limiting_factor = Column(
diff --git a/superset/utils/core.py b/superset/utils/core.py
index f605ef99c1..6649f34717 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -68,7 +68,7 @@ from flask_babel.speaklater import LazyString
 from pandas.api.types import infer_dtype
 from pandas.core.dtypes.common import is_numeric_dtype
 from sqlalchemy import event, exc, inspect, select, Text
-from sqlalchemy.dialects.mysql import MEDIUMTEXT
+from sqlalchemy.dialects.mysql import LONGTEXT, MEDIUMTEXT
 from sqlalchemy.engine import Connection, Engine
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.sql.type_api import Variant
@@ -1469,6 +1469,10 @@ def MediumText() -> Variant:  # 
pylint:disable=invalid-name
     return Text().with_variant(MEDIUMTEXT(), "mysql")
 
 
+def LongText() -> Variant:  # pylint:disable=invalid-name
+    return Text().with_variant(LONGTEXT(), "mysql")
+
+
 def shortid() -> str:
     return f"{uuid.uuid4()}"[-12:]
 
diff --git a/superset/utils/pandas_postprocessing/contribution.py 
b/superset/utils/pandas_postprocessing/contribution.py
index 46144ec019..e0deb08322 100644
--- a/superset/utils/pandas_postprocessing/contribution.py
+++ b/superset/utils/pandas_postprocessing/contribution.py
@@ -74,7 +74,8 @@ def contribution(
     if len(rename_columns) != len(actual_columns):
         raise InvalidPostProcessingError(
             _(
-                "`rename_columns` must have the same length as `columns` + 
`time_shift_columns`."
+                "`rename_columns` must have the same length as "
+                + "`columns` + `time_shift_columns`."
             )
         )
     # limit to selected columns
@@ -105,10 +106,10 @@ def get_column_groups(
     :param df: DataFrame to group columns from
     :param time_shifts: List of time shifts to group by
     :param rename_columns: List of new column names
-    :return: Dictionary with two keys: 'non_time_shift' and 'time_shifts'. 
'non_time_shift'
-    maps to a tuple of original and renamed columns without a time shift. 
'time_shifts' maps
-    to a dictionary where each key is a time shift and each value is a tuple 
of original and
-    renamed columns with that time shift.
+    :return: Dictionary with two keys: 'non_time_shift' and 'time_shifts'.
+    'non_time_shift' maps to a tuple of original and renamed columns without a 
time shift.
+    'time_shifts' maps to a dictionary where each key is a time shift and each 
value is a
+    tuple of original and renamed columns with that time shift.
     """
     result: dict[str, Any] = {
         "non_time_shift": ([], []),  # take the form of ([A, B, C], [X, Y, Z])
@@ -139,8 +140,9 @@ def calculate_row_contribution(
     """
     Calculate the contribution of each column to the row total and update the 
DataFrame.
 
-    This function calculates the contribution of each selected column to the 
total of the row,
-    and updates the DataFrame with these contribution percentages in place of 
the original values.
+    This function calculates the contribution of each selected column to the 
total of
+    the row, and updates the DataFrame with these contribution percentages in 
place of
+    the original values.
 
     :param df: The DataFrame to calculate contributions for.
     :param columns: A list of column names to calculate contributions for.

Reply via email to