Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/291#discussion_r204589559
  
    --- Diff: src/ports/postgres/modules/utilities/transform_vec_cols.py_in ---
    @@ -0,0 +1,513 @@
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file except in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +
    +import plpy
    +from control import MinWarning
    +from internal.db_utils import is_col_1d_array
    +from internal.db_utils import quote_literal
    +from utilities import _assert
    +from utilities import add_postfix
    +from utilities import ANY_ARRAY
    +from utilities import is_psql_boolean_type
    +from utilities import is_psql_char_type
    +from utilities import is_psql_numeric_type
    +from utilities import is_valid_psql_type
    +from utilities import py_list_to_sql_string
    +from utilities import split_quoted_delimited_str
    +from validate_args import is_var_valid
    +from validate_args import get_cols
    +from validate_args import get_cols_and_types
    +from validate_args import get_expr_type
    +from validate_args import input_tbl_valid
    +from validate_args import output_tbl_valid
    +from validate_args import table_exists
    +
    +class vec_cols_helper:
    +    def __init__(self):
    +        self.all_cols = None
    +
    +    def get_cols_as_list(self, cols_to_process, source_table=None, 
exclude_cols=None):
    +        """
    +            Get a list of columns based on the value of cols_to_process
    +            Args:
    +            @param cols_to_process: str, Either a * or a comma-separated 
list of col names
    +            @param source_table: str, optional. Source table name
    +            @param exclude_cols: str, optional. Comma-separated list of 
the col(s) to exclude
    +                                 from the source table, only used if 
cols_to_process is *
    +            Returns:
    +            A list of column names (or an empty list)
    +        """
    +        # If cols_to_process is empty/None, return empty list
    +        if not cols_to_process:
    +            return []
    +        if cols_to_process.strip() != "*":
    +            # If cols_to_process is a comma separated list of names, 
