This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new bc435e08d0 fix: overwrite update override columns on PUT /dataset 
(#20862)
bc435e08d0 is described below

commit bc435e08d01b87efcf8774f29a7078cee8988e39
Author: Hugh A. Miles II <[email protected]>
AuthorDate: Fri Jul 29 21:51:35 2022 -0400

    fix: overwrite update override columns on PUT /dataset (#20862)
    
    * update override columns
    
    * save
    
    * fix overwrite with session.flush
    
    * write test
    
    * write test
    
    * layup
    
    * address concerns
    
    * address concerns
---
 superset/datasets/commands/update.py          |  3 +-
 superset/datasets/dao.py                      | 50 +++++++++++++++++++--------
 tests/integration_tests/datasets/api_tests.py | 50 +++++++++++++++++++++++++++
 3 files changed, 88 insertions(+), 15 deletions(-)

diff --git a/superset/datasets/commands/update.py 
b/superset/datasets/commands/update.py
index e3c908cebb..483a98e76c 100644
--- a/superset/datasets/commands/update.py
+++ b/superset/datasets/commands/update.py
@@ -50,12 +50,13 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
         self,
         model_id: int,
         data: Dict[str, Any],
-        override_columns: bool = False,
+        override_columns: Optional[bool] = False,
     ):
         self._model_id = model_id
         self._properties = data.copy()
         self._model: Optional[SqlaTable] = None
         self.override_columns = override_columns
+        self._properties["override_columns"] = override_columns
 
     def run(self) -> Model:
         self.validate()
diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py
index a538a70c13..d260df3610 100644
--- a/superset/datasets/dao.py
+++ b/superset/datasets/dao.py
@@ -147,14 +147,22 @@ class DatasetDAO(BaseDAO):  # pylint: 
disable=too-many-public-methods
 
     @classmethod
     def update(
-        cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
+        cls,
+        model: SqlaTable,
+        properties: Dict[str, Any],
+        commit: bool = True,
     ) -> Optional[SqlaTable]:
         """
         Updates a Dataset model on the metadata DB
         """
 
         if "columns" in properties:
-            cls.update_columns(model, properties.pop("columns"), commit=commit)
+            cls.update_columns(
+                model,
+                properties.pop("columns"),
+                commit=commit,
+                override_columns=bool(properties.get("override_columns")),
+            )
 
         if "metrics" in properties:
             cls.update_metrics(model, properties.pop("metrics"), commit=commit)
@@ -167,6 +175,7 @@ class DatasetDAO(BaseDAO):  # pylint: 
disable=too-many-public-methods
         model: SqlaTable,
         property_columns: List[Dict[str, Any]],
         commit: bool = True,
+        override_columns: bool = False,
     ) -> None:
         """
         Creates/updates and/or deletes a list of columns, based on a
@@ -180,24 +189,37 @@ class DatasetDAO(BaseDAO):  # pylint: 
disable=too-many-public-methods
 
         column_by_id = {column.id: column for column in model.columns}
         seen = set()
+        original_cols = {obj.id for obj in model.columns}
 
-        for properties in property_columns:
-            if "id" in properties:
-                seen.add(properties["id"])
+        if override_columns:
+            for id_ in original_cols:
+                DatasetDAO.delete_column(column_by_id[id_], commit=False)
 
-                DatasetDAO.update_column(
-                    column_by_id[properties["id"]],
-                    properties,
-                    commit=False,
-                )
-            else:
+            db.session.flush()
+
+            for properties in property_columns:
                 DatasetDAO.create_column(
                     {**properties, "table_id": model.id},
                     commit=False,
                 )
-
-        for id_ in {obj.id for obj in model.columns} - seen:
-            DatasetDAO.delete_column(column_by_id[id_], commit=False)
+        else:
+            for properties in property_columns:
+                if "id" in properties:
+                    seen.add(properties["id"])
+
+                    DatasetDAO.update_column(
+                        column_by_id[properties["id"]],
+                        properties,
+                        commit=False,
+                    )
+                else:
+                    DatasetDAO.create_column(
+                        {**properties, "table_id": model.id},
+                        commit=False,
+                    )
+
+            for id_ in {obj.id for obj in model.columns} - seen:
+                DatasetDAO.delete_column(column_by_id[id_], commit=False)
 
         if commit:
             db.session.commit()
diff --git a/tests/integration_tests/datasets/api_tests.py 
b/tests/integration_tests/datasets/api_tests.py
index 46739f9631..a993f0c0b8 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -777,6 +777,56 @@ class TestDatasetApi(SupersetTestCase):
         db.session.delete(dataset)
         db.session.commit()
 
+    def test_update_dataset_item_w_override_columns_same_columns(self):
+        """
+        Dataset API: Test update dataset with override columns
+        """
+        if backend() == "sqlite":
+            return
+
+        # Add default dataset
+        main_db = get_main_database()
+        dataset = self.insert_default_dataset()
+        prev_col_len = len(dataset.columns)
+
+        cols = [
+            {
+                "column_name": c.column_name,
+                "description": c.description,
+                "expression": c.expression,
+                "type": c.type,
+                "advanced_data_type": c.advanced_data_type,
+                "verbose_name": c.verbose_name,
+            }
+            for c in dataset.columns
+        ]
+
+        cols.append(
+            {
+                "column_name": "new_col",
+                "description": "description",
+                "expression": "expression",
+                "type": "INTEGER",
+                "advanced_data_type": "ADVANCED_DATA_TYPE",
+                "verbose_name": "New Col",
+            }
+        )
+
+        self.login(username="admin")
+        dataset_data = {
+            "columns": cols,
+        }
+        uri = f"api/v1/dataset/{dataset.id}?override_columns=true"
+        rv = self.put_assert_metric(uri, dataset_data, "put")
+
+        assert rv.status_code == 200
+
+        columns = 
db.session.query(TableColumn).filter_by(table_id=dataset.id).all()
+        assert len(columns) != prev_col_len
+        assert len(columns) == 3
+        db.session.delete(dataset)
+        db.session.commit()
+
     def test_update_dataset_create_column_and_metric(self):
         """
         Dataset API: Test update dataset create column

Reply via email to