This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 88029e2  fix dataset update table (#19269)
88029e2 is described below

commit 88029e21b6068f845d806cfc10d478a5d972ffa5
Author: Elizabeth Thompson <[email protected]>
AuthorDate: Mon Mar 21 09:43:51 2022 -0700

    fix dataset update table (#19269)
---
 superset/connectors/sqla/models.py       | 272 +++++++++++++++++--------------
 tests/unit_tests/datasets/test_models.py |  81 ++++++++-
 2 files changed, 225 insertions(+), 128 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 62ae8c9..bbd1b5d 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1863,11 +1863,20 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
 
         session.execute(update(SqlaTable).where(SqlaTable.id == 
target.table.id))
 
-        # update ``Column`` model as well
         dataset = (
-            
session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one()
+            session.query(NewDataset)
+            .filter_by(sqlatable_id=target.table.id)
+            .one_or_none()
         )
 
+        if not dataset:
+            # if dataset is not found create a new copy
+            # of the dataset instead of updating the existing
+
+            SqlaTable.write_shadow_dataset(target.table, database, session)
+            return
+
+        # update ``Column`` model as well
         if isinstance(target, TableColumn):
             columns = [
                 column
@@ -1923,7 +1932,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
             column.extra_json = json.dumps(extra_json) if extra_json else None
 
     @staticmethod
-    def after_insert(  # pylint: disable=too-many-locals
+    def after_insert(
         mapper: Mapper, connection: Connection, target: "SqlaTable",
     ) -> None:
         """
@@ -1938,135 +1947,18 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
 
         For more context: https://github.com/apache/superset/issues/14909
         """
+        session = inspect(target).session
         # set permissions
         security_manager.set_perm(mapper, connection, target)
 
-        session = inspect(target).session
-
         # get DB-specific conditional quoter for expressions that point to 
columns or
         # table names
         database = (
             target.database
             or session.query(Database).filter_by(id=target.database_id).one()
         )
-        engine = database.get_sqla_engine(schema=target.schema)
-        conditional_quote = engine.dialect.identifier_preparer.quote
-
-        # create columns
-        columns = []
-        for column in target.columns:
-            # ``is_active`` might be ``None`` at this point, but it defaults 
to ``True``.
-            if column.is_active is False:
-                continue
-
-            extra_json = json.loads(column.extra or "{}")
-            for attr in {"groupby", "filterable", "verbose_name", 
"python_date_format"}:
-                value = getattr(column, attr)
-                if value:
-                    extra_json[attr] = value
-
-            columns.append(
-                NewColumn(
-                    name=column.column_name,
-                    type=column.type or "Unknown",
-                    expression=column.expression
-                    or conditional_quote(column.column_name),
-                    description=column.description,
-                    is_temporal=column.is_dttm,
-                    is_aggregation=False,
-                    is_physical=column.expression is None,
-                    is_spatial=False,
-                    is_partition=False,
-                    is_increase_desired=True,
-                    extra_json=json.dumps(extra_json) if extra_json else None,
-                    is_managed_externally=target.is_managed_externally,
-                    external_url=target.external_url,
-                ),
-            )
-
-        # create metrics
-        for metric in target.metrics:
-            extra_json = json.loads(metric.extra or "{}")
-            for attr in {"verbose_name", "metric_type", "d3format"}:
-                value = getattr(metric, attr)
-                if value:
-                    extra_json[attr] = value
-
-            is_additive = (
-                metric.metric_type
-                and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
-            )
-
-            columns.append(
-                NewColumn(
-                    name=metric.metric_name,
-                    type="Unknown",  # figuring this out would require a type 
inferrer
-                    expression=metric.expression,
-                    warning_text=metric.warning_text,
-                    description=metric.description,
-                    is_aggregation=True,
-                    is_additive=is_additive,
-                    is_physical=False,
-                    is_spatial=False,
-                    is_partition=False,
-                    is_increase_desired=True,
-                    extra_json=json.dumps(extra_json) if extra_json else None,
-                    is_managed_externally=target.is_managed_externally,
-                    external_url=target.external_url,
-                ),
-            )
-
-        # physical dataset
-        tables = []
-        if target.sql is None:
-            physical_columns = [column for column in columns if 
column.is_physical]
-
-            # create table
-            table = NewTable(
-                name=target.table_name,
-                schema=target.schema,
-                catalog=None,  # currently not supported
-                database_id=target.database_id,
-                columns=physical_columns,
-                is_managed_externally=target.is_managed_externally,
-                external_url=target.external_url,
-            )
-            tables.append(table)
 
-        # virtual dataset
-        else:
-            # mark all columns as virtual (not physical)
-            for column in columns:
-                column.is_physical = False
-
-            # find referenced tables
-            parsed = ParsedQuery(target.sql)
-            referenced_tables = parsed.tables
-
-            # predicate for finding the referenced tables
-            predicate = or_(
-                *[
-                    and_(
-                        NewTable.schema == (table.schema or target.schema),
-                        NewTable.name == table.table,
-                    )
-                    for table in referenced_tables
-                ]
-            )
-            tables = session.query(NewTable).filter(predicate).all()
-
-        # create the new dataset
-        dataset = NewDataset(
-            sqlatable_id=target.id,
-            name=target.table_name,
-            expression=target.sql or conditional_quote(target.table_name),
-            tables=tables,
-            columns=columns,
-            is_physical=target.sql is None,
-            is_managed_externally=target.is_managed_externally,
-            external_url=target.external_url,
-        )
-        session.add(dataset)
+        SqlaTable.write_shadow_dataset(target, database, session)
 
     @staticmethod
     def after_delete(  # pylint: disable=unused-argument
@@ -2301,6 +2193,142 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
         dataset.expression = target.sql or conditional_quote(target.table_name)
         dataset.is_physical = target.sql is None
 
+    @staticmethod
+    def write_shadow_dataset(  # pylint: disable=too-many-locals
+        dataset: "SqlaTable", database: Database, session: Session
+    ) -> None:
+        """
+        Shadow write the dataset to new models.
+
+        The ``SqlaTable`` model is currently being migrated to two new models, 
``Table``
+        and ``Dataset``. In the first phase of the migration the new models 
are populated
+        whenever ``SqlaTable`` is modified (created, updated, or deleted).
+
+        In the second phase of the migration reads will be done from the new 
models.
+        Finally, in the third phase of the migration the old models will be 
removed.
+
+        For more context: https://github.com/apache/superset/issues/14909
+        """
+
+        engine = database.get_sqla_engine(schema=dataset.schema)
+        conditional_quote = engine.dialect.identifier_preparer.quote
+
+        # create columns
+        columns = []
+        for column in dataset.columns:
+            # ``is_active`` might be ``None`` at this point, but it defaults 
to ``True``.
+            if column.is_active is False:
+                continue
+
+            extra_json = json.loads(column.extra or "{}")
+            for attr in {"groupby", "filterable", "verbose_name", 
"python_date_format"}:
+                value = getattr(column, attr)
+                if value:
+                    extra_json[attr] = value
+
+            columns.append(
+                NewColumn(
+                    name=column.column_name,
+                    type=column.type or "Unknown",
+                    expression=column.expression
+                    or conditional_quote(column.column_name),
+                    description=column.description,
+                    is_temporal=column.is_dttm,
+                    is_aggregation=False,
+                    is_physical=column.expression is None,
+                    is_spatial=False,
+                    is_partition=False,
+                    is_increase_desired=True,
+                    extra_json=json.dumps(extra_json) if extra_json else None,
+                    is_managed_externally=dataset.is_managed_externally,
+                    external_url=dataset.external_url,
+                ),
+            )
+
+        # create metrics
+        for metric in dataset.metrics:
+            extra_json = json.loads(metric.extra or "{}")
+            for attr in {"verbose_name", "metric_type", "d3format"}:
+                value = getattr(metric, attr)
+                if value:
+                    extra_json[attr] = value
+
+            is_additive = (
+                metric.metric_type
+                and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
+            )
+
+            columns.append(
+                NewColumn(
+                    name=metric.metric_name,
+                    type="Unknown",  # figuring this out would require a type 
inferrer
+                    expression=metric.expression,
+                    warning_text=metric.warning_text,
+                    description=metric.description,
+                    is_aggregation=True,
+                    is_additive=is_additive,
+                    is_physical=False,
+                    is_spatial=False,
+                    is_partition=False,
+                    is_increase_desired=True,
+                    extra_json=json.dumps(extra_json) if extra_json else None,
+                    is_managed_externally=dataset.is_managed_externally,
+                    external_url=dataset.external_url,
+                ),
+            )
+
+        # physical dataset
+        tables = []
+        if dataset.sql is None:
+            physical_columns = [column for column in columns if 
column.is_physical]
+
+            # create table
+            table = NewTable(
+                name=dataset.table_name,
+                schema=dataset.schema,
+                catalog=None,  # currently not supported
+                database_id=dataset.database_id,
+                columns=physical_columns,
+                is_managed_externally=dataset.is_managed_externally,
+                external_url=dataset.external_url,
+            )
+            tables.append(table)
+
+        # virtual dataset
+        else:
+            # mark all columns as virtual (not physical)
+            for column in columns:
+                column.is_physical = False
+
+            # find referenced tables
+            parsed = ParsedQuery(dataset.sql)
+            referenced_tables = parsed.tables
+
+            # predicate for finding the referenced tables
+            predicate = or_(
+                *[
+                    and_(
+                        NewTable.schema == (table.schema or dataset.schema),
+                        NewTable.name == table.table,
+                    )
+                    for table in referenced_tables
+                ]
+            )
+            tables = session.query(NewTable).filter(predicate).all()
+
+        # create the new dataset
+        new_dataset = NewDataset(
+            sqlatable_id=dataset.id,
+            name=dataset.table_name,
+            expression=dataset.sql or conditional_quote(dataset.table_name),
+            tables=tables,
+            columns=columns,
+            is_physical=dataset.sql is None,
+            is_managed_externally=dataset.is_managed_externally,
+            external_url=dataset.external_url,
+        )
+        session.add(new_dataset)
+
 
 sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
 sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
diff --git a/tests/unit_tests/datasets/test_models.py 
b/tests/unit_tests/datasets/test_models.py
index eab0a8a..095b502 100644
--- a/tests/unit_tests/datasets/test_models.py
+++ b/tests/unit_tests/datasets/test_models.py
@@ -980,9 +980,9 @@ def test_update_sqlatable_schema(
     sqla_table.schema = "new_schema"
     session.flush()
 
-    dataset = session.query(Dataset).one()
-    assert dataset.tables[0].schema == "new_schema"
-    assert dataset.tables[0].id == 2
+    new_dataset = session.query(Dataset).one()
+    assert new_dataset.tables[0].schema == "new_schema"
+    assert new_dataset.tables[0].id == 2
 
 
 def test_update_sqlatable_metric(
@@ -1098,9 +1098,9 @@ def test_update_virtual_sqlatable_references(
     session.flush()
 
     # check that new dataset has both tables
-    dataset = session.query(Dataset).one()
-    assert dataset.tables == [table1, table2]
-    assert dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
+    new_dataset = session.query(Dataset).one()
+    assert new_dataset.tables == [table1, table2]
+    assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
 
 
 def test_quote_expressions(app_context: None, session: Session) -> None:
@@ -1242,3 +1242,72 @@ def test_update_physical_sqlatable(
 
     # check that dataset points to the original table
     assert dataset.tables[0].database_id == 1
+
+
+def test_update_physical_sqlatable_no_dataset(
+    mocker: MockFixture, app_context: None, session: Session
+) -> None:
+    """
+    Test updating the table on a physical dataset that it creates
+    a new dataset if one didn't already exist.
+
+    When updating the table on a physical dataset by pointing it somewhere 
else (change
+    in database ID, schema, or table name) we should point the ``Dataset`` to 
an
+    existing ``Table`` if possible, and create a new one otherwise.
+    """
+    # patch session
+    mocker.patch(
+        "superset.security.SupersetSecurityManager.get_session", 
return_value=session
+    )
+    mocker.patch("superset.datasets.dao.db.session", session)
+
+    from superset.columns.models import Column
+    from superset.connectors.sqla.models import SqlaTable, TableColumn
+    from superset.datasets.models import Dataset
+    from superset.models.core import Database
+    from superset.tables.models import Table
+    from superset.tables.schemas import TableSchema
+
+    engine = session.get_bind()
+    Dataset.metadata.create_all(engine)  # pylint: disable=no-member
+
+    columns = [
+        TableColumn(column_name="a", type="INTEGER"),
+    ]
+
+    sqla_table = SqlaTable(
+        table_name="old_dataset",
+        columns=columns,
+        metrics=[],
+        database=Database(database_name="my_database", 
sqlalchemy_uri="sqlite://"),
+    )
+    session.add(sqla_table)
+    session.flush()
+
+    # check that the table was created
+    table = session.query(Table).one()
+    assert table.id == 1
+
+    dataset = session.query(Dataset).one()
+    assert dataset.tables == [table]
+
+    # point ``SqlaTable`` to a different database
+    new_database = Database(
+        database_name="my_other_database", sqlalchemy_uri="sqlite://"
+    )
+    session.add(new_database)
+    session.flush()
+    sqla_table.database = new_database
+    session.flush()
+
+    new_dataset = session.query(Dataset).one()
+
+    # check that dataset now points to the new table
+    assert new_dataset.tables[0].database_id == 2
+
+    # point ``SqlaTable`` back
+    sqla_table.database_id = 1
+    session.flush()
+
+    # check that dataset points to the original table
+    assert new_dataset.tables[0].database_id == 1

Reply via email to