This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e8e441  feat: implement csv upload configuration func for the schema 
enforcement (#9734)
3e8e441 is described below

commit 3e8e441bfcb5b188840d3d08265073db1f67d1a6
Author: Bogdan <[email protected]>
AuthorDate: Thu May 21 13:49:53 2020 -0700

    feat: implement csv upload configuration func for the schema enforcement 
(#9734)
    
    * Implement csv upload func for schema enforcement
    
    Implement function controlled csv upload schema
    
    Refactor + fix tests
    
    Fixing hive as well
    
    * Add explore_db to the extras
    
    Co-authored-by: bogdan kyryliuk <[email protected]>
---
 superset/config.py               |  16 ++++
 superset/db_engine_specs/base.py |  62 ++++--------
 superset/db_engine_specs/hive.py |  55 ++++-------
 superset/models/core.py          |   8 +-
 superset/views/database/views.py | 113 ++++++++++++++++------
 tests/core_tests.py              | 198 ++++++++++++++++++++++++++-------------
 6 files changed, 272 insertions(+), 180 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index 1668f57..e3d3ffb 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -586,11 +586,27 @@ CSV_TO_HIVE_UPLOAD_S3_BUCKET = None
 # The directory within the bucket specified above that will
 # contain all the external tables
 CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
+# Function that creates upload directory dynamically based on the
+# database used, user and schema provided.
+CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC: Callable[
+    ["Database", "models.User", str], Optional[str]
+] = lambda database, user, schema: CSV_TO_HIVE_UPLOAD_DIRECTORY
 
 # The namespace within hive where the tables created from
 # uploading CSVs will be stored.
 UPLOADED_CSV_HIVE_NAMESPACE = None
 
+# Function that computes the allowed schemas for the CSV uploads.
+# Allowed schemas will be a union of schemas_allowed_for_csv_upload
+# db configuration and a result of this function.
+
+# mypy doesn't catch that if case ensures list content being always str
+ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
+    ["Database", "models.User"], List[str]
+] = lambda database, user: [
+    UPLOADED_CSV_HIVE_NAMESPACE  # type: ignore
+] if UPLOADED_CSV_HIVE_NAMESPACE else []
+
 # A dictionary of items that gets merged into the Jinja context for
 # SQL Lab. The existing context gets updated with this dictionary,
 # meaning values for existing keys get overwritten by the content of this
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 845b22c..a593f59 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -18,7 +18,6 @@
 import hashlib
 import json
 import logging
-import os
 import re
 from contextlib import closing
 from datetime import datetime
@@ -49,11 +48,11 @@ from sqlalchemy.orm import Session
 from sqlalchemy.sql import quoted_name, text
 from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, 
TextAsFrom
 from sqlalchemy.types import TypeEngine
-from wtforms.form import Form
 
 from superset import app, sql_parse
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.models.sql_lab import Query
+from superset.sql_parse import Table
 from superset.utils import core as utils
 
 if TYPE_CHECKING:
@@ -454,55 +453,26 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         df.to_sql(**kwargs)
 
     @classmethod
-    def create_table_from_csv(cls, form: Form, database: "Database") -> None:
+    def create_table_from_csv(  # pylint: disable=too-many-arguments
+        cls,
+        filename: str,
+        table: Table,
+        database: "Database",
+        csv_to_df_kwargs: Dict[str, Any],
+        df_to_sql_kwargs: Dict[str, Any],
+    ) -> None:
         """
         Create table from contents of a csv. Note: this method does not create
         metadata for the table.
-
-        :param form: Parameters defining how to process data
-        :param database: Database model object for the target database
         """
-
-        def _allowed_file(filename: str) -> bool:
-            # Only allow specific file extensions as specified in the config
-            extension = os.path.splitext(filename)[1].lower()
-            return (
-                extension is not None and extension[1:] in 
config["ALLOWED_EXTENSIONS"]
-            )
-
-        filename = form.csv_file.data.filename
-
-        if not _allowed_file(filename):
-            raise Exception("Invalid file type selected")
-        csv_to_df_kwargs = {
-            "filepath_or_buffer": filename,
-            "sep": form.sep.data,
-            "header": form.header.data if form.header.data else 0,
-            "index_col": form.index_col.data,
-            "mangle_dupe_cols": form.mangle_dupe_cols.data,
-            "skipinitialspace": form.skipinitialspace.data,
-            "skiprows": form.skiprows.data,
-            "nrows": form.nrows.data,
-            "skip_blank_lines": form.skip_blank_lines.data,
-            "parse_dates": form.parse_dates.data,
-            "infer_datetime_format": form.infer_datetime_format.data,
-            "chunksize": 10000,
-        }
-        df = cls.csv_to_df(**csv_to_df_kwargs)
-
+        df = cls.csv_to_df(filepath_or_buffer=filename, **csv_to_df_kwargs,)
         engine = cls.get_engine(database)
-
-        df_to_sql_kwargs = {
-            "df": df,
-            "name": form.name.data,
-            "con": engine,
-            "schema": form.schema.data,
-            "if_exists": form.if_exists.data,
-            "index": form.index.data,
-            "index_label": form.index_label.data,
-            "chunksize": 10000,
-        }
-        cls.df_to_sql(**df_to_sql_kwargs)
+        if table.schema:
+            # only add schema when it is preset and non empty
+            df_to_sql_kwargs["schema"] = table.schema
+        if engine.dialect.supports_multivalues_insert:
+            df_to_sql_kwargs["method"] = "multi"
+        cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs)
 
     @classmethod
     def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index cb810f1..3fb09ef 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -23,18 +23,19 @@ from typing import Any, Dict, List, Optional, Tuple, 