return list
    +            # of column names in cols_to_process.
    +            return [col for col in 
split_quoted_delimited_str(cols_to_process)
    +                    if col not in split_quoted_delimited_str(exclude_cols)]
    +        if source_table:
    +            if not self.all_cols:
    +                self.all_cols = get_cols(source_table)
    +            return [col for col in self.all_cols
    +                    if col not in split_quoted_delimited_str(exclude_cols)]
    +        return []
    +
    +    def get_type_class(self, arg):
    +        if is_psql_numeric_type(arg):
    +            return "double precision"
    +        elif is_psql_char_type(arg):
    +            return "text"
    +        else:
    +            return arg
    +
    +class vec2cols:
    +    def __init__(self):
    +        self.get_cols_helper = vec_cols_helper()
    +        self.module_name = self.__class__.__name__
    +
    +    def validate_args(self, source_table, output_table, vector_col, 
feature_names,
    +                      cols_to_output):
    +        """
    +            Validate args for vec2cols
    +        """
    +        input_tbl_valid(source_table, self.module_name)
    +        output_tbl_valid(output_table, self.module_name)
    +        is_var_valid(source_table, cols_to_output)
    +        is_var_valid(source_table, vector_col)
    +        _assert(is_valid_psql_type(get_expr_type(vector_col, 
source_table), ANY_ARRAY),
    +            "{0}: vector_col should refer to an 
array.".format(self.module_name))
    +        _assert(is_col_1d_array(source_table, vector_col),
    +            "{0}: vector_col must be a 1-dimensional 
array.".format(self.module_name))
    +
    +    def get_names_for_split_output_cols(self, source_table, vector_col, 
feature_names):
    +        """
    +            Get list of names for the newly-split columns to include in the
    +            output table.
    +            Args:
    +            @param: source_table, str. Source table
    +            @param: vector_col, str. Column name containing the array input
    +            @param: feature_names, list. Python list of the feature names 
to
    +                    use for the split elements in the vector_col array
    +        """
    +        query = """
    +            SELECT array_upper({0}, 1) AS n_x
    +            FROM {1}
    +            LIMIT 1
    +        """.format(vector_col, source_table)
    +        result = plpy.execute(query)[0]["n_x"]
    +        if not result:
    +            plpy.error('{0}: Column to split ({1}) must not be an empty 
array'
    +                .format(self.module_name, vector_col))
    +        if not feature_names:
    +            # Create custom col names for output columns, with prefix "f".
    +            feature_names = ["f{0}".format(i+1) for i in range(result)]
    +        else:
    +            # Check if the array dimension is equal to the number of col 
names
    +            # specified in feature_names.
    +            _assert(result == len(feature_names),
    +                    "{0}: Mismatch between size of vector_col and number 
of "
    +                    "cols in feature_names.".format(self.module_name))
    +        return feature_names
    +
    +    def validate_output_cols(self, features_to_unnest, cols_to_keep):
    +        # If there are more than 1600 columns for the output table, we 
give a
    +        # warning as it might give an error
    +        MAX_OUTPUT_COLUMN_COUNT = 1600
    +        _assert(len(features_to_unnest)+len(cols_to_keep) < 
MAX_OUTPUT_COLUMN_COUNT,
    +                "{0}: The output exceeds the max number of columns that " +
    +                "can be created ({1})".format(self.module_name, 
MAX_OUTPUT_COLUMN_COUNT))
    +        # Check if newly created col names have the same name as existing 
cols
    +        duplicate_col_names = 
set(features_to_unnest).intersection(set(cols_to_keep))
    +        _assert(len(duplicate_col_names) == 0,
    +                "{0}: Conflicting column names. Column names in source "
    +                "table cannot be {1}".format(self.module_name,
    +                                            list(duplicate_col_names)))
    +
    +    def vec2cols(self, schema_madlib, source_table, output_table,
    +                 vector_col, feature_names, cols_to_output, **kwargs):
    +        """
    +            Split up a column of array entries into multiple columns, each 
column
    +            corresponding to one array position
    +            Args:
    +            @param: schema_madlib, str. The schema with madlib installed
    +            @param: source_table, str. The source table
    +            @param: output_table, str. The output table
    +            @param: vector_col, str. The column with array entries to 
split up
    +            @param: feature_names, list. Python list of the feature names 
to use
    +                    for the split elements in the vector_col array
    +            @param: cols_to_output, str. Comma-separated list of the 
columns in
    +                    the source_table to include in the output_table
    +        """
    +        self.validate_args(source_table, output_table, vector_col, 
feature_names,
    +                           cols_to_output)
    +
    +        # Get names of columns to use for the split vector_col
    +        features_to_unnest = 
self.get_names_for_split_output_cols(source_table,
    +                             vector_col, feature_names)
    +        cols_to_keep = 
self.get_cols_helper.get_cols_as_list(cols_to_output,
    +                       source_table)
    +
    +        self.validate_output_cols(features_to_unnest, cols_to_keep)
    +
    +        # Construct the output query and populate the output table with 
all the
    +        # correct parameters
    +        select_new_cols = ', '.join(["{0}[{1}] AS {2}".format(vector_col,
    +            i+1, features_to_unnest[i]) for i in 
range(len(features_to_unnest))])
    +        cols_from_src_table = ', '.join(cols_to_keep)+', ' if cols_to_keep 
else ''
    +        query = """
    +        CREATE TABLE {output_table} AS
    +        SELECT {cols_from_src_table} {select_new_cols}
    +        FROM {source_table}
    +        """.format(**locals())
    +        plpy.execute(query)
    +
    +    def vec2cols_help_message(self, schema_madlib, message, **kwargs):
    +        """
    +            Help message for vec2cols function
    +        """
    +        summary_string = """
    
+-----------------------------------------------------------------------------------
    +                                    SUMMARY
    
+-----------------------------------------------------------------------------------
    +Functionality: Vector to Columns
    +
    +The MADlib vec2cols function enables the user to split up a single column 
into
    +multiple columns, given that the input column contains array entries. For 
example,
    +if the input column contained ARRAY[1, 2, 3] in one of its rows, the 
output table
    +will contain 3 different columns, one for each element of the array.
    +
    +For more details on function usage:
    +    SELECT {schema_madlib}.vec2cols('usage');
    +
    +For a small example on using the function:
    +    SELECT {schema_madlib}.vec2cols('example');
    +    """.format(schema_madlib=schema_madlib)
    +
    +        usage_string = """
    
+-----------------------------------------------------------------------------------
    +                                    USAGE
    
+-----------------------------------------------------------------------------------
    +SELECT {schema_madlib}.vec2cols(
    +    'source_table',     -- str, Name of the source table that contains the 
data
    +    'output_table',     -- str, Name of the output view or table
    +    'vector_col',       -- str, Name of the array entry column to be split
    +    'feature_names',    -- array, Optional parameter to provide a text 
array of
    +                        -- the feature names for the newly split columns 
(if not
    +                        -- provided, default names f0, f1, ... will be 
used)
    +    'cols_to_output'    -- str, Optional parameter to specify any other 
columns
    +                        -- in the source_table to include in the 
output_table
    +                        -- (default none of them, also supports '*' as 
input)
    +    """.format(schema_madlib=schema_madlib)
    +
    +        example_string = """
    
+-----------------------------------------------------------------------------------
    +                                    EXAMPLE
    
+-----------------------------------------------------------------------------------
    +-- Create an input data set:
    +
    +DROP TABLE IF EXISTS golf CASCADE;
    +CREATE TABLE golf (
    +    id integer NOT NULL,
    +    "OUTLOOK" text,
    +    temperature double precision,
    +    humidity double precision,
    +    "Temp_Humidity" double precision[],
    +    clouds_airquality text[],
    +    windy boolean,
    +    class text,
    +    observation_weight double precision
    +);
    +INSERT INTO golf VALUES
    +(1,'sunny', 85, 85, ARRAY[85, 85],ARRAY['none', 'unhealthy'], 
'false','Don''t Play', 5.0),
    +(2, 'sunny', 80, 90, ARRAY[80, 90], ARRAY['none', 'moderate'], 'true', 
'Don''t Play', 5.0),
    +(3, 'overcast', 83, 78, ARRAY[83, 78], ARRAY['low', 'moderate'], 'false', 
'Play', 1.5),
    +(4, 'rain', 70, 96, ARRAY[70, 96], ARRAY['low', 'moderate'], 'false', 
'Play', 1.0),
    +(5, 'rain', 68, 80, ARRAY[68, 80], ARRAY['medium', 'good'], 'false', 
'Play', 1.0),
    +(6, 'rain', 65, 70, ARRAY[65, 70], ARRAY['low', 'unhealthy'], 'true', 
'Don''t Play', 1.0),
    +(7, 'overcast', 64, 65, ARRAY[64, 65], ARRAY['medium', 'moderate'], 
'true', 'Play', 1.5),
    +(8, 'sunny', 72, 95, ARRAY[72, 95], ARRAY['high', 'unhealthy'], 'false', 
'Don''t Play', 5.0),
    +(9, 'sunny', 69, 70, ARRAY[69, 70], ARRAY['high', 'good'], 'false', 
'Play', 5.0),
    +(10, 'rain', 75, 80, ARRAY[75, 80], ARRAY['medium', 'good'], 'false', 
'Play', 1.0),
    +(11, 'sunny', 75, 70, ARRAY[75, 70], ARRAY['none', 'good'], 'true', 
'Play', 5.0),
    +(12, 'overcast', 72, 90, ARRAY[72, 90], ARRAY['medium', 'moderate'], 
'true', 'Play', 1.5),
    +(13, 'overcast', 81, 75, ARRAY[81, 75], ARRAY['medium', 'moderate'], 
'false', 'Play', 1.5),
    +(14, 'rain', 71, 80, ARRAY[71, 80], ARRAY['low', 'unhealthy'], 'true', 
'Don''t Play', 1.0);
    +
    +-- Call the vec2cols function on the 'clouds_airquality' column, to split 
it up
    +
    +DROP TABLE IF EXISTS output_table;
    +SELECT {schema_madlib}.vec2cols(
    +    'golf',               -- source table
    +    'output_table',       -- output table
    +    'clouds_airquality',  -- column with array entries to split
    +    ARRAY['a', 'b'],      -- feature_names array (will use 'a' to name the 
first new column, and 'b' for the second)
    +    '"OUTLOOK", id'       -- columns to keep from source table (as a 
comma-separated list)
    +);
    +
    +SELECT * FROM output_table ORDER BY id;
    + OUTLOOK  | id |   a    |     b
    +----------+----+--------+-----------
    + sunny    |  1 | none   | unhealthy
    + sunny    |  2 | none   | moderate
    + overcast |  3 | low    | moderate
    + rain     |  4 | low    | moderate
    + rain     |  5 | medium | good
    + rain     |  6 | low    | unhealthy
    + overcast |  7 | medium | moderate
    + sunny    |  8 | high   | unhealthy
    + sunny    |  9 | high   | good
    + rain     | 10 | medium | good
    + sunny    | 11 | none   | good
    + overcast | 12 | medium | moderate
    + overcast | 13 | medium | moderate
    + rain     | 14 | low    | unhealthy
    +(14 rows)
    +""".format(schema_madlib=schema_madlib)
    +        
    +        if not message:
    +            return summary_string
    +        elif message.lower() in ('usage', 'help', '?'):
    +            return usage_string
    +        elif message.lower() in ('example', 'examples'):
    +            return example_string
    +        else:
    +            return """
    +No such option. Use "SELECT {schema_madlib}.vec2cols()" for help.
    +        """.format(schema_madlib=schema_madlib)
    +
    +class cols2vec:
    +    def __init__(self):
    +        self.get_cols_helper = vec_cols_helper()
    +        self.module_name = self.__class__.__name__
    +
    +    def validate_args(self, source_table, output_table,
    +                      list_of_features, list_of_features_to_exclude, 
cols_to_output):
    +        """
    +            Function to validate input parameters
    +        """
    +        input_tbl_valid(source_table, self.module_name)
    +        output_tbl_valid(output_table, self.module_name)
    +
    +        _assert(list_of_features and list_of_features.strip(), "{0}: List 
of "
    +                    "features cannot be empty".format(self.module_name))
    +        if list_of_features.strip() != '*':
    +            is_var_valid(source_table, list_of_features)
    +        
    +        if list_of_features_to_exclude:
    +            if list_of_features_to_exclude.strip() == "*":
    +                plpy.error("{0}: Cannot exclude all columns from being "
    +                    "features".format(self.module_name))
    +            elif list_of_features.strip() != '*':
    +                plpy.info("{0} NOTICE: will exclude given column(s) even 
though "
    +                    "list of features was not *".format(self.module_name))
    +
    +        is_var_valid(source_table, list_of_features_to_exclude)
    +        is_var_valid(source_table, cols_to_output)
    +
    +    def get_and_validate_feature_types(self, source_table, 
features_to_nest):
    +        """
    +            This function will validate and return the appropriate type to 
cast
    +            the final SQL array. Will fail if feature types do not belong 
to the
    +            same group (numeric, text, etc.) or if any type is an array.
    +
    +            If all features to nest are of the same type, we just return 
that type
    +            and do not cast. Else, if they only constitute integers and 
smallints,
    +            we cast everything to an integer. Else, if they are all part 
of the same
    +            type group, we return the most comprehensive type in that 
group. Else, throw an error.
    +        """
    +        all_cols_and_types = get_cols_and_types(source_table)
    +        distinct_types = set([col_type[1] for col_type in 
all_cols_and_types
    +            if col_type[0] in features_to_nest])
    +        for expr_type in distinct_types:
    +            _assert(not is_valid_psql_type(expr_type, ANY_ARRAY),
    +                "{0}: Feature columns to nest cannot be of type array"
    +                .format(self.module_name))
    +        if len(distinct_types) > 1:
    +            if distinct_types == {'integer', 'smallint'}:
    --- End diff --
    
    Is there a reason why `bigint` is not here? If the distinct types were 
`{integer, bigint}`, then this function would return back `double` right? It 
seems to be fine functionality-wise, in the sense that no precision is lost in 
the output array though.
    
    I created a table named `t1` with the following schema:
    ```
    madlib-pg94=# \d+ t1;
                                                     Table "public.t1"
     Column |       Type       |                    Modifiers                   
 | Storage | Stats target | Description
    
--------+------------------+-------------------------------------------------+---------+--------------+-------------
     id     | integer          | not null default 
nextval('t1_id_seq'::regclass) | plain   |              |
     c1     | integer          |                                                
 | plain   |              |
     c2     | bigint           |                                                
 | plain   |              |
     c3     | smallint         |                                                
 | plain   |              |
     c4     | double precision |                                                
 | plain   |              |
    ```
    Calling `select madlib.cols2vec('t1','t1_out','c1,c2');` resulted in an 
output array of type `double precision[]`. Is that the expected behavior?


---

Reply via email to