This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3fae7dbc34 chore: Migrate get_or_create_table endpoint to api v1 
(#22931)
3fae7dbc34 is described below

commit 3fae7dbc34b0ab98b255b1ca77cc0ddfdad0a495
Author: Jack Fragassi <[email protected]>
AuthorDate: Wed Feb 15 02:38:51 2023 -0800

    chore: Migrate get_or_create_table endpoint to api v1 (#22931)
---
 superset-frontend/src/SqlLab/actions/sqlLab.js     |  8 +-
 .../components/ExploreCtasResultsButton/index.tsx  |  6 +-
 superset/connectors/sqla/views.py                  |  2 -
 superset/datasets/api.py                           | 70 +++++++++++++++++
 superset/datasets/dao.py                           |  8 ++
 superset/datasets/schemas.py                       | 11 +++
 superset/views/base.py                             |  6 --
 superset/views/core.py                             |  4 +-
 tests/integration_tests/datasets/api_tests.py      | 89 ++++++++++++++++++++++
 tests/integration_tests/datasets/commands_tests.py | 51 ++++++++++++-
 10 files changed, 237 insertions(+), 18 deletions(-)

diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js 
b/superset-frontend/src/SqlLab/actions/sqlLab.js
index 40aea66301..ab8abe0edc 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.js
@@ -1513,13 +1513,13 @@ export function createCtasDatasource(vizOptions) {
   return dispatch => {
     dispatch(createDatasourceStarted());
     return SupersetClient.post({
-      endpoint: '/superset/get_or_create_table/',
-      postPayload: { data: vizOptions },
+      endpoint: '/api/v1/dataset/get_or_create/',
+      jsonPayload: vizOptions,
     })
       .then(({ json }) => {
-        dispatch(createDatasourceSuccess(json));
+        dispatch(createDatasourceSuccess(json.result));
 
-        return json;
+        return json.result;
       })
       .catch(() => {
         const errorMsg = t('An error occurred while creating the data source');
diff --git 
a/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx 
b/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx
index 2fe1e14a07..a4c71139c0 100644
--- a/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx
+++ b/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx
@@ -48,10 +48,10 @@ const ExploreCtasResultsButton = ({
   const dispatch = useDispatch<(dispatch: any) => Promise<JsonObject>>();
 
   const buildVizOptions = {
-    datasourceName: table,
+    table_name: table,
     schema,
-    dbId,
-    templateParams,
+    database_id: dbId,
+    template_params: templateParams,
   };
 
   const visualize = () => {
diff --git a/superset/connectors/sqla/views.py 
b/superset/connectors/sqla/views.py
index c502f527ac..86cb08bb86 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -35,7 +35,6 @@ from superset.constants import 
MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
 from superset.superset_typing import FlaskResponse
 from superset.utils import core as utils
 from superset.views.base import (
-    create_table_permissions,
     DatasourceFilter,
     DeleteMixin,
     ListWidgetWithCheckboxes,
@@ -511,7 +510,6 @@ class TableModelView(  # pylint: disable=too-many-ancestors
     ) -> None:
         if fetch_metadata:
             item.fetch_metadata()
-        create_table_permissions(item)
         if flash_message:
             flash(
                 _(
diff --git a/superset/datasets/api.py b/superset/datasets/api.py
index 925c3c7cb8..d58a1dd3f6 100644
--- a/superset/datasets/api.py
+++ b/superset/datasets/api.py
@@ -61,6 +61,7 @@ from superset.datasets.schemas import (
     DatasetRelatedObjectsResponse,
     get_delete_ids_schema,
     get_export_ids_schema,
+    GetOrCreateDatasetSchema,
 )
 from superset.utils.core import parse_boolean_string
 from superset.views.base import DatasourceFilter, generate_download_headers
@@ -93,6 +94,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
         "refresh",
         "related_objects",
         "duplicate",
+        "get_or_create_dataset",
     }
     list_columns = [
         "id",
@@ -240,6 +242,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
     openapi_spec_component_schemas = (
         DatasetRelatedObjectsResponse,
         DatasetDuplicateSchema,
+        GetOrCreateDatasetSchema,
     )
 
     list_outer_default_load = True
@@ -877,3 +880,70 @@ class DatasetRestApi(BaseSupersetModelRestApi):
         )
         command.run()
         return self.response(200, message="OK")
+
+    @expose("/get_or_create/", methods=["POST"])
+    @protect()
+    @safe
+    @statsd_metrics
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+        f".get_or_create_dataset",
+        log_to_statsd=False,
+    )
+    def get_or_create_dataset(self) -> Response:
+        """Retrieve a dataset by name, or create it if it does not exist
+        ---
+        post:
+          summary: Retrieve a table by name, or create it if it does not exist
+          requestBody:
+            required: true
+            content:
+              application/json:
+                schema:
+                  $ref: '#/components/schemas/GetOrCreateDatasetSchema'
+          responses:
+            200:
+              description: The ID of the table
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      result:
+                        type: object
+                        properties:
+                          table_id:
+                            type: integer
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            422:
+              $ref: '#/components/responses/422'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            body = GetOrCreateDatasetSchema().load(request.json)
+        except ValidationError as ex:
+            return self.response(400, message=ex.messages)
+        table_name = body["table_name"]
+        database_id = body["database_id"]
+        table = DatasetDAO.get_table_by_name(database_id, table_name)
+        if table:
+            return self.response(200, result={"table_id": table.id})
+
+        body["database"] = database_id
+        try:
+            tbl = CreateDatasetCommand(body).run()
+            return self.response(200, result={"table_id": tbl.id})
+        except DatasetInvalidError as ex:
+            return self.response_422(message=ex.normalized_messages())
+        except DatasetCreateFailedError as ex:
+            logger.error(
+                "Error creating model %s: %s",
+                self.__class__.__name__,
+                str(ex),
+                exc_info=True,
+            )
+            return self.response_422(message=ex.message)
diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py
index 1c55723b47..b158fce1fe 100644
--- a/superset/datasets/dao.py
+++ b/superset/datasets/dao.py
@@ -388,6 +388,14 @@ class DatasetDAO(BaseDAO):  # pylint: 
disable=too-many-public-methods
                 db.session.rollback()
             raise ex
 
+    @staticmethod
+    def get_table_by_name(database_id: int, table_name: str) -> 
Optional[SqlaTable]:
+        return (
+            db.session.query(SqlaTable)
+            .filter_by(database_id=database_id, table_name=table_name)
+            .one_or_none()
+        )
+
 
 class DatasetColumnDAO(BaseDAO):
     model_cls = TableColumn
diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py
index 223324da3a..103359a2c3 100644
--- a/superset/datasets/schemas.py
+++ b/superset/datasets/schemas.py
@@ -228,6 +228,17 @@ class ImportV1DatasetSchema(Schema):
     external_url = fields.String(allow_none=True)
 
 
+class GetOrCreateDatasetSchema(Schema):
+    table_name = fields.String(required=True, description="Name of table")
+    database_id = fields.Integer(
+        required=True, description="ID of database table belongs to"
+    )
+    schema = fields.String(
+        description="The schema the table belongs to", allow_none=True
+    )
+    template_params = fields.String(description="Template params for the 
table")
+
+
 class DatasetSchema(SQLAlchemyAutoSchema):
     """
     Schema for the ``Dataset`` model.
diff --git a/superset/views/base.py b/superset/views/base.py
index 0d69f1482f..f6651e5c74 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -299,12 +299,6 @@ def validate_sqlatable(table: models.SqlaTable) -> None:
         ) from ex
 
 
-def create_table_permissions(table: models.SqlaTable) -> None:
-    security_manager.add_permission_view_menu("datasource_access", 
table.get_perm())
-    if table.schema:
-        security_manager.add_permission_view_menu("schema_access", 
table.schema_perm)
-
-
 class BaseSupersetView(BaseView):
     @staticmethod
     def json_response(obj: Any, status: int = 200) -> FlaskResponse:
diff --git a/superset/views/core.py b/superset/views/core.py
index f1603837bb..fb371c209e 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -142,7 +142,6 @@ from superset.views.base import (
     api,
     BaseSupersetView,
     common_bootstrap_payload,
-    create_table_permissions,
     CsvResponse,
     data_payload_response,
     deprecated,
@@ -1927,6 +1926,7 @@ class Superset(BaseSupersetView):  # pylint: 
disable=too-many-public-methods
     @has_access
     @expose("/get_or_create_table/", methods=["POST"])
     @event_logger.log_this
+    @deprecated()
     def sqllab_table_viz(self) -> FlaskResponse:  # pylint: disable=no-self-use
         """Gets or creates a table object with attributes passed to the API.
 
@@ -1956,11 +1956,11 @@ class Superset(BaseSupersetView):  # pylint: 
disable=too-many-public-methods
                 table.schema = data.get("schema")
                 table.template_params = data.get("templateParams")
                 # needed for the table validation.
+                # fn can be deleted when this endpoint is removed
                 validate_sqlatable(table)
 
             db.session.add(table)
             table.fetch_metadata()
-            create_table_permissions(table)
             db.session.commit()
 
         return json_success(json.dumps({"table_id": table.id}))
diff --git a/tests/integration_tests/datasets/api_tests.py 
b/tests/integration_tests/datasets/api_tests.py
index 95236af090..6e0551bd9f 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -34,6 +34,7 @@ from superset.dao.exceptions import (
     DAODeleteFailedError,
     DAOUpdateFailedError,
 )
+from superset.datasets.commands.exceptions import DatasetCreateFailedError
 from superset.datasets.models import Dataset
 from superset.extensions import db, security_manager
 from superset.models.core import Database
@@ -474,6 +475,7 @@ class TestDatasetApi(SupersetTestCase):
             "can_write",
             "can_export",
             "can_duplicate",
+            "can_get_or_create_dataset",
         }
 
     def test_create_dataset_item(self):
@@ -2302,3 +2304,90 @@ class TestDatasetApi(SupersetTestCase):
         }
         rv = self.post_assert_metric(uri, table_data, "duplicate")
         assert rv.status_code == 422
+
+    @pytest.mark.usefixtures("app_context", "virtual_dataset")
+    def test_get_or_create_dataset_already_exists(self):
+        """
+        Dataset API: Test get or create endpoint when table already exists
+        """
+        self.login(username="admin")
+        rv = self.client.post(
+            "api/v1/dataset/get_or_create/",
+            json={
+                "table_name": "virtual_dataset",
+                "database_id": get_example_database().id,
+            },
+        )
+        self.assertEqual(rv.status_code, 200)
+        response = json.loads(rv.data.decode("utf-8"))
+        dataset = (
+            db.session.query(SqlaTable)
+            .filter(SqlaTable.table_name == "virtual_dataset")
+            .one()
+        )
+        self.assertEqual(response["result"], {"table_id": dataset.id})
+
+    def test_get_or_create_dataset_database_not_found(self):
+        """
+        Dataset API: Test get or create endpoint when database doesn't exist
+        """
+        self.login(username="admin")
+        rv = self.client.post(
+            "api/v1/dataset/get_or_create/",
+            json={"table_name": "virtual_dataset", "database_id": 999},
+        )
+        self.assertEqual(rv.status_code, 422)
+        response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(response["message"], {"database": ["Database does not 
exist"]})
+
+    @patch("superset.datasets.commands.create.CreateDatasetCommand.run")
+    def test_get_or_create_dataset_create_fails(self, command_run_mock):
+        """
+        Dataset API: Test get or create endpoint when create fails
+        """
+        command_run_mock.side_effect = DatasetCreateFailedError
+        self.login(username="admin")
+        rv = self.client.post(
+            "api/v1/dataset/get_or_create/",
+            json={
+                "table_name": "virtual_dataset",
+                "database_id": get_example_database().id,
+            },
+        )
+        self.assertEqual(rv.status_code, 422)
+        response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(response["message"], "Dataset could not be created.")
+
+    def test_get_or_create_dataset_creates_table(self):
+        """
+        Dataset API: Test get or create endpoint when table is created
+        """
+        self.login(username="admin")
+
+        examples_db = get_example_database()
+        with examples_db.get_sqla_engine_with_context() as engine:
+            engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api")
+            engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 
2 as col")
+
+        rv = self.client.post(
+            "api/v1/dataset/get_or_create/",
+            json={
+                "table_name": "test_create_sqla_table_api",
+                "database_id": examples_db.id,
+                "template_params": '{"param": 1}',
+            },
+        )
+        self.assertEqual(rv.status_code, 200)
+        response = json.loads(rv.data.decode("utf-8"))
+        table = (
+            db.session.query(SqlaTable)
+            .filter_by(table_name="test_create_sqla_table_api")
+            .one()
+        )
+        self.assertEqual(response["result"], {"table_id": table.id})
+        self.assertEqual(table.template_params, '{"param": 1}')
+
+        db.session.delete(table)
+        with examples_db.get_sqla_engine_with_context() as engine:
+            engine.execute("DROP TABLE test_create_sqla_table_api")
+        db.session.commit()
diff --git a/tests/integration_tests/datasets/commands_tests.py 
b/tests/integration_tests/datasets/commands_tests.py
index 5cc5c85bea..0ce98477a0 100644
--- a/tests/integration_tests/datasets/commands_tests.py
+++ b/tests/integration_tests/datasets/commands_tests.py
@@ -20,13 +20,18 @@ from unittest.mock import patch
 
 import pytest
 import yaml
+from sqlalchemy.exc import SQLAlchemyError
 
 from superset import db, security_manager
 from superset.commands.exceptions import CommandInvalidError
 from superset.commands.importers.exceptions import IncorrectVersionError
 from superset.connectors.sqla.models import SqlaTable
 from superset.databases.commands.importers.v1 import ImportDatabasesCommand
-from superset.datasets.commands.exceptions import DatasetNotFoundError
+from superset.datasets.commands.create import CreateDatasetCommand
+from superset.datasets.commands.exceptions import (
+    DatasetInvalidError,
+    DatasetNotFoundError,
+)
 from superset.datasets.commands.export import ExportDatasetsCommand
 from superset.datasets.commands.importers import v0, v1
 from superset.models.core import Database
@@ -519,3 +524,47 @@ def _get_table_from_list_by_name(name: str, tables: 
List[Any]):
         if table.table_name == name:
             return table
     raise ValueError(f"Table {name} does not exists in database")
+
+
+class TestCreateDatasetCommand(SupersetTestCase):
+    def test_database_not_found(self):
+        self.login(username="admin")
+        with self.assertRaises(DatasetInvalidError):
+            CreateDatasetCommand({"table_name": "table", "database": 
9999}).run()
+
+    @patch("superset.models.core.Database.get_table")
+    def test_get_table_from_database_error(self, get_table_mock):
+        self.login(username="admin")
+        get_table_mock.side_effect = SQLAlchemyError
+        with self.assertRaises(DatasetInvalidError):
+            CreateDatasetCommand(
+                {"table_name": "table", "database": get_example_database().id}
+            ).run()
+
+    @patch("superset.security.manager.g")
+    @patch("superset.commands.utils.g")
+    def test_create_dataset_command(self, mock_g, mock_g2):
+        mock_g.user = security_manager.find_user("admin")
+        mock_g2.user = mock_g.user
+        examples_db = get_example_database()
+        with examples_db.get_sqla_engine_with_context() as engine:
+            engine.execute("DROP TABLE IF EXISTS test_create_dataset_command")
+            engine.execute(
+                "CREATE TABLE test_create_dataset_command AS SELECT 2 as col"
+            )
+
+        table = CreateDatasetCommand(
+            {"table_name": "test_create_dataset_command", "database": 
examples_db.id}
+        ).run()
+        fetched_table = (
+            db.session.query(SqlaTable)
+            .filter_by(table_name="test_create_dataset_command")
+            .one()
+        )
+        self.assertEqual(table, fetched_table)
+        self.assertEqual([owner.username for owner in table.owners], ["admin"])
+
+        db.session.delete(table)
+        with examples_db.get_sqla_engine_with_context() as engine:
+            engine.execute("DROP TABLE test_create_dataset_command")
+        db.session.commit()

Reply via email to