TYPE_CHECKING
 from urllib import parse
 
 import pandas as pd
+from flask import g
 from sqlalchemy import Column
 from sqlalchemy.engine.base import Engine
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import make_url, URL
 from sqlalchemy.orm import Session
 from sqlalchemy.sql.expression import ColumnClause, Select
-from wtforms.form import Form
 
 from superset import app, cache, conf
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.presto import PrestoEngineSpec
 from superset.models.sql_lab import Query
+from superset.sql_parse import Table
 from superset.utils import core as utils
 
 if TYPE_CHECKING:
@@ -105,8 +106,13 @@ class HiveEngineSpec(PrestoEngineSpec):
             return []
 
     @classmethod
-    def create_table_from_csv(  # pylint: disable=too-many-locals
-        cls, form: Form, database: "Database"
+    def create_table_from_csv(  # pylint: disable=too-many-arguments, 
too-many-locals
+        cls,
+        filename: str,
+        table: Table,
+        database: "Database",
+        csv_to_df_kwargs: Dict[str, Any],
+        df_to_sql_kwargs: Dict[str, Any],
     ) -> None:
         """Uploads a csv file and creates a superset datasource in Hive."""
 
@@ -128,38 +134,16 @@ class HiveEngineSpec(PrestoEngineSpec):
                 "No upload bucket specified. You can specify one in the config 
file."
             )
 
-        table_name = form.name.data
-        schema_name = form.schema.data
-
-        if config["UPLOADED_CSV_HIVE_NAMESPACE"]:
-            if "." in table_name or schema_name:
-                raise Exception(
-                    "You can't specify a namespace. "
-                    "All tables will be uploaded to the `{}` namespace".format(
-                        config["HIVE_NAMESPACE"]
-                    )
-                )
-            full_table_name = "{}.{}".format(
-                config["UPLOADED_CSV_HIVE_NAMESPACE"], table_name
-            )
-        else:
-            if "." in table_name and schema_name:
-                raise Exception(
-                    "You can't specify a namespace both in the name of the 
table "
-                    "and in the schema field. Please remove one"
-                )
-
-            full_table_name = (
-                "{}.{}".format(schema_name, table_name) if schema_name else 
table_name
-            )
-
-        filename = form.csv_file.data.filename
-        upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY"]
+        upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
+            database, g.user, table.schema
+        )
 
         # Optional dependency
-        from tableschema import Table  # pylint: disable=import-error
+        from tableschema import (  # pylint: disable=import-error
+            Table as TableSchemaTable,
+        )
 
-        hive_table_schema = Table(filename).infer()
+        hive_table_schema = TableSchemaTable(filename).infer()
         column_name_and_type = []
         for column_info in hive_table_schema["fields"]:
             column_name_and_type.append(
@@ -173,13 +157,14 @@ class HiveEngineSpec(PrestoEngineSpec):
         import boto3  # pylint: disable=import-error
 
         s3 = boto3.client("s3")
-        location = os.path.join("s3a://", bucket_path, upload_prefix, 
table_name)
+        location = os.path.join("s3a://", bucket_path, upload_prefix, 
table.table)
         s3.upload_file(
             filename,
             bucket_path,
-            os.path.join(upload_prefix, table_name, 
os.path.basename(filename)),
+            os.path.join(upload_prefix, table.table, 
os.path.basename(filename)),
         )
-        sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} )
+        # TODO(bkyryliuk): support other delimiters
+        sql = f"""CREATE TABLE {str(table)} ( {schema_definition} )
             ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
             TEXTFILE LOCATION '{location}'
             tblproperties ('skip.header.line.count'='1')"""
