This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a137820d0 fix: catch some potential errors on dual write (#20351)
5a137820d0 is described below

commit 5a137820d0fd192fe8466e9448a59e327d13eeb5
Author: Elizabeth Thompson <[email protected]>
AuthorDate: Mon Jun 13 17:30:13 2022 -0700

    fix: catch some potential errors on dual write (#20351)
    
    * catch some potential errors on dual write
    
    * fix test for sqlite
---
 superset/connectors/sqla/models.py              | 42 +++++++++------
 superset/connectors/sqla/utils.py               | 11 +++-
 tests/integration_tests/datasets/api_tests.py   |  6 +++
 tests/integration_tests/datasets/model_tests.py | 69 +++++++++++++++++++++++++
 tests/integration_tests/fixtures/datasource.py  | 52 ++++++++++++++++++-
 5 files changed, 163 insertions(+), 17 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 60eff5e630..3b40474331 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -66,6 +66,7 @@ from sqlalchemy import (
 )
 from sqlalchemy.engine.base import Connection
 from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, 
Session
+from sqlalchemy.orm.exc import NoResultFound
 from sqlalchemy.orm.mapper import Mapper
 from sqlalchemy.schema import UniqueConstraint
 from sqlalchemy.sql import column, ColumnElement, literal_column, table
@@ -933,7 +934,8 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
         if sql_query_mutator:
             sql = sql_query_mutator(
                 sql,
-                user_name=get_username(),  # TODO(john-bodley): Deprecate in 
3.0.
+                # TODO(john-bodley): Deprecate in 3.0.
+                user_name=get_username(),
                 security_manager=security_manager,
                 database=self.database,
             )
@@ -2115,7 +2117,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
         ]
 
     @staticmethod
-    def update_table(  # pylint: disable=unused-argument
+    def update_column(  # pylint: disable=unused-argument
         mapper: Mapper, connection: Connection, target: Union[SqlMetric, 
TableColumn]
     ) -> None:
         """
@@ -2130,7 +2132,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
         # table is updated. This busts the cache key for all charts that use 
the table.
         session.execute(update(SqlaTable).where(SqlaTable.id == 
target.table.id))
 
-        # if table itself has changed, shadow-writing will happen in 
`after_udpate` anyway
+        # if table itself has changed, shadow-writing will happen in 
`after_update` anyway
         if target.table not in session.dirty:
             dataset: NewDataset = (
                 session.query(NewDataset)
@@ -2146,17 +2148,27 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
 
             # update changed_on timestamp
             session.execute(update(NewDataset).where(NewDataset.id == 
dataset.id))
-
-            # update `Column` model as well
-            session.add(
-                target.to_sl_column(
-                    {
-                        target.uuid: session.query(NewColumn)
-                        .filter_by(uuid=target.uuid)
-                        .one_or_none()
-                    }
+            try:
+                column = 
session.query(NewColumn).filter_by(uuid=target.uuid).one()
+                # update `Column` model as well
+                session.merge(target.to_sl_column({target.uuid: column}))
+            except NoResultFound:
+                logger.warning("No column was found for %s", target)
+                # see if the column is in cache
+                column = next(
+                    find_cached_objects_in_session(
+                        session, NewColumn, uuids=[target.uuid]
+                    ),
+                    None,
                 )
-            )
+
+                if not column:
+                    # to be safe, use a different uuid and create a new column
+                    uuid = uuid4()
+                    target.uuid = uuid
+                    column = NewColumn(uuid=uuid)
+
+                session.add(target.to_sl_column({column.uuid: column}))
 
     @staticmethod
     def after_insert(
@@ -2441,9 +2453,9 @@ sa.event.listen(SqlaTable, "before_update", 
SqlaTable.before_update)
 sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
 sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete)
 sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update)
-sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table)
+sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column)
 sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete)
-sa.event.listen(TableColumn, "after_update", SqlaTable.update_table)
+sa.event.listen(TableColumn, "after_update", SqlaTable.update_column)
 sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete)
 
 RLSFilterRoles = Table(
diff --git a/superset/connectors/sqla/utils.py 
b/superset/connectors/sqla/utils.py
index 1786c5bf17..69a983156e 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import logging
 from contextlib import closing
 from typing import (
     Any,
@@ -35,6 +36,7 @@ from sqlalchemy.engine.url import URL as SqlaURL
 from sqlalchemy.exc import NoSuchTableError
 from sqlalchemy.ext.declarative import DeclarativeMeta
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.exc import ObjectDeletedError
 from sqlalchemy.sql.type_api import TypeEngine
 
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@@ -191,6 +193,7 @@ def get_identifier_quoter(drivername: str) -> Dict[str, 
Callable[[str], str]]:
 
 
 DeclarativeModel = TypeVar("DeclarativeModel", bound=DeclarativeMeta)
+logger = logging.getLogger(__name__)
 
 
 def find_cached_objects_in_session(
@@ -209,9 +212,15 @@ def find_cached_objects_in_session(
     if not ids and not uuids:
         return iter([])
     uuids = uuids or []
+    try:
+        items = set(session)
+    except ObjectDeletedError:
+        logger.warning("ObjectDeletedError", exc_info=True)
+        return iter(())
+
     return (
         item
         # `session` is an iterator of all known items
-        for item in set(session)
+        for item in items
         if isinstance(item, cls) and (item.id in ids if ids else item.uuid in 
uuids)
     )
diff --git a/tests/integration_tests/datasets/api_tests.py 
b/tests/integration_tests/datasets/api_tests.py
index e378811eb9..b1767bddad 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -35,6 +35,7 @@ from superset.dao.exceptions import (
     DAODeleteFailedError,
     DAOUpdateFailedError,
 )
+from superset.datasets.models import Dataset
 from superset.extensions import db, security_manager
 from superset.models.core import Database
 from superset.utils.core import backend, get_example_default_schema
@@ -1636,16 +1637,21 @@ class TestDatasetApi(SupersetTestCase):
         database = (
             
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
         )
+        shadow_dataset = (
+            
db.session.query(Dataset).filter_by(uuid=dataset_config["uuid"]).one()
+        )
         assert database.database_name == "imported_database"
 
         assert len(database.tables) == 1
         dataset = database.tables[0]
         assert dataset.table_name == "imported_dataset"
         assert str(dataset.uuid) == dataset_config["uuid"]
+        assert str(shadow_dataset.uuid) == dataset_config["uuid"]
 
         dataset.owners = []
         database.owners = []
         db.session.delete(dataset)
+        db.session.delete(shadow_dataset)
         db.session.delete(database)
         db.session.commit()
 
diff --git a/tests/integration_tests/datasets/model_tests.py 
b/tests/integration_tests/datasets/model_tests.py
new file mode 100644
index 0000000000..31abce5494
--- /dev/null
+++ b/tests/integration_tests/datasets/model_tests.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import mock
+
+import pytest
+from sqlalchemy.orm.exc import NoResultFound
+
+from superset.connectors.sqla.models import SqlaTable, TableColumn
+from superset.extensions import db
+from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.fixtures.datasource import 
load_dataset_with_columns
+
+
+class SqlaTableModelTest(SupersetTestCase):
+    @pytest.mark.usefixtures("load_dataset_with_columns")
+    def test_dual_update_column(self) -> None:
+        """
+        Test that when updating a sqla ``TableColumn``
+        That the shadow ``Column`` is also updated
+        """
+        dataset = 
db.session.query(SqlaTable).filter_by(table_name="students").first()
+        column = dataset.columns[0]
+        column_name = column.column_name
+        column.column_name = "new_column_name"
+        SqlaTable.update_column(None, None, target=column)
+
+        # refetch
+        dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
+        assert dataset.columns[0].column_name == "new_column_name"
+
+        # reset
+        column.column_name = column_name
+        SqlaTable.update_column(None, None, target=column)
+
+    @pytest.mark.usefixtures("load_dataset_with_columns")
+    @mock.patch("superset.columns.models.Column")
+    def test_dual_update_column_not_found(self, column_mock) -> None:
+        """
+        Test that when updating a sqla ``TableColumn``
+        That the shadow ``Column`` is also updated
+        """
+        dataset = 
db.session.query(SqlaTable).filter_by(table_name="students").first()
+        column = dataset.columns[0]
+        column_uuid = column.uuid
+        with mock.patch("sqlalchemy.orm.query.Query.one", 
side_effect=NoResultFound):
+            SqlaTable.update_column(None, None, target=column)
+
+        # refetch
+        dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
+        # it should create a new uuid
+        assert dataset.columns[0].uuid != column_uuid
+
+        # reset
+        column.uuid = column_uuid
+        SqlaTable.update_column(None, None, target=column)
diff --git a/tests/integration_tests/fixtures/datasource.py 
b/tests/integration_tests/fixtures/datasource.py
index b6f2476f66..574f43d52b 100644
--- a/tests/integration_tests/fixtures/datasource.py
+++ b/tests/integration_tests/fixtures/datasource.py
@@ -15,10 +15,20 @@
 # specific language governing permissions and limitations
 # under the License.
 """Fixtures for test_datasource.py"""
-from typing import Any, Dict
+from typing import Any, Dict, Generator
 
+import pytest
+from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, 
Table
+from sqlalchemy.ext.declarative.api import declarative_base
+
+from superset.columns.models import Column as Sl_Column
+from superset.connectors.sqla.models import SqlaTable, TableColumn
+from superset.extensions import db
+from superset.models.core import Database
+from superset.tables.models import Table as Sl_Table
 from superset.utils.core import get_example_default_schema
 from superset.utils.database import get_example_database
+from tests.integration_tests.test_app import app
 
 
 def get_datasource_post() -> Dict[str, Any]:
@@ -159,3 +169,43 @@ def get_datasource_post() -> Dict[str, Any]:
             },
         ],
     }
+
+
[email protected]()
+def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
+    with app.app_context():
+        engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], 
echo=True)
+        meta = MetaData()
+        session = db.session
+
+        students = Table(
+            "students",
+            meta,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(255)),
+            Column("lastname", String(255)),
+            Column("ds", Date),
+        )
+        meta.create_all(engine)
+
+        students.insert().values(name="George", ds="2021-01-01")
+
+        dataset = SqlaTable(
+            database_id=db.session.query(Database).first().id, 
table_name="students"
+        )
+        column = TableColumn(table_id=dataset.id, column_name="name")
+        dataset.columns = [column]
+        session.add(dataset)
+        session.commit()
+        yield dataset
+
+        # cleanup
+        students_table = meta.tables.get("students")
+        if students_table is not None:
+            base = declarative_base()
+            # needed for sqlite
+            session.commit()
+            base.metadata.drop_all(engine, [students_table], checkfirst=True)
+        session.delete(dataset)
+        session.delete(column)
+        session.commit()

Reply via email to