This is an automated email from the ASF dual-hosted git repository.

yjc pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e2d3ea831a fix(db): use paginated_update for area chart migration 
(#20761)
e2d3ea831a is described below

commit e2d3ea831a7c634aeb2364a469a142c3514e4cf3
Author: Jesse Yang <[email protected]>
AuthorDate: Tue Jul 19 07:20:46 2022 -0700

    fix(db): use paginated_update for area chart migration (#20761)
---
 superset/migrations/shared/migrate_viz/__init__.py |  17 +++
 .../shared/migrate_viz/base.py}                    | 133 ++++++++++-----------
 .../migrations/shared/migrate_viz/processors.py    |  55 +++++++++
 superset/migrations/shared/utils.py                |  11 +-
 ..._13-00_c747c78868b6_migrating_legacy_treemap.py |  72 +----------
 ...-07_14-00_06e1e70058c7_migrating_legacy_area.py |  71 +----------
 .../06e1e70058c7_migrate_legacy_area__tests.py}    |  18 +--
 ...747c78868b6_migrating_legacy_treemap__tests.py} |  14 +--
 8 files changed, 168 insertions(+), 223 deletions(-)

diff --git a/superset/migrations/shared/migrate_viz/__init__.py 
b/superset/migrations/shared/migrate_viz/__init__.py
new file mode 100644
index 0000000000..aaa860e733
--- /dev/null
+++ b/superset/migrations/shared/migrate_viz/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from .processors import *
diff --git a/superset/utils/migrate_viz.py 
b/superset/migrations/shared/migrate_viz/base.py
similarity index 50%
rename from superset/utils/migrate_viz.py
rename to superset/migrations/shared/migrate_viz/base.py
index 6e59f1257f..024a58463e 100644
--- a/superset/utils/migrate_viz.py
+++ b/superset/migrations/shared/migrate_viz/base.py
@@ -17,21 +17,39 @@
 from __future__ import annotations
 
 import json
-from enum import Enum
-from typing import Dict, Set, Type, TYPE_CHECKING
+from typing import Dict, Set
 
-if TYPE_CHECKING:
-    from superset.models.slice import Slice
+from alembic import op
+from sqlalchemy import and_, Column, Integer, String, Text
+from sqlalchemy.ext.declarative import declarative_base
+
+from superset import db
+from superset.migrations.shared.utils import paginated_update, try_load_json
+
+Base = declarative_base()
+
+
+class Slice(Base):  # type: ignore
+    __tablename__ = "slices"
+
+    id = Column(Integer, primary_key=True)
+    slice_name = Column(String(250))
+    viz_type = Column(String(250))
+    params = Column(Text)
+    query_context = Column(Text)
+
+
+FORM_DATA_BAK_FIELD_NAME = "form_data_bak"
 
 
 class MigrateViz:
     remove_keys: Set[str] = set()
-    mapping_keys: Dict[str, str] = {}
+    rename_keys: Dict[str, str] = {}
     source_viz_type: str
     target_viz_type: str
 
     def __init__(self, form_data: str) -> None:
-        self.data = json.loads(form_data)
+        self.data = try_load_json(form_data)
 
     def _pre_action(self) -> None:
         """some actions before migrate"""
@@ -45,11 +63,11 @@ class MigrateViz:
 
         rv_data = {}
         for key, value in self.data.items():
-            if key in self.mapping_keys and self.mapping_keys[key] in rv_data:
+            if key in self.rename_keys and self.rename_keys[key] in rv_data:
                 raise ValueError("Duplicate key in target viz")
 
-            if key in self.mapping_keys:
-                rv_data[self.mapping_keys[key]] = value
+            if key in self.rename_keys:
+                rv_data[self.rename_keys[key]] = value
 
             if key in self.remove_keys:
                 continue
@@ -62,7 +80,7 @@ class MigrateViz:
         """some actions after migrate"""
 
     @classmethod
-    def upgrade(cls, slc: Slice) -> Slice:
+    def upgrade_slice(cls, slc: Slice) -> Slice:
         clz = cls(slc.params)
         slc.viz_type = cls.target_viz_type
         form_data_bak = clz.data.copy()
@@ -72,77 +90,56 @@ class MigrateViz:
         clz._post_action()
 
         # only backup params
-        slc.params = json.dumps({**clz.data, "form_data_bak": form_data_bak})
+        slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: 
form_data_bak})
 