diff --git a/superset/models/core.py b/superset/models/core.py
index 94383c8..abcb210 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -609,7 +609,13 @@ class Database(
     def get_schema_access_for_csv_upload(  # pylint: disable=invalid-name
         self,
     ) -> List[str]:
-        return self.get_extra().get("schemas_allowed_for_csv_upload", [])
+        allowed_databases = 
self.get_extra().get("schemas_allowed_for_csv_upload", [])
+        if hasattr(g, "user"):
+            extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
+                self, g.user
+            )
+            allowed_databases += extra_allowed_databases
+        return sorted(set(allowed_databases))
 
     @property
     def sqlalchemy_uri_decrypted(self) -> str:
diff --git a/superset/views/database/views.py b/superset/views/database/views.py
index 47bf1c3..208d351 100644
--- a/superset/views/database/views.py
+++ b/superset/views/database/views.py
@@ -30,6 +30,7 @@ from superset import app, db
 from superset.connectors.sqla.models import SqlaTable
 from superset.constants import RouteMethod
 from superset.exceptions import CertificateException
+from superset.sql_parse import Table
 from superset.utils import core as utils
 from superset.views.base import DeleteMixin, SupersetModelView, YamlExportMixin
 
@@ -109,66 +110,116 @@ class CsvToDatabaseView(SimpleFormView):
 
     def form_post(self, form):
         database = form.con.data
-        schema_name = form.schema.data or ""
+        csv_table = Table(table=form.name.data, schema=form.schema.data)
 
-        if not schema_allows_csv_upload(database, schema_name):
+        if not schema_allows_csv_upload(database, csv_table.schema):
             message = _(
                 'Database "%(database_name)s" schema "%(schema_name)s" '
                 "is not allowed for csv uploads. Please contact your Superset 
Admin.",
                 database_name=database.database_name,
-                schema_name=schema_name,
+                schema_name=csv_table.schema,
             )
             flash(message, "danger")
             return redirect("/csvtodatabaseview/form")
 
