Github user iyerr3 commented on a diff in the pull request: https://github.com/apache/madlib/pull/291#discussion_r201878219 --- Diff: src/ports/postgres/modules/utilities/vec2cols.py_in --- @@ -0,0 +1,266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import plpy +from control import MinWarning +from internal.db_utils import is_col_1d_array +from utilities import _assert +from utilities import ANY_ARRAY +from utilities import is_valid_psql_type +from utilities import py_list_to_sql_string +from utilities import split_quoted_delimited_str +from validate_args import cols_in_tbl_valid +from validate_args import get_cols +from validate_args import get_expr_type +from validate_args import input_tbl_valid +from validate_args import output_tbl_valid +from validate_args import table_exists + +def get_cols_to_keep_list(cols_to_output, source_table=None, vector_col=None): + """ + Get a list of columns based on the value of cols_to_output + Args: + @param cols_to_output: str, Either a * or a comma separated list of col names + @param source_table: str, optional. Source table name + @param vector_col: str, optional. Name of the column representing the vector + + Returns: + A list of column names (or an empty list) + """ + # If cols_to_output is empty/None, return empty list + if not cols_to_output: + return [] + if cols_to_output.strip() != "*": + # If cols_to_output is a comma separated list of names, return list + # of column names in cols_to_output. + return split_quoted_delimited_str(cols_to_output) + if source_table and vector_col: + # If cols_to_output is *, and both + # source_table and vector_col are non-null values, return a list of + # all columns in source_table except the vector_col. + return [col for col in get_cols(source_table) if col != vector_col] + return [] + +def validate_args(source_table, out_table, vector_col, feature_names, + cols_to_output): + """ + Validate args for vec2cols + """ + input_tbl_valid(source_table, 'vec2cols') + output_tbl_valid(out_table, 'vec2cols') + cols_to_validate = get_cols_to_keep_list(cols_to_output) + [vector_col] + cols_in_tbl_valid(source_table, cols_to_validate, 'vec2cols') + # Check if vector_col is an array (not null) + _assert(is_valid_psql_type(get_expr_type(vector_col, source_table), ANY_ARRAY), + "vec2cols: vector_col should refer to an array.") + # Check if vector_col is a 1-dimensional array + _assert(is_col_1d_array(source_table, vector_col), + "vec2cols: vector_col must be a 1-dimensional array.") + +def get_names_for_split_output_cols(source_table, vector_col, feature_names): + """ + Get list of names for the newly-split columns to include in the + output table. + Args: + @param: source_table, str. Source table + @param: vector_col, str. Column name containing the array input + @param: feature_names, list. Python list of the feature names to + use for the split elements in the vector_col array + """ + query = """ --- End diff -- I'm assuming this was meant to use the `is_col_1d_array` function?
---