-        query_context = json.loads(slc.query_context or "{}")
+        query_context = try_load_json(slc.query_context)
         if "form_data" in query_context:
             query_context["form_data"] = clz.data
             slc.query_context = json.dumps(query_context)
         return slc
 
     @classmethod
-    def downgrade(cls, slc: Slice) -> Slice:
-        form_data = json.loads(slc.params)
-        if "form_data_bak" in form_data and "viz_type" in form_data.get(
-            "form_data_bak"
-        ):
-            form_data_bak = form_data["form_data_bak"]
+    def downgrade_slice(cls, slc: Slice) -> Slice:
+        form_data = try_load_json(slc.params)
+        form_data_bak = form_data.get(FORM_DATA_BAK_FIELD_NAME, {})
+        if "viz_type" in form_data_bak:
             slc.params = json.dumps(form_data_bak)
             slc.viz_type = form_data_bak.get("viz_type")
-
-            query_context = json.loads(slc.query_context or "{}")
+            query_context = try_load_json(slc.query_context)
             if "form_data" in query_context:
                 query_context["form_data"] = form_data_bak
                 slc.query_context = json.dumps(query_context)
         return slc
 
-
-class MigrateTreeMap(MigrateViz):
-    source_viz_type = "treemap"
-    target_viz_type = "treemap_v2"
-    remove_keys = {"metrics"}
-
-    def _pre_action(self) -> None:
-        if (
-            "metrics" in self.data
-            and isinstance(self.data["metrics"], list)
-            and len(self.data["metrics"]) > 0
+    @classmethod
+    def upgrade(cls) -> None:
+        bind = op.get_bind()
+        session = db.Session(bind=bind)
+        slices = session.query(Slice).filter(Slice.viz_type == 
cls.source_viz_type)
+        for slc in paginated_update(
+            slices,
+            lambda current, total: print(
+                f"  Updating {current}/{total} charts", end="\r"
+            ),
         ):
-            self.data["metric"] = self.data["metrics"][0]
-
+            new_viz = cls.upgrade_slice(slc)
+            session.merge(new_viz)
 
-class MigrateArea(MigrateViz):
-    source_viz_type = "area"
-    target_viz_type = "echarts_area"
-    remove_keys = {"contribution", "stacked_style", "x_axis_label"}
-
-    def _pre_action(self) -> None:
-        if self.data.get("contribution"):
-            self.data["contributionMode"] = "row"
-
-        stacked = self.data.get("stacked_style")
-        if stacked:
-            stacked_map = {
-                "expand": "Expand",
-                "stack": "Stack",
-            }
-            self.data["show_extra_controls"] = True
-            self.data["stack"] = stacked_map.get(stacked)
-
-        x_axis_label = self.data.get("x_axis_label")
-        if x_axis_label:
-            self.data["x_axis_title"] = x_axis_label
-            self.data["x_axis_title_margin"] = 30
-
-
-# pylint: disable=invalid-name
-class MigrateVizEnum(str, Enum):
-    # the Enum member name is viz_type in database
-    treemap = "treemap"
-    area = "area"
-
-
-get_migrate_class: Dict[MigrateVizEnum, Type[MigrateViz]] = {
-    MigrateVizEnum.treemap: MigrateTreeMap,
-    MigrateVizEnum.area: MigrateArea,
-}
+    @classmethod
+    def downgrade(cls) -> None:
+        bind = op.get_bind()
+        session = db.Session(bind=bind)
+        slices = session.query(Slice).filter(
+            and_(
+                Slice.viz_type == cls.target_viz_type,
+                Slice.params.like(f"%{FORM_DATA_BAK_FIELD_NAME}%"),
+            )
+        )
+        for slc in paginated_update(
+            slices,
+            lambda current, total: print(
+                f"  Downgrading {current}/{total} charts", end="\r"
+            ),
+        ):
+            new_viz = cls.downgrade_slice(slc)
+            session.merge(new_viz)
diff --git a/superset/migrations/shared/migrate_viz/processors.py 
b/superset/migrations/shared/migrate_viz/processors.py
new file mode 100644
index 0000000000..3584856beb
--- /dev/null
+++ b/superset/migrations/shared/migrate_viz/processors.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from .base import MigrateViz
+
+
+class MigrateTreeMap(MigrateViz):
+    source_viz_type = "treemap"
+    target_viz_type = "treemap_v2"
+    remove_keys = {"metrics"}
+
+    def _pre_action(self) -> None:
+        if (
+            "metrics" in self.data
+            and isinstance(self.data["metrics"], list)
+            and len(self.data["metrics"]) > 0
+        ):
+            self.data["metric"] = self.data["metrics"][0]
+
+
+class MigrateAreaChart(MigrateViz):
+    source_viz_type = "area"
+    target_viz_type = "echarts_area"
+    remove_keys = {"contribution", "stacked_style", "x_axis_label"}
+
+    def _pre_action(self) -> None:
+        if self.data.get("contribution"):
+            self.data["contributionMode"] = "row"
+
+        stacked = self.data.get("stacked_style")
+        if stacked:
+            stacked_map = {
+                "expand": "Expand",
+                "stack": "Stack",
+            }
+            self.data["show_extra_controls"] = True
+            self.data["stack"] = stacked_map.get(stacked)
+
+        x_axis_label = self.data.get("x_axis_label")
+        if x_axis_label:
+            self.data["x_axis_title"] = x_axis_label
+            self.data["x_axis_title_margin"] = 30
diff --git a/superset/migrations/shared/utils.py 
b/superset/migrations/shared/utils.py
index 614590409b..14987ea0b4 100644
--- a/superset/migrations/shared/utils.py
+++ b/superset/migrations/shared/utils.py
@@ -14,10 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import json
 import logging
 import os
 import time
-from typing import Any, Callable, Iterator, Optional, Union
+from typing import Any, Callable, Dict, Iterator, Optional, Union
 from uuid import uuid4
 
 from alembic import op
@@ -115,3 +116,11 @@ def paginated_update(
         if print_page_progress:
             print_page_progress(end, count)
         start += batch_size
+
+
+def try_load_json(data: Optional[str]) -> Dict[str, Any]:
+    try:
+        return data and json.loads(data) or {}
+    except json.decoder.JSONDecodeError:
+        print(f"Failed to parse: {data}")
+        return {}
diff --git 
a/superset/migrations/versions/2022-07-07_13-00_c747c78868b6_migrating_legacy_treemap.py
 
b/superset/migrations/versions/2022-07-07_13-00_c747c78868b6_migrating_legacy_treemap.py
index 5f93e7cd75..99a7250e9e 100644
--- 
a/superset/migrations/versions/2022-07-07_13-00_c747c78868b6_migrating_legacy_treemap.py
+++ 
b/superset/migrations/versions/2022-07-07_13-00_c747c78868b6_migrating_legacy_treemap.py
@@ -21,82 +21,16 @@ Revises: cdcf3d64daf4
 Create Date: 2022-06-30 22:04:17.686635
 
 """
+from superset.migrations.shared.migrate_viz import MigrateTreeMap
 
 # revision identifiers, used by Alembic.
-
 revision = "c747c78868b6"
 down_revision = "cdcf3d64daf4"
 
-from alembic import op
-from sqlalchemy import and_, Column, Integer, String, Text
-from sqlalchemy.ext.declarative import declarative_base
-
-from superset import db
-from superset.utils.migrate_viz import get_migrate_class, MigrateVizEnum
-
-treemap_processor = get_migrate_class[MigrateVizEnum.treemap]
-
-Base = declarative_base()
-
-
-class Slice(Base):
-    __tablename__ = "slices"
-
-    id = Column(Integer, primary_key=True)
-    slice_name = Column(String(250))
-    viz_type = Column(String(250))
-    params = Column(Text)
-    query_context = Column(Text)
-
 
 def upgrade():
-    bind = op.get_bind()
-    session = db.Session(bind=bind)
-
-    slices = session.query(Slice).filter(
-        Slice.viz_type == treemap_processor.source_viz_type
-    )
-    total = slices.count()
-    idx = 0
-    for slc in slices.yield_per(1000):
-        try:
-            idx += 1
-            print(f"Upgrading ({idx}/{total}): {slc.slice_name}#{slc.id}")
-            new_viz = treemap_processor.upgrade(slc)
-            session.merge(new_viz)
-        except Exception as exc:
-            print(
-                "Error while processing migration: '{}'\nError: {}\n".format(
-                    slc.slice_name, str(exc)
-                )
-            )
-    session.commit()
-    session.close()
+    MigrateTreeMap.upgrade()
 
 
 def downgrade():
-    bind = op.get_bind()
-    session = db.Session(bind=bind)
-
-    slices = session.query(Slice).filter(
-        and_(
-            Slice.viz_type == treemap_processor.target_viz_type,
-            Slice.params.like("%form_data_bak%"),
-        )
-    )
-    total = slices.count()
-    idx = 0
-    for slc in slices.yield_per(1000):
-        try:
-            idx += 1
-            print(f"Downgrading ({idx}/{total}): {slc.slice_name}#{slc.id}")
-            new_viz = treemap_processor.downgrade(slc)
-            session.merge(new_viz)
-        except Exception as exc:
-            print(
-                "Error while processing migration: '{}'\nError: {}\n".format(
-                    slc.slice_name, str(exc)
-                )
-            )
-    session.commit()
-    session.close()
+    MigrateTreeMap.downgrade()
diff --git 
a/superset/migrations/versions/2022-07-07_14-00_06e1e70058c7_migrating_legacy_area.py
 
b/superset/migrations/versions/2022-07-07_14-00_06e1e70058c7_migrating_legacy_area.py
index 3def02268d..de40780991 100644
--- 
a/superset/migrations/versions/2022-07-07_14-00_06e1e70058c7_migrating_legacy_area.py
+++ 
b/superset/migrations/versions/2022-07-07_14-00_06e1e70058c7_migrating_legacy_area.py
@@ -21,81 +21,16 @@ Revises: c747c78868b6
 Create Date: 2022-06-13 14:17:51.872706
 
 """
+from superset.migrations.shared.migrate_viz import MigrateAreaChart
 
 # revision identifiers, used by Alembic.
 revision = "06e1e70058c7"
 down_revision = "c747c78868b6"
 
-from alembic import op
-from sqlalchemy import and_, Column, Integer, String, Text
-from sqlalchemy.ext.declarative import declarative_base
-
-from superset import db
-from superset.utils.migrate_viz import get_migrate_class, MigrateVizEnum
-
-area_processor = get_migrate_class[MigrateVizEnum.area]
-
-Base = declarative_base()
-
-
-class Slice(Base):
-    __tablename__ = "slices"
-
-    id = Column(Integer, primary_key=True)
-    slice_name = Column(String(250))
-    viz_type = Column(String(250))
-    params = Column(Text)
-    query_context = Column(Text)
-
 
 def upgrade():
-    bind = op.get_bind()
-    session = db.Session(bind=bind)
-
-    slices = session.query(Slice).filter(
-        Slice.viz_type == area_processor.source_viz_type
-    )
-    total = slices.count()
-    idx = 0
-    for slc in slices.yield_per(1000):
-        try:
-            idx += 1
-            print(f"Upgrading ({idx}/{total}): {slc.slice_name}#{slc.id}")
-            new_viz = area_processor.upgrade(slc)
-            session.merge(new_viz)
-        except Exception as exc:
-            print(
-                "Error while processing migration: '{}'\nError: {}\n".format(
-                    slc.slice_name, str(exc)
-                )
-            )
-    session.commit()
-    session.close()
+    MigrateAreaChart.upgrade()
 
 
 def downgrade():
-    bind = op.get_bind()
-    session = db.Session(bind=bind)
-
-    slices = session.query(Slice).filter(
-        and_(
-            Slice.viz_type == area_processor.target_viz_type,
-            Slice.params.like("%form_data_bak%"),
-        )
-    )
-    total = slices.count()
-    idx = 0
-    for slc in slices.yield_per(1000):
-        try:
-            idx += 1
-            print(f"Downgrading ({idx}/{total}): {slc.slice_name}#{slc.id}")
-            new_viz = area_processor.downgrade(slc)
-            session.merge(new_viz)
-        except Exception as exc:
-            print(
-                "Error while processing migration: '{}'\nError: {}\n".format(
-                    slc.slice_name, str(exc)
-                )
-            )
-    session.commit()
-    session.close()
+    MigrateAreaChart.downgrade()
diff --git a/tests/unit_tests/utils/viz_migration/area_migration_test.py 
b/tests/integration_tests/migrations/06e1e70058c7_migrate_legacy_area__tests.py
similarity index 86%
rename from tests/unit_tests/utils/viz_migration/area_migration_test.py
rename to 
tests/integration_tests/migrations/06e1e70058c7_migrate_legacy_area__tests.py
index 8857a96c94..f02d069b2b 100644
--- a/tests/unit_tests/utils/viz_migration/area_migration_test.py
+++ 
b/tests/integration_tests/migrations/06e1e70058c7_migrate_legacy_area__tests.py
@@ -17,7 +17,7 @@
 import json
 
 from superset.app import SupersetApp
-from superset.utils.migrate_viz import get_migrate_class, MigrateVizEnum
+from superset.migrations.shared.migrate_viz import MigrateAreaChart
 
 area_form_data = """{
   "adhoc_filters": [],
@@ -60,21 +60,19 @@ area_form_data = """{
 }
 """
 
-area_processor = get_migrate_class[MigrateVizEnum.area]
-
 
 def test_area_migrate(app_context: SupersetApp) -> None:
     from superset.models.slice import Slice
 
     slc = Slice(
-        viz_type="area",
+        viz_type=MigrateAreaChart.source_viz_type,
         datasource_type="table",
         params=area_form_data,
         query_context=f'{{"form_data": {area_form_data}}}',
     )
 
-    slc = area_processor.upgrade(slc)
-    assert slc.viz_type == area_processor.target_viz_type
+    slc = MigrateAreaChart.upgrade_slice(slc)
+    assert slc.viz_type == MigrateAreaChart.target_viz_type
     # verify form_data
     new_form_data = json.loads(slc.params)
     assert new_form_data["contributionMode"] == "row"
@@ -89,11 +87,13 @@ def test_area_migrate(app_context: SupersetApp) -> None:
 
     # verify query_context
     new_query_context = json.loads(slc.query_context)
-    assert new_query_context["form_data"]["viz_type"] == 
area_processor.target_viz_type
+    assert (
+        new_query_context["form_data"]["viz_type"] == 
MigrateAreaChart.target_viz_type
+    )
 
     # downgrade
-    slc = area_processor.downgrade(slc)
-    assert slc.viz_type == area_processor.source_viz_type
+    slc = MigrateAreaChart.downgrade_slice(slc)
+    assert slc.viz_type == MigrateAreaChart.source_viz_type
     assert json.dumps(json.loads(slc.params), sort_keys=True) == json.dumps(
         json.loads(area_form_data), sort_keys=True
     )
diff --git a/tests/unit_tests/utils/viz_migration/treemap_migration_test.py 
b/tests/integration_tests/migrations/c747c78868b6_migrating_legacy_treemap__tests.py
similarity index 87%
rename from tests/unit_tests/utils/viz_migration/treemap_migration_test.py
rename to 
tests/integration_tests/migrations/c747c78868b6_migrating_legacy_treemap__tests.py
index 4bec5dec83..3e9ef33092 100644
--- a/tests/unit_tests/utils/viz_migration/treemap_migration_test.py
+++ 
b/tests/integration_tests/migrations/c747c78868b6_migrating_legacy_treemap__tests.py
@@ -17,7 +17,7 @@
 import json
 
 from superset.app import SupersetApp
-from superset.utils.migrate_viz import get_migrate_class, MigrateVizEnum
+from superset.migrations.shared.migrate_viz import MigrateTreeMap
 
 treemap_form_data = """{
   "adhoc_filters": [
@@ -57,21 +57,19 @@ treemap_form_data = """{
 }
 """
 
-treemap_processor = get_migrate_class[MigrateVizEnum.treemap]
-
 
 def test_treemap_migrate(app_context: SupersetApp) -> None:
     from superset.models.slice import Slice
 
     slc = Slice(
-        viz_type="treemap",
+        viz_type=MigrateTreeMap.source_viz_type,
         datasource_type="table",
         params=treemap_form_data,
         query_context=f'{{"form_data": {treemap_form_data}}}',
     )
 
-    slc = treemap_processor.upgrade(slc)
-    assert slc.viz_type == treemap_processor.target_viz_type
+    slc = MigrateTreeMap.upgrade_slice(slc)
+    assert slc.viz_type == MigrateTreeMap.target_viz_type
     # verify form_data
     new_form_data = json.loads(slc.params)
     assert new_form_data["metric"] == "sum__num"
@@ -86,8 +84,8 @@ def test_treemap_migrate(app_context: SupersetApp) -> None:
     assert new_query_context["form_data"]["viz_type"] == "treemap_v2"
 
     # downgrade
-    slc = treemap_processor.downgrade(slc)
-    assert slc.viz_type == treemap_processor.source_viz_type
+    slc = MigrateTreeMap.downgrade_slice(slc)
+    assert slc.viz_type == MigrateTreeMap.source_viz_type
     assert json.dumps(json.loads(slc.params), sort_keys=True) == json.dumps(
         json.loads(treemap_form_data), sort_keys=True
     )

Reply via email to