-        csv_filename = form.csv_file.data.filename
-        extension = os.path.splitext(csv_filename)[1].lower()
-        path = tempfile.NamedTemporaryFile(
-            dir=app.config["UPLOAD_FOLDER"], suffix=extension, delete=False
+        if "." in csv_table.table and csv_table.schema:
+            message = _(
+                "You cannot specify a namespace both in the name of the table: 
"
+                '"%(csv_table.table)s" and in the schema field: '
+                '"%(csv_table.schema)s". Please remove one',
+                table=csv_table.table,
+                schema=csv_table.schema,
+            )
+            flash(message, "danger")
+            return redirect("/csvtodatabaseview/form")
+
+        uploaded_tmp_file_path = tempfile.NamedTemporaryFile(
+            dir=app.config["UPLOAD_FOLDER"],
+            suffix=os.path.splitext(form.csv_file.data.filename)[1].lower(),
+            delete=False,
         ).name
-        form.csv_file.data.filename = path
 
         try:
             utils.ensure_path_exists(config["UPLOAD_FOLDER"])
-            upload_stream_write(form.csv_file.data, path)
-            table_name = form.name.data
+            upload_stream_write(form.csv_file.data, uploaded_tmp_file_path)
 
             con = form.data.get("con")
             database = (
                 
db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
             )
-            database.db_engine_spec.create_table_from_csv(form, database)
-            table = (
+            csv_to_df_kwargs = {
+                "sep": form.sep.data,
+                "header": form.header.data if form.header.data else 0,
+                "index_col": form.index_col.data,
+                "mangle_dupe_cols": form.mangle_dupe_cols.data,
+                "skipinitialspace": form.skipinitialspace.data,
+                "skiprows": form.skiprows.data,
+                "nrows": form.nrows.data,
+                "skip_blank_lines": form.skip_blank_lines.data,
+                "parse_dates": form.parse_dates.data,
+                "infer_datetime_format": form.infer_datetime_format.data,
+                "chunksize": 1000,
+            }
+            df_to_sql_kwargs = {
+                "name": csv_table.table,
+                "if_exists": form.if_exists.data,
+                "index": form.index.data,
+                "index_label": form.index_label.data,
+                "chunksize": 1000,
+            }
+            database.db_engine_spec.create_table_from_csv(
+                uploaded_tmp_file_path,
+                csv_table,
+                database,
+                csv_to_df_kwargs,
+                df_to_sql_kwargs,
+            )
+
+            # Connect table to the database that should be used for 
exploration.
+            # E.g. if hive was used to upload a csv, presto will be a better 
option
+            # to explore the table.
+            expore_database = database
+            explore_database_id = 
database.get_extra().get("explore_database_id", None)
+            if explore_database_id:
+                expore_database = (
+                    db.session.query(models.Database)
+                    .filter_by(id=explore_database_id)
+                    .one_or_none()
+                    or database
+                )
+
+            sqla_table = (
                 db.session.query(SqlaTable)
                 .filter_by(
-                    table_name=table_name,
-                    schema=form.schema.data,
-                    database_id=database.id,
+                    table_name=csv_table.table,
+                    schema=csv_table.schema,
+                    database_id=expore_database.id,
                 )
                 .one_or_none()
             )
-            if table:
-                table.fetch_metadata()
-            if not table:
-                table = SqlaTable(table_name=table_name)
-                table.database = database
-                table.database_id = database.id
-                table.user_id = g.user.id
-                table.schema = form.schema.data
-                table.fetch_metadata()
-                db.session.add(table)
+
+            if sqla_table:
+                sqla_table.fetch_metadata()
+            if not sqla_table:
+                sqla_table = SqlaTable(table_name=csv_table.table)
+                sqla_table.database = expore_database
+                sqla_table.database_id = database.id
+                sqla_table.user_id = g.user.id
+                sqla_table.schema = csv_table.schema
+                sqla_table.fetch_metadata()
+                db.session.add(sqla_table)
             db.session.commit()
         except Exception as ex:  # pylint: disable=broad-except
             db.session.rollback()
             try:
-                os.remove(path)
+                os.remove(uploaded_tmp_file_path)
             except OSError:
                 pass
             message = _(
                 'Unable to upload CSV file "%(filename)s" to table '
                 '"%(table_name)s" in database "%(db_name)s". '
                 "Error message: %(error_msg)s",
-                filename=csv_filename,
+                filename=form.csv_file.data.filename,
                 table_name=form.name.data,
                 db_name=database.database_name,
                 error_msg=str(ex),
@@ -178,14 +229,14 @@ class CsvToDatabaseView(SimpleFormView):
             stats_logger.incr("failed_csv_upload")
             return redirect("/csvtodatabaseview/form")
 
-        os.remove(path)
+        os.remove(uploaded_tmp_file_path)
         # Go back to welcome page / splash screen
         message = _(
             'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in 
'
             'database "%(db_name)s"',
-            csv_filename=csv_filename,
-            table_name=form.name.data,
-            db_name=table.database.database_name,
+            csv_filename=form.csv_file.data.filename,
+            table_name=str(csv_table),
+            db_name=sqla_table.database.database_name,
         )
         flash(message, "info")
         stats_logger.incr("successful_csv_upload")
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 5c4f2b7..12293c1 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -24,6 +24,8 @@ import io
 import json
 import logging
 import os
+from typing import Dict, List, Optional
+
 import pytz
 import random
 import re
@@ -44,6 +46,7 @@ from superset import (
     is_feature_enabled,
 )
 from superset.connectors.sqla.models import SqlaTable
+from superset.datasets.dao import DatasetDAO
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.mssql import MssqlEngineSpec
 from superset.models import core as models
@@ -769,102 +772,163 @@ class CoreTests(SupersetTestCase):
         self.get_json_resp(slc_url, {"form_data": json.dumps(slc.form_data)})
         self.assertEqual(1, qry.count())
 
-    def test_import_csv(self):
-        self.login(username="admin")
-        table_name = "".join(random.choice(string.ascii_uppercase) for _ in 
range(5))
+    def create_sample_csvfile(self, filename: str, content: List[str]) -> None:
+        with open(filename, "w+") as test_file:
+            for l in content:
+                test_file.write(f"{l}\n")
 
-        filename_1 = "testCSV.csv"
-        test_file_1 = open(filename_1, "w+")
-        test_file_1.write("a,b\n")
-        test_file_1.write("john,1\n")
-        test_file_1.write("paul,2\n")
-        test_file_1.close()
+    def enable_csv_upload(self, database: models.Database) -> None:
+        """Enables csv upload in the given database."""
+        database.allow_csv_upload = True
+        db.session.commit()
+        add_datasource_page = self.get_resp("/databaseview/list/")
+        self.assertIn("Upload a CSV", add_datasource_page)
 
-        filename_2 = "testCSV2.csv"
-        test_file_2 = open(filename_2, "w+")
-        test_file_2.write("b,c,d\n")
-        test_file_2.write("john,1,x\n")
-        test_file_2.write("paul,2,y\n")
-        test_file_2.close()
+        form_get = self.get_resp("/csvtodatabaseview/form")
+        self.assertIn("CSV to Database configuration", form_get)
 
-        example_db = utils.get_example_database()
-        example_db.allow_csv_upload = True
-        db_id = example_db.id
-        db.session.commit()
+    def upload_csv(
+        self, filename: str, table_name: str, extra: Optional[Dict[str, str]] 
= None
+    ):
         form_data = {
-            "csv_file": open(filename_1, "rb"),
+            "csv_file": open(filename, "rb"),
             "sep": ",",
             "name": table_name,
-            "con": db_id,
+            "con": utils.get_example_database().id,
             "if_exists": "fail",
             "index_label": "test_label",
             "mangle_dupe_cols": False,
         }
-        url = "/databaseview/list/"
-        add_datasource_page = self.get_resp(url)
-        self.assertIn("Upload a CSV", add_datasource_page)
-
-        url = "/csvtodatabaseview/form"
-        form_get = self.get_resp(url)
-        self.assertIn("CSV to Database configuration", form_get)
+        if extra:
+            form_data.update(extra)
+        return self.get_resp("/csvtodatabaseview/form", data=form_data)
 
+    @mock.patch(
+        "superset.models.core.config",
+        {**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": lambda d, u: 
["admin_database"]},
+    )
+    def test_import_csv_enforced_schema(self):
+        if utils.get_example_database().backend == "sqlite":
+            # sqlite doesn't support schema / database creation
+            return
+        self.login(username="admin")
+        table_name = "".join(random.choice(string.ascii_lowercase) for _ in 
range(5))
+        full_table_name = f"admin_database.{table_name}"
+        filename = "testCSV.csv"
+        self.create_sample_csvfile(filename, ["a,b", "john,1", "paul,2"])
         try:
-            # initial upload with fail mode
-            resp = self.get_resp(url, data=form_data)
+            self.enable_csv_upload(utils.get_example_database())
+
+            # no schema specified, fail upload
+            resp = self.upload_csv(filename, table_name)
             self.assertIn(
-                f'CSV file "{filename_1}" uploaded to table "{table_name}"', 
resp
+                'Database "examples" schema "None" is not allowed for csv 
uploads', resp
             )
 
-            # upload again with fail mode; should fail
-            form_data["csv_file"] = open(filename_1, "rb")
-            resp = self.get_resp(url, data=form_data)
+            # user specified schema matches the expected schema, append
+            success_msg = f'CSV file "{filename}" uploaded to table 
"{full_table_name}"'
+            resp = self.upload_csv(
+                filename,
+                table_name,
+                extra={"schema": "admin_database", "if_exists": "append"},
+            )
+            self.assertIn(success_msg, resp)
+
+            resp = self.upload_csv(
+                filename,
+                table_name,
+                extra={"schema": "admin_database", "if_exists": "replace"},
+            )
+            self.assertIn(success_msg, resp)
+
+            # user specified schema doesn't match, fail
+            resp = self.upload_csv(filename, table_name, extra={"schema": 
"gold"})
             self.assertIn(
-                f'Unable to upload CSV file "{filename_1}" to table 
"{table_name}"',
+                'Database "examples" schema "gold" is not allowed for csv 
uploads',
                 resp,
             )
+        finally:
+            os.remove(filename)
+
+    def test_import_csv_explore_database(self):
+        if utils.get_example_database().backend == "sqlite":
+            # sqlite doesn't support schema / database creation
+            return
+        explore_db_id = utils.get_example_database().id
+
+        upload_db = utils.get_or_create_db(
+            "csv_explore_db", app.config["SQLALCHEMY_DATABASE_URI"]
+        )
+        upload_db_id = upload_db.id
+        extra = upload_db.get_extra()
+        extra["explore_database_id"] = explore_db_id
+        upload_db.extra = json.dumps(extra)
+        db.session.commit()
+
+        self.login(username="admin")
+        self.enable_csv_upload(DatasetDAO.get_database_by_id(upload_db_id))
+        table_name = "".join(random.choice(string.ascii_uppercase) for _ in 
range(5))
+
+        f = "testCSV.csv"
+        self.create_sample_csvfile(f, ["a,b", "john,1", "paul,2"])
+        # initial upload with fail mode
+        resp = self.upload_csv(f, table_name)
+        self.assertIn(f'CSV file "{f}" uploaded to table "{table_name}"', resp)
+        table = self.get_table_by_name(table_name)
+        self.assertEqual(table.database_id, explore_db_id)
+
+        # cleanup
+        db.session.delete(table)
+        db.session.delete(DatasetDAO.get_database_by_id(upload_db_id))
+        db.session.commit()
+        os.remove(f)
+
+    def test_import_csv(self):
+        self.login(username="admin")
+        table_name = "".join(random.choice(string.ascii_uppercase) for _ in 
range(5))
+
+        f1 = "testCSV.csv"
+        self.create_sample_csvfile(f1, ["a,b", "john,1", "paul,2"])
+        f2 = "testCSV2.csv"
+        self.create_sample_csvfile(f2, ["b,c,d", "john,1,x", "paul,2,y"])
+        self.enable_csv_upload(utils.get_example_database())
+
+        try:
+            success_msg_f1 = f'CSV file "{f1}" uploaded to table 
"{table_name}"'
+
+            # initial upload with fail mode
+            resp = self.upload_csv(f1, table_name)
+            self.assertIn(success_msg_f1, resp)
+
+            # upload again with fail mode; should fail
+            fail_msg = f'Unable to upload CSV file "{f1}" to table 
"{table_name}"'
+            resp = self.upload_csv(f1, table_name)
+            self.assertIn(fail_msg, resp)
 
             # upload again with append mode
-            form_data["csv_file"] = open(filename_1, "rb")
-            form_data["if_exists"] = "append"
-            resp = self.get_resp(url, data=form_data)
-            self.assertIn(
-                f'CSV file "{filename_1}" uploaded to table "{table_name}"', 
resp
-            )
+            resp = self.upload_csv(f1, table_name, extra={"if_exists": 
"append"})
+            self.assertIn(success_msg_f1, resp)
 
             # upload again with replace mode
-            form_data["csv_file"] = open(filename_1, "rb")
-            form_data["if_exists"] = "replace"
-            resp = self.get_resp(url, data=form_data)
-            self.assertIn(
-                f'CSV file "{filename_1}" uploaded to table "{table_name}"', 
resp
-            )
+            resp = self.upload_csv(f1, table_name, extra={"if_exists": 
"replace"})
+            self.assertIn(success_msg_f1, resp)
 
             # try to append to table from file with different schema
-            form_data["csv_file"] = open(filename_2, "rb")
-            form_data["if_exists"] = "append"
-            resp = self.get_resp(url, data=form_data)
-            self.assertIn(
-                f'Unable to upload CSV file "{filename_2}" to table 
"{table_name}"',
-                resp,
-            )
+            resp = self.upload_csv(f2, table_name, extra={"if_exists": 
"append"})
+            fail_msg_f2 = f'Unable to upload CSV file "{f2}" to table 
"{table_name}"'
+            self.assertIn(fail_msg_f2, resp)
 
             # replace table from file with different schema
-            form_data["csv_file"] = open(filename_2, "rb")
-            form_data["if_exists"] = "replace"
-            resp = self.get_resp(url, data=form_data)
-            self.assertIn(
-                f'CSV file "{filename_2}" uploaded to table "{table_name}"', 
resp
-            )
-            table = (
-                db.session.query(SqlaTable)
-                .filter_by(table_name=table_name, database_id=db_id)
-                .first()
-            )
+            resp = self.upload_csv(f2, table_name, extra={"if_exists": 
"replace"})
+            success_msg_f2 = f'CSV file "{f2}" uploaded to table 
"{table_name}"'
+            self.assertIn(success_msg_f2, resp)
+
+            table = self.get_table_by_name(table_name)
             # make sure the new column name is reflected in the table metadata
             self.assertIn("d", table.column_names)
         finally:
-            os.remove(filename_1)
-            os.remove(filename_2)
+            os.remove(f1)
+            os.remove(f2)
 
     def test_dataframe_timezone(self):
         tz = pytz.FixedOffset(60)

Reply via email to