timifasubaa closed pull request #5267: [hive-csv] Infer schema from csv
URL: https://github.com/apache/incubator-superset/pull/5267
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/requirements.txt b/requirements.txt
index f60b57e384..aaf97d1b7b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -34,6 +34,7 @@ six==1.11.0
 sqlalchemy==1.2.2
 sqlalchemy-utils==0.32.21
 sqlparse==0.2.4
+tableschema==1.1.0
 thrift==0.11.0
 thrift-sasl==0.3.0
 unicodecsv==0.14.1
diff --git a/setup.py b/setup.py
index 7adccb70f7..c5aa4b42da 100644
--- a/setup.py
+++ b/setup.py
@@ -90,6 +90,7 @@ def get_git_sha():
         'sqlalchemy',
         'sqlalchemy-utils',
         'sqlparse',
+        'tableschema',
         'thrift>=0.9.3',
         'thrift-sasl>=0.2.1',
         'unicodecsv',
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index d6a9144106..4f6b22e305 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -37,7 +37,7 @@
 from sqlalchemy.sql import text
 from sqlalchemy.sql.expression import TextAsFrom
 import sqlparse
-import unicodecsv
+from tableschema import Table
 from werkzeug.utils import secure_filename
 
 from superset import app, cache_util, conf, db, utils
@@ -134,7 +134,7 @@ def get_query_without_limit(cls, sql):
     @staticmethod
     def csv_to_df(**kwargs):
         kwargs['filepath_or_buffer'] = \
-            app.config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
+            config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
         kwargs['encoding'] = 'utf-8'
         kwargs['iterator'] = True
         chunks = pandas.read_csv(**kwargs)
@@ -156,7 +156,7 @@ def create_table_from_csv(form, table):
         def _allowed_file(filename):
             # Only allow specific file extensions as specified in the config
             extension = os.path.splitext(filename)[1]
-            return extension and extension[1:] in 
app.config['ALLOWED_EXTENSIONS']
+            return extension and extension[1:] in config['ALLOWED_EXTENSIONS']
 
         filename = secure_filename(form.csv_file.data.filename)
         if not _allowed_file(filename):
@@ -973,9 +973,15 @@ def fetch_data(cls, cursor, limit):
     @staticmethod
     def create_table_from_csv(form, table):
         """Uploads a csv file and creates a superset datasource in Hive."""
-        def get_column_names(filepath):
-            with open(filepath, 'rb') as f:
-                return next(unicodecsv.reader(f, encoding='utf-8-sig'))
+        def convert_to_hive_type(col_type):
+            """maps tableschema's types to hive types"""
+            tableschema_to_hive_types = {
+                'boolean': 'BOOLEAN',
+                'integer': 'INT',
+                'number': 'DOUBLE',
+                'string': 'STRING',
+            }
+            return tableschema_to_hive_types.get(col_type, 'STRING')
 
         table_name = form.name.data
         if config.get('UPLOADED_CSV_HIVE_NAMESPACE'):
@@ -988,21 +994,27 @@ def get_column_names(filepath):
                 config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name)
         filename = form.csv_file.data.filename
 
-        bucket_path = app.config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']
+        bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']
 
         if not bucket_path:
             logging.info('No upload bucket specified')
             raise Exception(
                 'No upload bucket specified. You can specify one in the config 
file.')
 
-        upload_prefix = app.config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
-        dest_path = os.path.join(table_name, filename)
+        table_name = form.name.data
+        filename = form.csv_file.data.filename
+        upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
 
-        upload_path = app.config['UPLOAD_FOLDER'] + \
+        upload_path = config['UPLOAD_FOLDER'] + \
             secure_filename(form.csv_file.data.filename)
-        column_names = get_column_names(upload_path)
-        schema_definition = ', '.join(
-            [s + ' STRING ' for s in column_names])
+
+        hive_table_schema = Table(upload_path).infer()
+        column_name_and_type = []
+        for column_info in hive_table_schema['fields']:
+            column_name_and_type.append(
+                '{} {}'.format(
+                    column_info['name'], 
convert_to_hive_type(column_info['type'])))
+        schema_definition = ', '.join(column_name_and_type)
 
         s3 = boto3.client('s3')
         location = os.path.join('s3a://', bucket_path, upload_prefix, 
table_name)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to