This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e5a4c78dc7 chore: Migrate /superset/csv/<client_id> to API v1 (#22913)
e5a4c78dc7 is described below

commit e5a4c78dc7c7bd9ed3b0e201283d726d4eacd0d9
Author: Diego Medina <[email protected]>
AuthorDate: Wed Feb 15 07:48:24 2023 -0300

    chore: Migrate /superset/csv/<client_id> to API v1 (#22913)
---
 docs/static/resources/openapi.json                 |  90 ++++++--
 .../src/SqlLab/components/ResultSet/index.tsx      |   5 +-
 superset/sqllab/api.py                             |  67 +++++-
 superset/sqllab/commands/export.py                 | 136 +++++++++++
 superset/views/core.py                             |   1 +
 tests/integration_tests/sql_lab/api_tests.py       |  39 +++-
 tests/integration_tests/sql_lab/commands_tests.py  | 251 ++++++++++++++++-----
 7 files changed, 518 insertions(+), 71 deletions(-)

diff --git a/docs/static/resources/openapi.json 
b/docs/static/resources/openapi.json
index cc92f091e2..18ea7a47f8 100644
--- a/docs/static/resources/openapi.json
+++ b/docs/static/resources/openapi.json
@@ -345,7 +345,7 @@
       "AnnotationLayerRestApi.get_list": {
         "properties": {
           "changed_by": {
-            "$ref": 
"#/components/schemas/AnnotationLayerRestApi.get_list.User1"
+            "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User"
           },
           "changed_on": {
             "format": "date-time",
@@ -356,7 +356,7 @@
             "readOnly": true
           },
           "created_by": {
-            "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User"
+            "$ref": 
"#/components/schemas/AnnotationLayerRestApi.get_list.User1"
           },
           "created_on": {
             "format": "date-time",
@@ -502,13 +502,13 @@
       "AnnotationRestApi.get_list": {
         "properties": {
           "changed_by": {
-            "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1"
+            "$ref": "#/components/schemas/AnnotationRestApi.get_list.User"
           },
           "changed_on_delta_humanized": {
             "readOnly": true
           },
           "created_by": {
-            "$ref": "#/components/schemas/AnnotationRestApi.get_list.User"
+            "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1"
           },
           "end_dttm": {
             "format": "date-time",
@@ -1223,6 +1223,7 @@
             "example": false
           },
           "periods": {
+            "description": "Time periods (in units of `time_grain`) to predict 
into the future",
             "example": 7,
             "format": "int32",
             "type": "integer"
@@ -1578,6 +1579,7 @@
             "type": "string"
           },
           "from_dttm": {
+            "description": "Start timestamp of time range",
             "format": "int32",
             "nullable": true,
             "type": "integer"
@@ -1603,6 +1605,7 @@
             "type": "integer"
           },
           "stacktrace": {
+            "description": "Stacktrace if there was an error",
             "nullable": true,
             "type": "string"
           },
@@ -1620,6 +1623,7 @@
             "type": "string"
           },
           "to_dttm": {
+            "description": "End timestamp of time range",
             "format": "int32",
             "nullable": true,
             "type": "integer"
@@ -2232,6 +2236,7 @@
             "type": "string"
           },
           "rolling_type_options": {
+            "description": "Optional options to pass to rolling method. Needed 
for e.g. quantile operation.",
             "example": {},
             "type": "object"
           },
@@ -3027,13 +3032,13 @@
       "CssTemplateRestApi.get_list": {
         "properties": {
           "changed_by": {
-            "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1"
+            "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User"
           },
           "changed_on_delta_humanized": {
             "readOnly": true
           },
           "created_by": {
-            "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User"
+            "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1"
           },
           "created_on": {
             "format": "date-time",
@@ -5230,7 +5235,7 @@
       "DatasetRestApi.get_list": {
         "properties": {
           "changed_by": {
-            "$ref": "#/components/schemas/DatasetRestApi.get_list.User1"
+            "$ref": "#/components/schemas/DatasetRestApi.get_list.User"
           },
           "changed_by_name": {
             "readOnly": true
@@ -5273,7 +5278,7 @@
             "readOnly": true
           },
           "owners": {
-            "$ref": "#/components/schemas/DatasetRestApi.get_list.User"
+            "$ref": "#/components/schemas/DatasetRestApi.get_list.User1"
           },
           "schema": {
             "maxLength": 255,
@@ -5317,14 +5322,6 @@
             "maxLength": 64,
             "type": "string"
           },
-          "id": {
-            "format": "int32",
-            "type": "integer"
-          },
-          "last_name": {
-            "maxLength": 64,
-            "type": "string"
-          },
           "username": {
             "maxLength": 64,
             "type": "string"
@@ -5332,7 +5329,6 @@
         },
         "required": [
           "first_name",
-          "last_name",
           "username"
         ],
         "type": "object"
@@ -5343,6 +5339,14 @@
             "maxLength": 64,
             "type": "string"
           },
+          "id": {
+            "format": "int32",
+            "type": "integer"
+          },
+          "last_name": {
+            "maxLength": 64,
+            "type": "string"
+          },
           "username": {
             "maxLength": 64,
             "type": "string"
@@ -5350,6 +5354,7 @@
         },
         "required": [
           "first_name",
+          "last_name",
           "username"
         ],
         "type": "object"
@@ -19997,6 +20002,57 @@
         ]
       }
     },
+    "/api/v1/sqllab/export/{client_id}/": {
+      "get": {
+        "parameters": [
+          {
+            "description": "The SQL query result identifier",
+            "in": "path",
+            "name": "client_id",
+            "required": true,
+            "schema": {
+              "type": "integer"
+            }
+          }
+        ],
+        "responses": {
+          "200": {
+            "content": {
+              "text/csv": {
+                "schema": {
+                  "type": "string"
+                }
+              }
+            },
+            "description": "SQL query results"
+          },
+          "400": {
+            "$ref": "#/components/responses/400"
+          },
+          "401": {
+            "$ref": "#/components/responses/401"
+          },
+          "403": {
+            "$ref": "#/components/responses/403"
+          },
+          "404": {
+            "$ref": "#/components/responses/404"
+          },
+          "500": {
+            "$ref": "#/components/responses/500"
+          }
+        },
+        "security": [
+          {
+            "jwt": []
+          }
+        ],
+        "summary": "Exports the SQL query results to a CSV",
+        "tags": [
+          "SQL Lab"
+        ]
+      }
+    },
     "/api/v1/sqllab/results/": {
       "get": {
         "parameters": [
diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx 
b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
index 62912d66a2..fad6c98bc9 100644
--- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
+++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
@@ -219,6 +219,9 @@ const ResultSet = ({
     }
   };
 
+  const getExportCsvUrl = (clientId: string) =>
+    `/api/v1/sqllab/export/${clientId}/`;
+
   const renderControls = () => {
     if (search || visualize || csv) {
       let { data } = query.results;
@@ -257,7 +260,7 @@ const ResultSet = ({
               />
             )}
             {csv && (
-              <Button buttonSize="small" href={`/superset/csv/${query.id}`}>
+              <Button buttonSize="small" href={getExportCsvUrl(query.id)}>
                 <i className="fa fa-file-text-o" /> {t('Download to CSV')}
               </Button>
             )}
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 283c3ab638..f73ef749d4 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -16,6 +16,7 @@
 # under the License.
 import logging
 from typing import Any, cast, Dict, Optional
+from urllib import parse
 
 import simplejson as json
 from flask import request
@@ -32,6 +33,7 @@ from superset.queries.dao import QueryDAO
 from superset.sql_lab import get_sql_results
 from superset.sqllab.command_status import SqlJsonExecutionStatus
 from superset.sqllab.commands.execute import CommandResult, ExecuteSqlCommand
+from superset.sqllab.commands.export import SqlResultExportCommand
 from superset.sqllab.commands.results import SqlExecutionResultsCommand
 from superset.sqllab.exceptions import (
     QueryIsForbiddenToAccessException,
@@ -53,7 +55,7 @@ from superset.sqllab.sqllab_execution_context import 
SqlJsonExecutionContext
 from superset.sqllab.validators import CanAccessQueryValidatorImpl
 from superset.superset_typing import FlaskResponse
 from superset.utils import core as utils
-from superset.views.base import json_success
+from superset.views.base import CsvResponse, generate_download_headers, 
json_success
 from superset.views.base_api import BaseSupersetApi, requires_json, 
statsd_metrics
 
 config = app.config
@@ -79,6 +81,69 @@ class SqlLabRestApi(BaseSupersetApi):
         QueryExecutionResponseSchema,
     )
 
+    @expose("/export/<string:client_id>/")
+    @protect()
+    @statsd_metrics
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+        f".export_csv",
+        log_to_statsd=False,
+    )
+    def export_csv(self, client_id: str) -> CsvResponse:
+        """Exports the SQL query results to a CSV
+        ---
+        get:
+          summary: >-
+            Exports the SQL query results to a CSV
+          parameters:
+          - in: path
+            schema:
+              type: integer
+            name: client_id
+            description: The SQL query result identifier
+          responses:
+            200:
+              description: SQL query results
+              content:
+                text/csv:
+                  schema:
+                    type: string
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            403:
+              $ref: '#/components/responses/403'
+            404:
+              $ref: '#/components/responses/404'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        result = SqlResultExportCommand(client_id=client_id).run()
+
+        query = result.get("query")
+        data = result.get("data")
+        row_count = result.get("count")
+
+        quoted_csv_name = parse.quote(query.name)
+        response = CsvResponse(
+            data, headers=generate_download_headers("csv", quoted_csv_name)
+        )
+        event_info = {
+            "event_type": "data_export",
+            "client_id": client_id,
+            "row_count": row_count,
+            "database": query.database.name,
+            "schema": query.schema,
+            "sql": query.sql,
+            "exported_format": "csv",
+        }
+        event_rep = repr(event_info)
+        logger.debug(
+            "CSV exported: %s", event_rep, extra={"superset_event": event_info}
+        )
+        return response
+
     @expose("/results/")
     @protect()
     @statsd_metrics
diff --git a/superset/sqllab/commands/export.py 
b/superset/sqllab/commands/export.py
new file mode 100644
index 0000000000..feca664225
--- /dev/null
+++ b/superset/sqllab/commands/export.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-few-public-methods, too-many-arguments
+from __future__ import annotations
+
+import logging
+from typing import Any, cast, List, TypedDict
+
+import pandas as pd
+from flask_babel import gettext as __, lazy_gettext as _
+
+from superset import app, db, results_backend, results_backend_use_msgpack
+from superset.commands.base import BaseCommand
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import SupersetErrorException, 
SupersetSecurityException
+from superset.models.sql_lab import Query
+from superset.sql_parse import ParsedQuery
+from superset.sqllab.limiting_factor import LimitingFactor
+from superset.utils import core as utils, csv
+from superset.utils.dates import now_as_float
+from superset.views.utils import _deserialize_results_payload
+
+config = app.config
+
+logger = logging.getLogger(__name__)
+
+
+class SqlExportResult(TypedDict):
+    query: Query
+    count: int
+    data: List[Any]
+
+
+class SqlResultExportCommand(BaseCommand):
+    _client_id: str
+    _query: Query
+
+    def __init__(
+        self,
+        client_id: str,
+    ) -> None:
+        self._client_id = client_id
+
+    def validate(self) -> None:
+        self._query = (
+            
db.session.query(Query).filter_by(client_id=self._client_id).one_or_none()
+        )
+        if self._query is None:
+            raise SupersetErrorException(
+                SupersetError(
+                    message=__(
+                        "The query associated with these results could not be 
found. "
+                        "You need to re-run the original query."
+                    ),
+                    error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
+                    level=ErrorLevel.ERROR,
+                ),
+                status=404,
+            )
+
+        try:
+            self._query.raise_for_access()
+        except SupersetSecurityException:
+            raise SupersetErrorException(
+                SupersetError(
+                    message=__("Cannot access the query"),
+                    error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
+                    level=ErrorLevel.ERROR,
+                ),
+                status=403,
+            )
+
+    def run(
+        self,
+    ) -> SqlExportResult:
+        self.validate()
+        blob = None
+        if results_backend and self._query.results_key:
+            logger.info(
+                "Fetching CSV from results backend [%s]", 
self._query.results_key
+            )
+            blob = results_backend.get(self._query.results_key)
+        if blob:
+            logger.info("Decompressing")
+            payload = utils.zlib_decompress(
+                blob, decode=not results_backend_use_msgpack
+            )
+            obj = _deserialize_results_payload(
+                payload, self._query, cast(bool, results_backend_use_msgpack)
+            )
+
+            df = pd.DataFrame(
+                data=obj["data"],
+                dtype=object,
+                columns=[c["name"] for c in obj["columns"]],
+            )
+
+            logger.info("Using pandas to convert to CSV")
+        else:
+            logger.info("Running a query to turn into CSV")
+            if self._query.select_sql:
+                sql = self._query.select_sql
+                limit = None
+            else:
+                sql = self._query.executed_sql
+                limit = ParsedQuery(sql).limit
+            if limit is not None and self._query.limiting_factor in {
+                LimitingFactor.QUERY,
+                LimitingFactor.DROPDOWN,
+                LimitingFactor.QUERY_AND_DROPDOWN,
+            }:
+                # remove extra row from `increased_limit`
+                limit -= 1
+            df = self._query.database.get_df(sql, self._query.schema)[:limit]
+
+        csv_data = csv.df_to_escaped_csv(df, index=False, 
**config["CSV_EXPORT"])
+
+        return {
+            "query": self._query,
+            "count": len(df.index),
+            "data": csv_data,
+        }
diff --git a/superset/views/core.py b/superset/views/core.py
index fb371c209e..cad8483bfd 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -2401,6 +2401,7 @@ class Superset(BaseSupersetView):  # pylint: 
disable=too-many-public-methods
     @has_access
     @event_logger.log_this
     @expose("/csv/<client_id>")
+    @deprecated()
     def csv(self, client_id: str) -> FlaskResponse:  # pylint: 
disable=no-self-use
         """Download the query results as csv."""
         logger.info("Exporting CSV file [%s]", client_id)
diff --git a/tests/integration_tests/sql_lab/api_tests.py 
b/tests/integration_tests/sql_lab/api_tests.py
index 4c2080ad4c..93beb380f0 100644
--- a/tests/integration_tests/sql_lab/api_tests.py
+++ b/tests/integration_tests/sql_lab/api_tests.py
@@ -19,6 +19,9 @@
 import datetime
 import json
 import random
+import csv
+import pandas as pd
+import io
 
 import pytest
 import prison
@@ -26,7 +29,7 @@ from sqlalchemy.sql import func
 from unittest import mock
 
 from tests.integration_tests.test_app import app
-from superset import sql_lab
+from superset import db, sql_lab
 from superset.common.db_query_status import QueryStatus
 from superset.models.core import Database
 from superset.utils.database import get_example_database, get_main_database
@@ -176,3 +179,37 @@ class TestSqlLabApi(SupersetTestCase):
         self.assertEqual(result_limited, expected_limited)
 
         app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack
+
+    @mock.patch("superset.models.sql_lab.Query.raise_for_access", lambda _: 
None)
+    @mock.patch("superset.models.core.Database.get_df")
+    def test_export_results(self, get_df_mock: mock.Mock) -> None:
+        self.login()
+
+        database = get_example_database()
+        query_obj = Query(
+            client_id="test",
+            database=database,
+            tab_name="test_tab",
+            sql_editor_id="test_editor_id",
+            sql="select * from bar",
+            select_sql=None,
+            executed_sql="select * from bar limit 2",
+            limit=100,
+            select_as_cta=False,
+            rows=104,
+            error_message="none",
+            results_key="test_abc",
+        )
+
+        db.session.add(query_obj)
+        db.session.commit()
+
+        get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]})
+
+        resp = self.get_resp("/api/v1/sqllab/export/test/")
+        data = csv.reader(io.StringIO(resp))
+        expected_data = csv.reader(io.StringIO("foo\n1\n2"))
+
+        self.assertEqual(list(expected_data), list(data))
+        db.session.delete(query_obj)
+        db.session.commit()
diff --git a/tests/integration_tests/sql_lab/commands_tests.py 
b/tests/integration_tests/sql_lab/commands_tests.py
index 74c1fe7082..edb7155237 100644
--- a/tests/integration_tests/sql_lab/commands_tests.py
+++ b/tests/integration_tests/sql_lab/commands_tests.py
@@ -15,23 +15,208 @@
 # specific language governing permissions and limitations
 # under the License.
 from unittest import mock, skip
-from unittest.mock import patch
+from unittest.mock import Mock, patch
 
+import pandas as pd
 import pytest
 
 from superset import db, sql_lab
 from superset.common.db_query_status import QueryStatus
-from superset.errors import SupersetErrorType
-from superset.exceptions import SerializationError, SupersetErrorException
+from superset.errors import ErrorLevel, SupersetErrorType
+from superset.exceptions import (
+    SerializationError,
+    SupersetError,
+    SupersetErrorException,
+    SupersetSecurityException,
+)
 from superset.models.core import Database
 from superset.models.sql_lab import Query
-from superset.sqllab.commands import results
+from superset.sqllab.commands import export, results
+from superset.sqllab.limiting_factor import LimitingFactor
 from superset.utils import core as utils
+from superset.utils.database import get_example_database
 from tests.integration_tests.base_tests import SupersetTestCase
 
 
+class TestSqlResultExportCommand(SupersetTestCase):
+    @pytest.fixture()
+    def create_database_and_query(self):
+        with self.create_app().app_context():
+            database = get_example_database()
+            query_obj = Query(
+                client_id="test",
+                database=database,
+                tab_name="test_tab",
+                sql_editor_id="test_editor_id",
+                sql="select * from bar",
+                select_sql="select * from bar",
+                executed_sql="select * from bar",
+                limit=100,
+                select_as_cta=False,
+                rows=104,
+                error_message="none",
+                results_key="abc_query",
+            )
+
+            db.session.add(query_obj)
+            db.session.commit()
+
+            yield
+
+            db.session.delete(query_obj)
+            db.session.commit()
+
+    @pytest.mark.usefixtures("create_database_and_query")
+    def test_validation_query_not_found(self) -> None:
+        command = export.SqlResultExportCommand("asdf")
+
+        with pytest.raises(SupersetErrorException) as ex_info:
+            command.run()
+        assert ex_info.value.error.error_type == 
SupersetErrorType.RESULTS_BACKEND_ERROR
+
+    @pytest.mark.usefixtures("create_database_and_query")
+    def test_validation_invalid_access(self) -> None:
+        command = export.SqlResultExportCommand("test")
+
+        with mock.patch(
+            "superset.security_manager.raise_for_access",
+            side_effect=SupersetSecurityException(
+                SupersetError(
+                    "dummy",
+                    SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
+                    ErrorLevel.ERROR,
+                )
+            ),
+        ):
+            with pytest.raises(SupersetErrorException) as ex_info:
+                command.run()
+            assert (
+                ex_info.value.error.error_type
+                == SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR
+            )
+
+    @pytest.mark.usefixtures("create_database_and_query")
+    @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None)
+    @patch("superset.models.core.Database.get_df")
+    def test_run_no_results_backend_select_sql(self, get_df_mock: Mock) -> 
None:
+        command = export.SqlResultExportCommand("test")
+
+        get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]})
+        result = command.run()
+
+        data = result.get("data")
+        count = result.get("count")
+        query = result.get("query")
+
+        assert data == "foo\n1\n2\n3\n"
+        assert count == 3
+        assert query.client_id == "test"
+
+    @pytest.mark.usefixtures("create_database_and_query")
+    @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None)
+    @patch("superset.models.core.Database.get_df")
+    def test_run_no_results_backend_executed_sql(self, get_df_mock: Mock) -> 
None:
+        query_obj = db.session.query(Query).filter_by(client_id="test").one()
+        query_obj.executed_sql = "select * from bar limit 2"
+        query_obj.select_sql = None
+        db.session.commit()
+
+        command = export.SqlResultExportCommand("test")
+
+        get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]})
+        result = command.run()
+
+        data = result.get("data")
+        count = result.get("count")
+        query = result.get("query")
+
+        assert data == "foo\n1\n2\n"
+        assert count == 2
+        assert query.client_id == "test"
+
+    @pytest.mark.usefixtures("create_database_and_query")
+    @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None)
+    @patch("superset.models.core.Database.get_df")
+    def test_run_no_results_backend_executed_sql_limiting_factor(
+        self, get_df_mock: Mock
+    ) -> None:
+        query_obj = 
db.session.query(Query).filter_by(results_key="abc_query").one()
+        query_obj.executed_sql = "select * from bar limit 2"
+        query_obj.select_sql = None
+        query_obj.limiting_factor = LimitingFactor.DROPDOWN
+        db.session.commit()
+
+        command = export.SqlResultExportCommand("test")
+
+        get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]})
+
+        result = command.run()
+
+        data = result.get("data")
+        count = result.get("count")
+        query = result.get("query")
+
+        assert data == "foo\n1\n"
+        assert count == 1
+        assert query.client_id == "test"
+
+    @pytest.mark.usefixtures("create_database_and_query")
+    @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None)
+    @patch("superset.sqllab.commands.export.results_backend_use_msgpack", 
False)
+    def test_run_with_results_backend(self) -> None:
+        command = export.SqlResultExportCommand("test")
+
+        data = [{"foo": i} for i in range(5)]
+        payload = {
+            "columns": [{"name": "foo"}],
+            "data": data,
+        }
+        serialized_payload = sql_lab._serialize_payload(payload, False)
+        compressed = utils.zlib_compress(serialized_payload)
+
+        export.results_backend = mock.Mock()
+        export.results_backend.get.return_value = compressed
+
+        result = command.run()
+
+        data = result.get("data")
+        count = result.get("count")
+        query = result.get("query")
+
+        assert data == "foo\n0\n1\n2\n3\n4\n"
+        assert count == 5
+        assert query.client_id == "test"
+
+
 class TestSqlExecutionResultsCommand(SupersetTestCase):
-    
@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
+    @pytest.fixture()
+    def create_database_and_query(self):
+        with self.create_app().app_context():
+            database = get_example_database()
+            query_obj = Query(
+                client_id="test",
+                database=database,
+                tab_name="test_tab",
+                sql_editor_id="test_editor_id",
+                sql="select * from bar",
+                select_sql="select * from bar",
+                executed_sql="select * from bar",
+                limit=100,
+                select_as_cta=False,
+                rows=104,
+                error_message="none",
+                results_key="abc_query",
+            )
+
+            db.session.add(query_obj)
+            db.session.commit()
+
+            yield
+
+            db.session.delete(query_obj)
+            db.session.commit()
+
+    @patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
     def test_validation_no_results_backend(self) -> None:
         results.results_backend = None
 
@@ -44,7 +229,7 @@ class TestSqlExecutionResultsCommand(SupersetTestCase):
             == SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR
         )
 
-    
@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
+    @patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
     def test_validation_data_cannot_be_retrieved(self) -> None:
         results.results_backend = mock.Mock()
         results.results_backend.get.return_value = None
@@ -55,8 +240,8 @@ class TestSqlExecutionResultsCommand(SupersetTestCase):
             command.run()
         assert ex_info.value.error.error_type == 
SupersetErrorType.RESULTS_BACKEND_ERROR
 
-    
@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
-    def test_validation_query_not_found(self) -> None:
+    @patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
+    def test_validation_data_not_found(self) -> None:
         data = [{"col_0": i} for i in range(100)]
         payload = {
             "status": QueryStatus.SUCCESS,
@@ -75,8 +260,9 @@ class TestSqlExecutionResultsCommand(SupersetTestCase):
             command.run()
         assert ex_info.value.error.error_type == 
SupersetErrorType.RESULTS_BACKEND_ERROR
 
-    
@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
-    def test_validation_query_not_found2(self) -> None:
+    @pytest.mark.usefixtures("create_database_and_query")
+    @patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
+    def test_validation_query_not_found(self) -> None:
         data = [{"col_0": i} for i in range(104)]
         payload = {
             "status": QueryStatus.SUCCESS,
@@ -89,38 +275,20 @@ class TestSqlExecutionResultsCommand(SupersetTestCase):
         results.results_backend = mock.Mock()
         results.results_backend.get.return_value = compressed
 
-        database = Database(database_name="my_database", 
sqlalchemy_uri="sqlite://")
-        query_obj = Query(
-            client_id="foo",
-            database=database,
-            tab_name="test_tab",
-            sql_editor_id="test_editor_id",
-            sql="select * from bar",
-            select_sql="select * from bar",
-            executed_sql="select * from bar",
-            limit=100,
-            select_as_cta=False,
-            rows=104,
-            error_message="none",
-            results_key="test_abc",
-        )
-
-        db.session.add(database)
-        db.session.add(query_obj)
-
         with mock.patch(
             "superset.views.utils._deserialize_results_payload",
             side_effect=SerializationError(),
         ):
             with pytest.raises(SupersetErrorException) as ex_info:
-                command = results.SqlExecutionResultsCommand("test", 1000)
+                command = results.SqlExecutionResultsCommand("test_other", 
1000)
                 command.run()
             assert (
                 ex_info.value.error.error_type
                 == SupersetErrorType.RESULTS_BACKEND_ERROR
             )
 
-    
@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
+    @pytest.mark.usefixtures("create_database_and_query")
+    @patch("superset.sqllab.commands.results.results_backend_use_msgpack", 
False)
     def test_run_succeeds(self) -> None:
         data = [{"col_0": i} for i in range(104)]
         payload = {
@@ -134,26 +302,7 @@ class TestSqlExecutionResultsCommand(SupersetTestCase):
         results.results_backend = mock.Mock()
         results.results_backend.get.return_value = compressed
 
-        database = Database(database_name="my_database", 
sqlalchemy_uri="sqlite://")
-        query_obj = Query(
-            client_id="foo",
-            database=database,
-            tab_name="test_tab",
-            sql_editor_id="test_editor_id",
-            sql="select * from bar",
-            select_sql="select * from bar",
-            executed_sql="select * from bar",
-            limit=100,
-            select_as_cta=False,
-            rows=104,
-            error_message="none",
-            results_key="test_abc",
-        )
-
-        db.session.add(database)
-        db.session.add(query_obj)
-
-        command = results.SqlExecutionResultsCommand("test_abc", 1000)
+        command = results.SqlExecutionResultsCommand("abc_query", 1000)
         result = command.run()
 
         assert result.get("status") == "success"

Reply via email to