Github user njayaram2 commented on a diff in the pull request: https://github.com/apache/madlib/pull/291#discussion_r204589559 --- Diff: src/ports/postgres/modules/utilities/transform_vec_cols.py_in --- @@ -0,0 +1,513 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import plpy +from control import MinWarning +from internal.db_utils import is_col_1d_array +from internal.db_utils import quote_literal +from utilities import _assert +from utilities import add_postfix +from utilities import ANY_ARRAY +from utilities import is_psql_boolean_type +from utilities import is_psql_char_type +from utilities import is_psql_numeric_type +from utilities import is_valid_psql_type +from utilities import py_list_to_sql_string +from utilities import split_quoted_delimited_str +from validate_args import is_var_valid +from validate_args import get_cols +from validate_args import get_cols_and_types +from validate_args import get_expr_type +from validate_args import input_tbl_valid +from validate_args import output_tbl_valid +from validate_args import table_exists + +class vec_cols_helper: + def __init__(self): + self.all_cols = None + + def get_cols_as_list(self, cols_to_process, source_table=None, exclude_cols=None): + """ + Get a list of columns based on the value of cols_to_process + Args: + @param cols_to_process: str, Either a * or a comma-separated list of col names + @param source_table: str, optional. Source table name + @param exclude_cols: str, optional. Comma-separated list of the col(s) to exclude + from the source table, only used if cols_to_process is * + Returns: + A list of column names (or an empty list) + """ + # If cols_to_process is empty/None, return empty list + if not cols_to_process: + return [] + if cols_to_process.strip() != "*": + # If cols_to_process is a comma separated list of names, return list + # of column names in cols_to_process. + return [col for col in split_quoted_delimited_str(cols_to_process) + if col not in split_quoted_delimited_str(exclude_cols)] + if source_table: + if not self.all_cols: + self.all_cols = get_cols(source_table) + return [col for col in self.all_cols + if col not in split_quoted_delimited_str(exclude_cols)] + return [] + + def get_type_class(self, arg): + if is_psql_numeric_type(arg): + return "double precision" + elif is_psql_char_type(arg): + return "text" + else: + return arg + +class vec2cols: + def __init__(self): + self.get_cols_helper = vec_cols_helper() + self.module_name = self.__class__.__name__ + + def validate_args(self, source_table, output_table, vector_col, feature_names, + cols_to_output): + """ + Validate args for vec2cols + """ + input_tbl_valid(source_table, self.module_name) + output_tbl_valid(output_table, self.module_name) + is_var_valid(source_table, cols_to_output) + is_var_valid(source_table, vector_col) + _assert(is_valid_psql_type(get_expr_type(vector_col, source_table), ANY_ARRAY), + "{0}: vector_col should refer to an array.".format(self.module_name)) + _assert(is_col_1d_array(source_table, vector_col), + "{0}: vector_col must be a 1-dimensional array.".format(self.module_name)) + + def get_names_for_split_output_cols(self, source_table, vector_col, feature_names): + """ + Get list of names for the newly-split columns to include in the + output table. + Args: + @param: source_table, str. Source table + @param: vector_col, str. Column name containing the array input + @param: feature_names, list. Python list of the feature names to + use for the split elements in the vector_col array + """ + query = """ + SELECT array_upper({0}, 1) AS n_x + FROM {1} + LIMIT 1 + """.format(vector_col, source_table) + result = plpy.execute(query)[0]["n_x"] + if not result: + plpy.error('{0}: Column to split ({1}) must not be an empty array' + .format(self.module_name, vector_col)) + if not feature_names: + # Create custom col names for output columns, with prefix "f". + feature_names = ["f{0}".format(i+1) for i in range(result)] + else: + # Check if the array dimension is equal to the number of col names + # specified in feature_names. + _assert(result == len(feature_names), + "{0}: Mismatch between size of vector_col and number of " + "cols in feature_names.".format(self.module_name)) + return feature_names + + def validate_output_cols(self, features_to_unnest, cols_to_keep): + # If there are more than 1600 columns for the output table, we give a + # warning as it might give an error + MAX_OUTPUT_COLUMN_COUNT = 1600 + _assert(len(features_to_unnest)+len(cols_to_keep) < MAX_OUTPUT_COLUMN_COUNT, + "{0}: The output exceeds the max number of columns that " + + "can be created ({1})".format(self.module_name, MAX_OUTPUT_COLUMN_COUNT)) + # Check if newly created col names have the same name as existing cols + duplicate_col_names = set(features_to_unnest).intersection(set(cols_to_keep)) + _assert(len(duplicate_col_names) == 0, + "{0}: Conflicting column names. Column names in source " + "table cannot be {1}".format(self.module_name, + list(duplicate_col_names))) + + def vec2cols(self, schema_madlib, source_table, output_table, + vector_col, feature_names, cols_to_output, **kwargs): + """ + Split up a column of array entries into multiple columns, each column + corresponding to one array position + Args: + @param: schema_madlib, str. The schema with madlib installed + @param: source_table, str. The source table + @param: output_table, str. The output table + @param: vector_col, str. The column with array entries to split up + @param: feature_names, list. Python list of the feature names to use + for the split elements in the vector_col array + @param: cols_to_output, str. Comma-separated list of the columns in + the source_table to include in the output_table + """ + self.validate_args(source_table, output_table, vector_col, feature_names, + cols_to_output) + + # Get names of columns to use for the split vector_col + features_to_unnest = self.get_names_for_split_output_cols(source_table, + vector_col, feature_names) + cols_to_keep = self.get_cols_helper.get_cols_as_list(cols_to_output, + source_table) + + self.validate_output_cols(features_to_unnest, cols_to_keep) + + # Construct the output query and populate the output table with all the + # correct parameters + select_new_cols = ', '.join(["{0}[{1}] AS {2}".format(vector_col, + i+1, features_to_unnest[i]) for i in range(len(features_to_unnest))]) + cols_from_src_table = ', '.join(cols_to_keep)+', ' if cols_to_keep else '' + query = """ + CREATE TABLE {output_table} AS + SELECT {cols_from_src_table} {select_new_cols} + FROM {source_table} + """.format(**locals()) + plpy.execute(query) + + def vec2cols_help_message(self, schema_madlib, message, **kwargs): + """ + Help message for vec2cols function + """ + summary_string = """ +----------------------------------------------------------------------------------- + SUMMARY +----------------------------------------------------------------------------------- +Functionality: Vector to Columns + +The MADlib vec2cols function enables the user to split up a single column into +multiple columns, given that the input column contains array entries. For example, +if the input column contained ARRAY[1, 2, 3] in one of its rows, the output table +will contain 3 different columns, one for each element of the array. + +For more details on function usage: + SELECT {schema_madlib}.vec2cols('usage'); + +For a small example on using the function: + SELECT {schema_madlib}.vec2cols('example'); + """.format(schema_madlib=schema_madlib) + + usage_string = """ +----------------------------------------------------------------------------------- + USAGE +----------------------------------------------------------------------------------- +SELECT {schema_madlib}.vec2cols( + 'source_table', -- str, Name of the source table that contains the data + 'output_table', -- str, Name of the output view or table + 'vector_col', -- str, Name of the array entry column to be split + 'feature_names', -- array, Optional parameter to provide a text array of + -- the feature names for the newly split columns (if not + -- provided, default names f0, f1, ... will be used) + 'cols_to_output' -- str, Optional parameter to specify any other columns + -- in the source_table to include in the output_table + -- (default none of them, also supports '*' as input) + """.format(schema_madlib=schema_madlib) + + example_string = """ +----------------------------------------------------------------------------------- + EXAMPLE +----------------------------------------------------------------------------------- +-- Create an input data set: + +DROP TABLE IF EXISTS golf CASCADE; +CREATE TABLE golf ( + id integer NOT NULL, + "OUTLOOK" text, + temperature double precision, + humidity double precision, + "Temp_Humidity" double precision[], + clouds_airquality text[], + windy boolean, + class text, + observation_weight double precision +); +INSERT INTO golf VALUES +(1,'sunny', 85, 85, ARRAY[85, 85],ARRAY['none', 'unhealthy'], 'false','Don''t Play', 5.0), +(2, 'sunny', 80, 90, ARRAY[80, 90], ARRAY['none', 'moderate'], 'true', 'Don''t Play', 5.0), +(3, 'overcast', 83, 78, ARRAY[83, 78], ARRAY['low', 'moderate'], 'false', 'Play', 1.5), +(4, 'rain', 70, 96, ARRAY[70, 96], ARRAY['low', 'moderate'], 'false', 'Play', 1.0), +(5, 'rain', 68, 80, ARRAY[68, 80], ARRAY['medium', 'good'], 'false', 'Play', 1.0), +(6, 'rain', 65, 70, ARRAY[65, 70], ARRAY['low', 'unhealthy'], 'true', 'Don''t Play', 1.0), +(7, 'overcast', 64, 65, ARRAY[64, 65], ARRAY['medium', 'moderate'], 'true', 'Play', 1.5), +(8, 'sunny', 72, 95, ARRAY[72, 95], ARRAY['high', 'unhealthy'], 'false', 'Don''t Play', 5.0), +(9, 'sunny', 69, 70, ARRAY[69, 70], ARRAY['high', 'good'], 'false', 'Play', 5.0), +(10, 'rain', 75, 80, ARRAY[75, 80], ARRAY['medium', 'good'], 'false', 'Play', 1.0), +(11, 'sunny', 75, 70, ARRAY[75, 70], ARRAY['none', 'good'], 'true', 'Play', 5.0), +(12, 'overcast', 72, 90, ARRAY[72, 90], ARRAY['medium', 'moderate'], 'true', 'Play', 1.5), +(13, 'overcast', 81, 75, ARRAY[81, 75], ARRAY['medium', 'moderate'], 'false', 'Play', 1.5), +(14, 'rain', 71, 80, ARRAY[71, 80], ARRAY['low', 'unhealthy'], 'true', 'Don''t Play', 1.0); + +-- Call the vec2cols function on the 'clouds_airquality' column, to split it up + +DROP TABLE IF EXISTS output_table; +SELECT {schema_madlib}.vec2cols( + 'golf', -- source table + 'output_table', -- output table + 'clouds_airquality', -- column with array entries to split + ARRAY['a', 'b'], -- feature_names array (will use 'a' to name the first new column, and 'b' for the second) + '"OUTLOOK", id' -- columns to keep from source table (as a comma-separated list) +); + +SELECT * FROM output_table ORDER BY id; + OUTLOOK | id | a | b +----------+----+--------+----------- + sunny | 1 | none | unhealthy + sunny | 2 | none | moderate + overcast | 3 | low | moderate + rain | 4 | low | moderate + rain | 5 | medium | good + rain | 6 | low | unhealthy + overcast | 7 | medium | moderate + sunny | 8 | high | unhealthy + sunny | 9 | high | good + rain | 10 | medium | good + sunny | 11 | none | good + overcast | 12 | medium | moderate + overcast | 13 | medium | moderate + rain | 14 | low | unhealthy +(14 rows) +""".format(schema_madlib=schema_madlib) + + if not message: + return summary_string + elif message.lower() in ('usage', 'help', '?'): + return usage_string + elif message.lower() in ('example', 'examples'): + return example_string + else: + return """ +No such option. Use "SELECT {schema_madlib}.vec2cols()" for help. + """.format(schema_madlib=schema_madlib) + +class cols2vec: + def __init__(self): + self.get_cols_helper = vec_cols_helper() + self.module_name = self.__class__.__name__ + + def validate_args(self, source_table, output_table, + list_of_features, list_of_features_to_exclude, cols_to_output): + """ + Function to validate input parameters + """ + input_tbl_valid(source_table, self.module_name) + output_tbl_valid(output_table, self.module_name) + + _assert(list_of_features and list_of_features.strip(), "{0}: List of " + "features cannot be empty".format(self.module_name)) + if list_of_features.strip() != '*': + is_var_valid(source_table, list_of_features) + + if list_of_features_to_exclude: + if list_of_features_to_exclude.strip() == "*": + plpy.error("{0}: Cannot exclude all columns from being " + "features".format(self.module_name)) + elif list_of_features.strip() != '*': + plpy.info("{0} NOTICE: will exclude given column(s) even though " + "list of features was not *".format(self.module_name)) + + is_var_valid(source_table, list_of_features_to_exclude) + is_var_valid(source_table, cols_to_output) + + def get_and_validate_feature_types(self, source_table, features_to_nest): + """ + This function will validate and return the appropriate type to cast + the final SQL array. Will fail if feature types do not belong to the + same group (numeric, text, etc.) or if any type is an array. + + If all features to nest are of the same type, we just return that type + and do not cast. Else, if they only constitute integers and smallints, + we cast everything to an integer. Else, if they are all part of the same + type group, we return the most comprehensive type in that group. Else, throw an error. + """ + all_cols_and_types = get_cols_and_types(source_table) + distinct_types = set([col_type[1] for col_type in all_cols_and_types + if col_type[0] in features_to_nest]) + for expr_type in distinct_types: + _assert(not is_valid_psql_type(expr_type, ANY_ARRAY), + "{0}: Feature columns to nest cannot be of type array" + .format(self.module_name)) + if len(distinct_types) > 1: + if distinct_types == {'integer', 'smallint'}: --- End diff -- Is there a reason why `bigint` is not here? If the distinct types were `{integer, bigint}`, then this function would return back `double` right? It seems to be fine functionality-wise, in the sense that no precision is lost in the output array though. I created a table named `t1` with the following schema: ``` madlib-pg94=# \d+ t1; Table "public.t1" Column | Type | Modifiers | Storage | Stats target | Description --------+------------------+-------------------------------------------------+---------+--------------+------------- id | integer | not null default nextval('t1_id_seq'::regclass) | plain | | c1 | integer | | plain | | c2 | bigint | | plain | | c3 | smallint | | plain | | c4 | double precision | | plain | | ``` Calling `select madlib.cols2vec('t1','t1_out','c1,c2');` resulted in an output array of type `double precision[]`. Is that the expected behavior?
---