This is an automated email from the ASF dual-hosted git repository.
elizabeth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 5a137820d0 fix: catch some potential errors on dual write (#20351)
5a137820d0 is described below
commit 5a137820d0fd192fe8466e9448a59e327d13eeb5
Author: Elizabeth Thompson <[email protected]>
AuthorDate: Mon Jun 13 17:30:13 2022 -0700
fix: catch some potential errors on dual write (#20351)
* catch some potential errors on dual write
* fix test for sqlite
---
superset/connectors/sqla/models.py | 42 +++++++++------
superset/connectors/sqla/utils.py | 11 +++-
tests/integration_tests/datasets/api_tests.py | 6 +++
tests/integration_tests/datasets/model_tests.py | 69 +++++++++++++++++++++++++
tests/integration_tests/fixtures/datasource.py | 52 ++++++++++++++++++-
5 files changed, 163 insertions(+), 17 deletions(-)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 60eff5e630..3b40474331 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -66,6 +66,7 @@ from sqlalchemy import (
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty,
Session
+from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
@@ -933,7 +934,8 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
if sql_query_mutator:
sql = sql_query_mutator(
sql,
- user_name=get_username(), # TODO(john-bodley): Deprecate in
3.0.
+ # TODO(john-bodley): Deprecate in 3.0.
+ user_name=get_username(),
security_manager=security_manager,
database=self.database,
)
@@ -2115,7 +2117,7 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
]
@staticmethod
- def update_table( # pylint: disable=unused-argument
+ def update_column( # pylint: disable=unused-argument
mapper: Mapper, connection: Connection, target: Union[SqlMetric,
TableColumn]
) -> None:
"""
@@ -2130,7 +2132,7 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
# table is updated. This busts the cache key for all charts that use
the table.
session.execute(update(SqlaTable).where(SqlaTable.id ==
target.table.id))
- # if table itself has changed, shadow-writing will happen in
`after_udpate` anyway
+ # if table itself has changed, shadow-writing will happen in
`after_update` anyway
if target.table not in session.dirty:
dataset: NewDataset = (
session.query(NewDataset)
@@ -2146,17 +2148,27 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
# update changed_on timestamp
session.execute(update(NewDataset).where(NewDataset.id ==
dataset.id))
-
- # update `Column` model as well
- session.add(
- target.to_sl_column(
- {
- target.uuid: session.query(NewColumn)
- .filter_by(uuid=target.uuid)
- .one_or_none()
- }
+ try:
+ column =
session.query(NewColumn).filter_by(uuid=target.uuid).one()
+ # update `Column` model as well
+ session.merge(target.to_sl_column({target.uuid: column}))
+ except NoResultFound:
+ logger.warning("No column was found for %s", target)
+ # see if the column is in cache
+ column = next(
+ find_cached_objects_in_session(
+ session, NewColumn, uuids=[target.uuid]
+ ),
+ None,
)
- )
+
+ if not column:
+ # to be safe, use a different uuid and create a new column
+ uuid = uuid4()
+ target.uuid = uuid
+ column = NewColumn(uuid=uuid)
+
+ session.add(target.to_sl_column({column.uuid: column}))
@staticmethod
def after_insert(
@@ -2441,9 +2453,9 @@ sa.event.listen(SqlaTable, "before_update",
SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete)
sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update)
-sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table)
+sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column)
sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete)
-sa.event.listen(TableColumn, "after_update", SqlaTable.update_table)
+sa.event.listen(TableColumn, "after_update", SqlaTable.update_column)
sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete)
RLSFilterRoles = Table(
diff --git a/superset/connectors/sqla/utils.py
b/superset/connectors/sqla/utils.py
index 1786c5bf17..69a983156e 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import logging
from contextlib import closing
from typing import (
Any,
@@ -35,6 +36,7 @@ from sqlalchemy.engine.url import URL as SqlaURL
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session
+from sqlalchemy.orm.exc import ObjectDeletedError
from sqlalchemy.sql.type_api import TypeEngine
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@@ -191,6 +193,7 @@ def get_identifier_quoter(drivername: str) -> Dict[str,
Callable[[str], str]]:
DeclarativeModel = TypeVar("DeclarativeModel", bound=DeclarativeMeta)
+logger = logging.getLogger(__name__)
def find_cached_objects_in_session(
@@ -209,9 +212,15 @@ def find_cached_objects_in_session(
if not ids and not uuids:
return iter([])
uuids = uuids or []
+ try:
+ items = set(session)
+ except ObjectDeletedError:
+ logger.warning("ObjectDeletedError", exc_info=True)
+ return iter(())
+
return (
item
# `session` is an iterator of all known items
- for item in set(session)
+ for item in items
if isinstance(item, cls) and (item.id in ids if ids else item.uuid in
uuids)
)
diff --git a/tests/integration_tests/datasets/api_tests.py
b/tests/integration_tests/datasets/api_tests.py
index e378811eb9..b1767bddad 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -35,6 +35,7 @@ from superset.dao.exceptions import (
DAODeleteFailedError,
DAOUpdateFailedError,
)
+from superset.datasets.models import Dataset
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import backend, get_example_default_schema
@@ -1636,16 +1637,21 @@ class TestDatasetApi(SupersetTestCase):
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
+ shadow_dataset = (
+
db.session.query(Dataset).filter_by(uuid=dataset_config["uuid"]).one()
+ )
assert database.database_name == "imported_database"
assert len(database.tables) == 1
dataset = database.tables[0]
assert dataset.table_name == "imported_dataset"
assert str(dataset.uuid) == dataset_config["uuid"]
+ assert str(shadow_dataset.uuid) == dataset_config["uuid"]
dataset.owners = []
database.owners = []
db.session.delete(dataset)
+ db.session.delete(shadow_dataset)
db.session.delete(database)
db.session.commit()
diff --git a/tests/integration_tests/datasets/model_tests.py
b/tests/integration_tests/datasets/model_tests.py
new file mode 100644
index 0000000000..31abce5494
--- /dev/null
+++ b/tests/integration_tests/datasets/model_tests.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import mock
+
+import pytest
+from sqlalchemy.orm.exc import NoResultFound
+
+from superset.connectors.sqla.models import SqlaTable, TableColumn
+from superset.extensions import db
+from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.fixtures.datasource import
load_dataset_with_columns
+
+
+class SqlaTableModelTest(SupersetTestCase):
+ @pytest.mark.usefixtures("load_dataset_with_columns")
+ def test_dual_update_column(self) -> None:
+ """
+ Test that when updating a sqla ``TableColumn``
+ That the shadow ``Column`` is also updated
+ """
+ dataset =
db.session.query(SqlaTable).filter_by(table_name="students").first()
+ column = dataset.columns[0]
+ column_name = column.column_name
+ column.column_name = "new_column_name"
+ SqlaTable.update_column(None, None, target=column)
+
+ # refetch
+ dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
+ assert dataset.columns[0].column_name == "new_column_name"
+
+ # reset
+ column.column_name = column_name
+ SqlaTable.update_column(None, None, target=column)
+
+ @pytest.mark.usefixtures("load_dataset_with_columns")
+ @mock.patch("superset.columns.models.Column")
+ def test_dual_update_column_not_found(self, column_mock) -> None:
+ """
+ Test that when updating a sqla ``TableColumn``
+ That the shadow ``Column`` is also updated
+ """
+ dataset =
db.session.query(SqlaTable).filter_by(table_name="students").first()
+ column = dataset.columns[0]
+ column_uuid = column.uuid
+ with mock.patch("sqlalchemy.orm.query.Query.one",
side_effect=NoResultFound):
+ SqlaTable.update_column(None, None, target=column)
+
+ # refetch
+ dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
+ # it should create a new uuid
+ assert dataset.columns[0].uuid != column_uuid
+
+ # reset
+ column.uuid = column_uuid
+ SqlaTable.update_column(None, None, target=column)
diff --git a/tests/integration_tests/fixtures/datasource.py
b/tests/integration_tests/fixtures/datasource.py
index b6f2476f66..574f43d52b 100644
--- a/tests/integration_tests/fixtures/datasource.py
+++ b/tests/integration_tests/fixtures/datasource.py
@@ -15,10 +15,20 @@
# specific language governing permissions and limitations
# under the License.
"""Fixtures for test_datasource.py"""
-from typing import Any, Dict
+from typing import Any, Dict, Generator
+import pytest
+from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String,
Table
+from sqlalchemy.ext.declarative.api import declarative_base
+
+from superset.columns.models import Column as Sl_Column
+from superset.connectors.sqla.models import SqlaTable, TableColumn
+from superset.extensions import db
+from superset.models.core import Database
+from superset.tables.models import Table as Sl_Table
from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database
+from tests.integration_tests.test_app import app
def get_datasource_post() -> Dict[str, Any]:
@@ -159,3 +169,43 @@ def get_datasource_post() -> Dict[str, Any]:
},
],
}
+
+
[email protected]()
+def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
+ with app.app_context():
+ engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"],
echo=True)
+ meta = MetaData()
+ session = db.session
+
+ students = Table(
+ "students",
+ meta,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(255)),
+ Column("lastname", String(255)),
+ Column("ds", Date),
+ )
+ meta.create_all(engine)
+
+ students.insert().values(name="George", ds="2021-01-01")
+
+ dataset = SqlaTable(
+ database_id=db.session.query(Database).first().id,
table_name="students"
+ )
+ column = TableColumn(table_id=dataset.id, column_name="name")
+ dataset.columns = [column]
+ session.add(dataset)
+ session.commit()
+ yield dataset
+
+ # cleanup
+ students_table = meta.tables.get("students")
+ if students_table is not None:
+ base = declarative_base()
+ # needed for sqlite
+ session.commit()
+ base.metadata.drop_all(engine, [students_table], checkfirst=True)
+ session.delete(dataset)
+ session.delete(column)
+ session.commit()