This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 88029e2 fix dataset update table (#19269)
88029e2 is described below
commit 88029e21b6068f845d806cfc10d478a5d972ffa5
Author: Elizabeth Thompson <[email protected]>
AuthorDate: Mon Mar 21 09:43:51 2022 -0700
fix dataset update table (#19269)
---
superset/connectors/sqla/models.py | 272 +++++++++++++++++--------------
tests/unit_tests/datasets/test_models.py | 81 ++++++++-
2 files changed, 225 insertions(+), 128 deletions(-)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 62ae8c9..bbd1b5d 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1863,11 +1863,20 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
session.execute(update(SqlaTable).where(SqlaTable.id ==
target.table.id))
- # update ``Column`` model as well
dataset = (
-
session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one()
+ session.query(NewDataset)
+ .filter_by(sqlatable_id=target.table.id)
+ .one_or_none()
)
+ if not dataset:
+ # if dataset is not found create a new copy
+ # of the dataset instead of updating the existing
+
+ SqlaTable.write_shadow_dataset(target.table, database, session)
+ return
+
+ # update ``Column`` model as well
if isinstance(target, TableColumn):
columns = [
column
@@ -1923,7 +1932,7 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
column.extra_json = json.dumps(extra_json) if extra_json else None
@staticmethod
- def after_insert( # pylint: disable=too-many-locals
+ def after_insert(
mapper: Mapper, connection: Connection, target: "SqlaTable",
) -> None:
"""
@@ -1938,135 +1947,18 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
For more context: https://github.com/apache/superset/issues/14909
"""
+ session = inspect(target).session
# set permissions
security_manager.set_perm(mapper, connection, target)
- session = inspect(target).session
-
# get DB-specific conditional quoter for expressions that point to
columns or
# table names
database = (
target.database
or session.query(Database).filter_by(id=target.database_id).one()
)
- engine = database.get_sqla_engine(schema=target.schema)
- conditional_quote = engine.dialect.identifier_preparer.quote
-
- # create columns
- columns = []
- for column in target.columns:
- # ``is_active`` might be ``None`` at this point, but it defaults
to ``True``.
- if column.is_active is False:
- continue
-
- extra_json = json.loads(column.extra or "{}")
- for attr in {"groupby", "filterable", "verbose_name",
"python_date_format"}:
- value = getattr(column, attr)
- if value:
- extra_json[attr] = value
-
- columns.append(
- NewColumn(
- name=column.column_name,
- type=column.type or "Unknown",
- expression=column.expression
- or conditional_quote(column.column_name),
- description=column.description,
- is_temporal=column.is_dttm,
- is_aggregation=False,
- is_physical=column.expression is None,
- is_spatial=False,
- is_partition=False,
- is_increase_desired=True,
- extra_json=json.dumps(extra_json) if extra_json else None,
- is_managed_externally=target.is_managed_externally,
- external_url=target.external_url,
- ),
- )
-
- # create metrics
- for metric in target.metrics:
- extra_json = json.loads(metric.extra or "{}")
- for attr in {"verbose_name", "metric_type", "d3format"}:
- value = getattr(metric, attr)
- if value:
- extra_json[attr] = value
-
- is_additive = (
- metric.metric_type
- and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
- )
-
- columns.append(
- NewColumn(
- name=metric.metric_name,
- type="Unknown", # figuring this out would require a type
inferrer
- expression=metric.expression,
- warning_text=metric.warning_text,
- description=metric.description,
- is_aggregation=True,
- is_additive=is_additive,
- is_physical=False,
- is_spatial=False,
- is_partition=False,
- is_increase_desired=True,
- extra_json=json.dumps(extra_json) if extra_json else None,
- is_managed_externally=target.is_managed_externally,
- external_url=target.external_url,
- ),
- )
-
- # physical dataset
- tables = []
- if target.sql is None:
- physical_columns = [column for column in columns if
column.is_physical]
-
- # create table
- table = NewTable(
- name=target.table_name,
- schema=target.schema,
- catalog=None, # currently not supported
- database_id=target.database_id,
- columns=physical_columns,
- is_managed_externally=target.is_managed_externally,
- external_url=target.external_url,
- )
- tables.append(table)
- # virtual dataset
- else:
- # mark all columns as virtual (not physical)
- for column in columns:
- column.is_physical = False
-
- # find referenced tables
- parsed = ParsedQuery(target.sql)
- referenced_tables = parsed.tables
-
- # predicate for finding the referenced tables
- predicate = or_(
- *[
- and_(
- NewTable.schema == (table.schema or target.schema),
- NewTable.name == table.table,
- )
- for table in referenced_tables
- ]
- )
- tables = session.query(NewTable).filter(predicate).all()
-
- # create the new dataset
- dataset = NewDataset(
- sqlatable_id=target.id,
- name=target.table_name,
- expression=target.sql or conditional_quote(target.table_name),
- tables=tables,
- columns=columns,
- is_physical=target.sql is None,
- is_managed_externally=target.is_managed_externally,
- external_url=target.external_url,
- )
- session.add(dataset)
+ SqlaTable.write_shadow_dataset(target, database, session)
@staticmethod
def after_delete( # pylint: disable=unused-argument
@@ -2301,6 +2193,142 @@ class SqlaTable(Model, BaseDatasource): # pylint:
disable=too-many-public-metho
dataset.expression = target.sql or conditional_quote(target.table_name)
dataset.is_physical = target.sql is None
+ @staticmethod
+ def write_shadow_dataset( # pylint: disable=too-many-locals
+ dataset: "SqlaTable", database: Database, session: Session
+ ) -> None:
+ """
+ Shadow write the dataset to new models.
+
+ The ``SqlaTable`` model is currently being migrated to two new models,
``Table``
+ and ``Dataset``. In the first phase of the migration the new models
are populated
+ whenever ``SqlaTable`` is modified (created, updated, or deleted).
+
+ In the second phase of the migration reads will be done from the new
models.
+ Finally, in the third phase of the migration the old models will be
removed.
+
+ For more context: https://github.com/apache/superset/issues/14909
+ """
+
+ engine = database.get_sqla_engine(schema=dataset.schema)
+ conditional_quote = engine.dialect.identifier_preparer.quote
+
+ # create columns
+ columns = []
+ for column in dataset.columns:
+ # ``is_active`` might be ``None`` at this point, but it defaults
to ``True``.
+ if column.is_active is False:
+ continue
+
+ extra_json = json.loads(column.extra or "{}")
+ for attr in {"groupby", "filterable", "verbose_name",
"python_date_format"}:
+ value = getattr(column, attr)
+ if value:
+ extra_json[attr] = value
+
+ columns.append(
+ NewColumn(
+ name=column.column_name,
+ type=column.type or "Unknown",
+ expression=column.expression
+ or conditional_quote(column.column_name),
+ description=column.description,
+ is_temporal=column.is_dttm,
+ is_aggregation=False,
+ is_physical=column.expression is None,
+ is_spatial=False,
+ is_partition=False,
+ is_increase_desired=True,
+ extra_json=json.dumps(extra_json) if extra_json else None,
+ is_managed_externally=dataset.is_managed_externally,
+ external_url=dataset.external_url,
+ ),
+ )
+
+ # create metrics
+ for metric in dataset.metrics:
+ extra_json = json.loads(metric.extra or "{}")
+ for attr in {"verbose_name", "metric_type", "d3format"}:
+ value = getattr(metric, attr)
+ if value:
+ extra_json[attr] = value
+
+ is_additive = (
+ metric.metric_type
+ and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
+ )
+
+ columns.append(
+ NewColumn(
+ name=metric.metric_name,
+ type="Unknown", # figuring this out would require a type
inferrer
+ expression=metric.expression,
+ warning_text=metric.warning_text,
+ description=metric.description,
+ is_aggregation=True,
+ is_additive=is_additive,
+ is_physical=False,
+ is_spatial=False,
+ is_partition=False,
+ is_increase_desired=True,
+ extra_json=json.dumps(extra_json) if extra_json else None,
+ is_managed_externally=dataset.is_managed_externally,
+ external_url=dataset.external_url,
+ ),
+ )
+
+ # physical dataset
+ tables = []
+ if dataset.sql is None:
+ physical_columns = [column for column in columns if
column.is_physical]
+
+ # create table
+ table = NewTable(
+ name=dataset.table_name,
+ schema=dataset.schema,
+ catalog=None, # currently not supported
+ database_id=dataset.database_id,
+ columns=physical_columns,
+ is_managed_externally=dataset.is_managed_externally,
+ external_url=dataset.external_url,
+ )
+ tables.append(table)
+
+ # virtual dataset
+ else:
+ # mark all columns as virtual (not physical)
+ for column in columns:
+ column.is_physical = False
+
+ # find referenced tables
+ parsed = ParsedQuery(dataset.sql)
+ referenced_tables = parsed.tables
+
+ # predicate for finding the referenced tables
+ predicate = or_(
+ *[
+ and_(
+ NewTable.schema == (table.schema or dataset.schema),
+ NewTable.name == table.table,
+ )
+ for table in referenced_tables
+ ]
+ )
+ tables = session.query(NewTable).filter(predicate).all()
+
+ # create the new dataset
+ new_dataset = NewDataset(
+ sqlatable_id=dataset.id,
+ name=dataset.table_name,
+ expression=dataset.sql or conditional_quote(dataset.table_name),
+ tables=tables,
+ columns=columns,
+ is_physical=dataset.sql is None,
+ is_managed_externally=dataset.is_managed_externally,
+ external_url=dataset.external_url,
+ )
+ session.add(new_dataset)
+
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
diff --git a/tests/unit_tests/datasets/test_models.py
b/tests/unit_tests/datasets/test_models.py
index eab0a8a..095b502 100644
--- a/tests/unit_tests/datasets/test_models.py
+++ b/tests/unit_tests/datasets/test_models.py
@@ -980,9 +980,9 @@ def test_update_sqlatable_schema(
sqla_table.schema = "new_schema"
session.flush()
- dataset = session.query(Dataset).one()
- assert dataset.tables[0].schema == "new_schema"
- assert dataset.tables[0].id == 2
+ new_dataset = session.query(Dataset).one()
+ assert new_dataset.tables[0].schema == "new_schema"
+ assert new_dataset.tables[0].id == 2
def test_update_sqlatable_metric(
@@ -1098,9 +1098,9 @@ def test_update_virtual_sqlatable_references(
session.flush()
# check that new dataset has both tables
- dataset = session.query(Dataset).one()
- assert dataset.tables == [table1, table2]
- assert dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
+ new_dataset = session.query(Dataset).one()
+ assert new_dataset.tables == [table1, table2]
+ assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
def test_quote_expressions(app_context: None, session: Session) -> None:
@@ -1242,3 +1242,72 @@ def test_update_physical_sqlatable(
# check that dataset points to the original table
assert dataset.tables[0].database_id == 1
+
+
+def test_update_physical_sqlatable_no_dataset(
+ mocker: MockFixture, app_context: None, session: Session
+) -> None:
+ """
+ Test updating the table on a physical dataset that it creates
+ a new dataset if one didn't already exist.
+
+ When updating the table on a physical dataset by pointing it somewhere
else (change
+ in database ID, schema, or table name) we should point the ``Dataset`` to
an
+ existing ``Table`` if possible, and create a new one otherwise.
+ """
+ # patch session
+ mocker.patch(
+ "superset.security.SupersetSecurityManager.get_session",
return_value=session
+ )
+ mocker.patch("superset.datasets.dao.db.session", session)
+
+ from superset.columns.models import Column
+ from superset.connectors.sqla.models import SqlaTable, TableColumn
+ from superset.datasets.models import Dataset
+ from superset.models.core import Database
+ from superset.tables.models import Table
+ from superset.tables.schemas import TableSchema
+
+ engine = session.get_bind()
+ Dataset.metadata.create_all(engine) # pylint: disable=no-member
+
+ columns = [
+ TableColumn(column_name="a", type="INTEGER"),
+ ]
+
+ sqla_table = SqlaTable(
+ table_name="old_dataset",
+ columns=columns,
+ metrics=[],
+ database=Database(database_name="my_database",
sqlalchemy_uri="sqlite://"),
+ )
+ session.add(sqla_table)
+ session.flush()
+
+ # check that the table was created
+ table = session.query(Table).one()
+ assert table.id == 1
+
+ dataset = session.query(Dataset).one()
+ assert dataset.tables == [table]
+
+ # point ``SqlaTable`` to a different database
+ new_database = Database(
+ database_name="my_other_database", sqlalchemy_uri="sqlite://"
+ )
+ session.add(new_database)
+ session.flush()
+ sqla_table.database = new_database
+ session.flush()
+
+ new_dataset = session.query(Dataset).one()
+
+ # check that dataset now points to the new table
+ assert new_dataset.tables[0].database_id == 2
+
+ # point ``SqlaTable`` back
+ sqla_table.database_id = 1
+ session.flush()
+
+ # check that dataset points to the original table
+ assert new_dataset.tables[0].database_id == 1