This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new a3a2a68  feat: API endpoint to import charts (#11744)
a3a2a68 is described below

commit a3a2a68f019482d4d19af843213d34e515b50cf9
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Nov 20 14:40:27 2020 -0800

    feat: API endpoint to import charts (#11744)
    
    * ImportChartsCommand
    
    * feat: API endpoint to import charts
    
    * Add dispatcher
    
    * Fix docstring
---
 superset/charts/api.py                           | 56 ++++++++++++++
 superset/charts/commands/importers/dispatcher.py | 70 ++++++++++++++++++
 tests/charts/api_tests.py                        | 94 ++++++++++++++++++++++--
 tests/databases/api_tests.py                     |  4 +-
 tests/datasets/api_tests.py                      | 12 +--
 5 files changed, 223 insertions(+), 13 deletions(-)

diff --git a/superset/charts/api.py b/superset/charts/api.py
index 2327ed8..8e8789d 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -44,6 +44,7 @@ from superset.charts.commands.exceptions import (
     ChartUpdateFailedError,
 )
 from superset.charts.commands.export import ExportChartsCommand
+from superset.charts.commands.importers.dispatcher import ImportChartsCommand
 from superset.charts.commands.update import UpdateChartCommand
 from superset.charts.dao import ChartDAO
 from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, 
ChartFilter
@@ -59,6 +60,7 @@ from superset.charts.schemas import (
     screenshot_query_schema,
     thumbnail_query_schema,
 )
+from superset.commands.exceptions import CommandInvalidError
 from superset.constants import RouteMethod
 from superset.exceptions import SupersetSecurityException
 from superset.extensions import event_logger
@@ -86,6 +88,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
 
     include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
         RouteMethod.EXPORT,
+        RouteMethod.IMPORT,
         RouteMethod.RELATED,
         "bulk_delete",  # not using RouteMethod since locally defined
         "data",
@@ -823,3 +826,56 @@ class ChartRestApi(BaseSupersetModelRestApi):
             for request_id in requested_ids
         ]
         return self.response(200, result=res)
+
+    @expose("/import/", methods=["POST"])
+    @protect()
+    @safe
+    @statsd_metrics
+    def import_(self) -> Response:
+        """Import chart(s) with associated datasets and databases
+        ---
+        post:
+          requestBody:
+            content:
+              application/zip:
+                schema:
+                  type: string
+                  format: binary
+          responses:
+            200:
+              description: Chart import result
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      message:
+                        type: string
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            422:
+              $ref: '#/components/responses/422'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        upload = request.files.get("file")
+        if not upload:
+            return self.response_400()
+        with ZipFile(upload) as bundle:
+            contents = {
+                file_name: bundle.read(file_name).decode()
+                for file_name in bundle.namelist()
+            }
+
+        command = ImportChartsCommand(contents)
+        try:
+            command.run()
+            return self.response(200, message="OK")
+        except CommandInvalidError as exc:
+            logger.warning("Import chart failed")
+            return self.response_422(message=exc.normalized_messages())
+        except Exception as exc:  # pylint: disable=broad-except
+            logger.exception("Import chart failed")
+            return self.response_500(message=str(exc))
diff --git a/superset/charts/commands/importers/dispatcher.py 
b/superset/charts/commands/importers/dispatcher.py
new file mode 100644
index 0000000..e7f0149
--- /dev/null
+++ b/superset/charts/commands/importers/dispatcher.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+from typing import Any, Dict
+
+from marshmallow.exceptions import ValidationError
+
+from superset.charts.commands.importers import v1
+from superset.commands.base import BaseCommand
+from superset.commands.exceptions import CommandInvalidError
+from superset.commands.importers.exceptions import IncorrectVersionError
+
+logger = logging.getLogger(__name__)
+
+command_versions = [
+    v1.ImportChartsCommand,
+]
+
+
+class ImportChartsCommand(BaseCommand):
+    """
+    Import charts.
+
+    This command dispatches the import to different versions of the command
+    until it finds one that matches.
+    """
+
+    # pylint: disable=unused-argument
+    def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+        self.contents = contents
+
+    def run(self) -> None:
+        # iterate over all commands until we find a version that can
+        # handle the contents
+        for version in command_versions:
+            command = version(self.contents)
+            try:
+                command.run()
+                return
+            except IncorrectVersionError:
+                # file is not handled by command, skip
+                pass
+            except (CommandInvalidError, ValidationError) as exc:
+                # found right version, but file is invalid
+                logger.info("Command failed validation")
+                raise exc
+            except Exception as exc:
+                # validation succeeded but something went wrong
+                logger.exception("Error running import command")
+                raise exc
+
+        raise CommandInvalidError("Could not find a valid command to import 
file")
+
+    def validate(self) -> None:
+        pass
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 00decdc..3acfdb9 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -21,11 +21,12 @@ from typing import List, Optional
 from datetime import datetime
 from io import BytesIO
 from unittest import mock
-from zipfile import is_zipfile
+from zipfile import is_zipfile, ZipFile
 
 import humanize
 import prison
 import pytest
+import yaml
 from sqlalchemy import and_
 from sqlalchemy.sql import func
 
@@ -35,12 +36,19 @@ from tests.fixtures.unicode_dashboard import 
load_unicode_dashboard_with_slice
 from tests.test_app import app
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.extensions import db, security_manager
-from superset.models.core import FavStar, FavStarClassName
+from superset.models.core import Database, FavStar, FavStarClassName
 from superset.models.dashboard import Dashboard
 from superset.models.slice import Slice
 from superset.utils import core as utils
 from tests.base_api_tests import ApiOwnersTestCaseMixin
 from tests.base_tests import SupersetTestCase
+from tests.fixtures.importexport import (
+    chart_config,
+    chart_metadata_config,
+    database_config,
+    dataset_config,
+    dataset_metadata_config,
+)
 from tests.fixtures.query_context import get_query_context
 
 CHART_DATA_URI = "api/v1/chart/data"
@@ -1131,7 +1139,7 @@ class TestChartApi(SupersetTestCase, 
ApiOwnersTestCaseMixin):
 
     def test_export_chart(self):
         """
-        Chart API: Test export dataset
+        Chart API: Test export chart
         """
         example_chart = db.session.query(Slice).all()[0]
         argument = [example_chart.id]
@@ -1147,7 +1155,7 @@ class TestChartApi(SupersetTestCase, 
ApiOwnersTestCaseMixin):
 
     def test_export_chart_not_found(self):
         """
-        Dataset API: Test export dataset not found
+        Chart API: Test export chart not found
         """
         # Just one does not exist and we get 404
         argument = [-1, 1]
@@ -1159,7 +1167,7 @@ class TestChartApi(SupersetTestCase, 
ApiOwnersTestCaseMixin):
 
     def test_export_chart_gamma(self):
         """
-        Dataset API: Test export dataset has gamma
+        Chart API: Test export chart has gamma
         """
         example_chart = db.session.query(Slice).all()[0]
         argument = [example_chart.id]
@@ -1169,3 +1177,79 @@ class TestChartApi(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         rv = self.client.get(uri)
 
         assert rv.status_code == 404
+
+    def test_import_chart(self):
+        """
+        Chart API: Test import chart
+        """
+        self.login(username="admin")
+        uri = "api/v1/chart/import/"
+
+        buf = BytesIO()
+        with ZipFile(buf, "w") as bundle:
+            with bundle.open("metadata.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(chart_metadata_config).encode())
+            with bundle.open("databases/imported_database.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(database_config).encode())
+            with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(dataset_config).encode())
+            with bundle.open("charts/imported_chart.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(chart_config).encode())
+        buf.seek(0)
+
+        form_data = {
+            "file": (buf, "chart_export.zip"),
+        }
+        rv = self.client.post(uri, data=form_data, 
content_type="multipart/form-data")
+        response = json.loads(rv.data.decode("utf-8"))
+
+        assert rv.status_code == 200
+        assert response == {"message": "OK"}
+
+        database = (
+            
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+        )
+        assert database.database_name == "imported_database"
+
+        assert len(database.tables) == 1
+        dataset = database.tables[0]
+        assert dataset.table_name == "imported_dataset"
+        assert str(dataset.uuid) == dataset_config["uuid"]
+
+        chart = 
db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one()
+        assert chart.table == dataset
+
+        db.session.delete(chart)
+        db.session.delete(dataset)
+        db.session.delete(database)
+        db.session.commit()
+
+    def test_import_chart_invalid(self):
+        """
+        Chart API: Test import invalid chart
+        """
+        self.login(username="admin")
+        uri = "api/v1/chart/import/"
+
+        buf = BytesIO()
+        with ZipFile(buf, "w") as bundle:
+            with bundle.open("metadata.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(dataset_metadata_config).encode())
+            with bundle.open("databases/imported_database.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(database_config).encode())
+            with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(dataset_config).encode())
+            with bundle.open("charts/imported_chart.yaml", "w") as fp:
+                fp.write(yaml.safe_dump(chart_config).encode())
+        buf.seek(0)
+
+        form_data = {
+            "file": (buf, "chart_export.zip"),
+        }
+        rv = self.client.post(uri, data=form_data, 
content_type="multipart/form-data")
+        response = json.loads(rv.data.decode("utf-8"))
+
+        assert rv.status_code == 422
+        assert response == {
+            "message": {"metadata.yaml": {"type": ["Must be equal to Slice."]}}
+        }
diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py
index e38b64a..88213b6 100644
--- a/tests/databases/api_tests.py
+++ b/tests/databases/api_tests.py
@@ -840,7 +840,7 @@ class TestDatabaseApi(SupersetTestCase):
                 fp.write(yaml.safe_dump(database_metadata_config).encode())
             with bundle.open("databases/imported_database.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(database_config).encode())
-            with bundle.open("datasets/import_dataset.yaml", "w") as fp:
+            with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(dataset_config).encode())
         buf.seek(0)
 
@@ -880,7 +880,7 @@ class TestDatabaseApi(SupersetTestCase):
                 fp.write(yaml.safe_dump(dataset_metadata_config).encode())
             with bundle.open("databases/imported_database.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(database_config).encode())
-            with bundle.open("datasets/import_dataset.yaml", "w") as fp:
+            with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(dataset_config).encode())
         buf.seek(0)
 
diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py
index 59854c4..ea16c61 100644
--- a/tests/datasets/api_tests.py
+++ b/tests/datasets/api_tests.py
@@ -1176,7 +1176,7 @@ class TestDatasetApi(SupersetTestCase):
         for table_name in self.fixture_tables_names:
             assert table_name in [ds["table_name"] for ds in data["result"]]
 
-    def test_import_dataset(self):
+    def test_imported_dataset(self):
         """
         Dataset API: Test import dataset
         """
@@ -1189,7 +1189,7 @@ class TestDatasetApi(SupersetTestCase):
                 fp.write(yaml.safe_dump(dataset_metadata_config).encode())
             with bundle.open("databases/imported_database.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(database_config).encode())
-            with bundle.open("datasets/import_dataset.yaml", "w") as fp:
+            with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(dataset_config).encode())
         buf.seek(0)
 
@@ -1216,7 +1216,7 @@ class TestDatasetApi(SupersetTestCase):
         db.session.delete(database)
         db.session.commit()
 
-    def test_import_dataset_invalid(self):
+    def test_imported_dataset_invalid(self):
         """
         Dataset API: Test import invalid dataset
         """
@@ -1229,7 +1229,7 @@ class TestDatasetApi(SupersetTestCase):
                 fp.write(yaml.safe_dump(database_metadata_config).encode())
             with bundle.open("databases/imported_database.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(database_config).encode())
-            with bundle.open("datasets/import_dataset.yaml", "w") as fp:
+            with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(dataset_config).encode())
         buf.seek(0)
 
@@ -1244,7 +1244,7 @@ class TestDatasetApi(SupersetTestCase):
             "message": {"metadata.yaml": {"type": ["Must be equal to 
SqlaTable."]}}
         }
 
-    def test_import_dataset_invalid_v0_validation(self):
+    def test_imported_dataset_invalid_v0_validation(self):
         """
         Dataset API: Test import invalid dataset
         """
@@ -1255,7 +1255,7 @@ class TestDatasetApi(SupersetTestCase):
         with ZipFile(buf, "w") as bundle:
             with bundle.open("databases/imported_database.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(database_config).encode())
-            with bundle.open("datasets/import_dataset.yaml", "w") as fp:
+            with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
                 fp.write(yaml.safe_dump(dataset_config).encode())
         buf.seek(0)
 

Reply via email to