khannaekta commented on code in PR #581: URL: https://github.com/apache/madlib/pull/581#discussion_r897094694
########## src/ports/postgres/modules/mxgboost/madlib_xgboost.sql_in: ########## @@ -0,0 +1,714 @@ +/* ----------------------------------------------------------------------- *//** + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @file madlib_xgboost.sql_in + * @brief SQL functions for xgboost + * @date 05/12/2022 + * + */ + /* ----------------------------------------------------------------------- */ + + +/** +@addtogroup grp_xgboost + +@brief +This module allows you to use SQL to build gradient boosted tree +models designed in XGBoost [1]. + +<div class="toc"><b>Contents</b> +<ul> +<li><a href="#train">Training Function</a></li> +<li><a href="#predict">Prediction Function</a></li> +<li><a href="#examples">Examples</a></li> +</ul> +</div> + +XGBoost is an optimized distributed gradient boosting library. The MADlib +implementation is designed to train multiple models with different parameters +to achieve parallelism on Greenplum systems (PostgreSQL is also supported.) + +The main use case supported is classification. Regression is not currently +supported. + +Please note that the XGBoost implementation has the following limitations: +- All of the dataset is collected in a single segment, compressed by zlib and +stored in a single cells (one column of one row). PostgreSQL has a 1 GB limit +for each cell which means the supported dataset size is limited. +- The collected data is passed to the XGBoost in a single segment. This means, +without grid search, the XGBoost is using only a single segment which might +increase the run time. + +XGBoost requires the following python libraries available in every host machine: +- pandas v0.42 +- scikit-learn v0.19 +- xgboost v0.82 + +Versions are provided to ensure a python2 supported libraries can be obtained. +Other versions of these libraries might work just as well. + +@anchor train +@par Training Function +<pre class="syntax"> +SELECT xgboost( + source_table, + output_table_name, + id_column, + dependent_variable, + list_of_features, + list_of_features_to_exclude, + params_str, + sample_weights, + train_set_size, + train_set_split_var +) +</pre> + +\b Arguments +<dl class="arglist"> + <dt>source_table</dt> + <dd>TEXT. Name of the table containing the training data.</dd> + + <dt>output_table</dt> + <dd>TEXT. Name of the generated table containing the model. + If a table with the same name already exists, an + error will be returned. A summary table + named <em>\<output_table_name\>_summary</em> is also + created. + </DD> + + <DT>id_column</DT> + <DD>TEXT. Name of the column containing id information in the training data. + This is a mandatory argument and the values are expected to be unique for each + row. + </DD> + + <DT>dependent_variable</DT> + <DD>TEXT. Name of the column that contains the output (response) for + training. Boolean, integer and text types are accepted for classification. + </DD> + + <DT>list_of_features</DT> + <DD>TEXT. Comma-separated string of column names or expressions to use as predictors. + Can also be a '*' implying all columns are to be used as predictors (except for the + ones included in the next argument that lists exclusions). + + <DT>list_of_features_to_exclude (optional)</DT> + <DD>TEXT[]. Text array of column names to exclude from the predictors + list. The names in this parameter should be identical to the names used in the table and + quoted appropriately. The id_column and dependent_variable are excluded by default. + </DD> + + <DT>params_str (optional)</DT> + <DD>TEXT. Comma-separated string of key-value pairs used for initializing + XGBClassifier. Each parameter can be assigned an array of values. These values + are expanded into multiple parameter sets for grid search. + + Note that these key-value pairs are passed to the xgboost as is and xgboost + accepts any parameter thanks to kwargs. If there is a typo in the list, it + might get ignored by xgboost. + </DD> + + <DT>sample_weights (optional)</DT> + <DD>TEXT. Column name containing numerical weights for sample_weights parameter of + XGBClassifier fit function. + </DD> + + <DT>train_set_size (optional)</DT> + <DD>DOUBLE PRECISION. The proportion of the dataset to be used in training. + </DD> + + <DT>train_set_split_var (optional)</DT> + <DD>TEXT. Column name containing information on whether the associated row will be used for training or testing. If this parameter is set, train_set_size will be ignored. + </DD> +</DL> + +\b Output +<dl class="arglist"> +<DD> + The model table produced by the training function contains the following columns: + + <table class="output"> + <tr> + <th>model</th> + <td>BYTEA8. Trained XGBoost model stored in binary + format (not human readable).</td> + </tr> + <tr> + <th>features</th> + <td>TEXT[]. Ordered levels (values) of categorical variables + corresponding to the categorical features. Used to help + interpret the trained model.</td> + </tr> + <tr> + <th>params_index</th> + <td>INTEGER. The index of the model. During grid search, each parameter + set is given an index to distinguish different models. + </td> + </tr> + + </table> + + A summary table named <em>\<output_table\>_summary</em> is also created at + the same time, which has the following columns: + + <table class="output"> + <tr> + <th>mdl_train_ts</th> + <td>TEXT. The timestamp for the initiation of the xgboost command.</td> + </tr> + <tr> + <th>mdl_name</th> + <td>TEXT. The name of the model set to <source_table>_xgboost.</td> + </tr> + <tr> + <th>features</th> + <td>TEXT[]. The list of features inputed to the XGBoost.</td> + </tr> + <tr> + <th>params</th> + <td>TEXT. The XGBClassifier parameters used for this particular model.</td> + </tr> + <tr> + <th>fnames</th> + <td>TEXT. The list of features in sorted order. This column is used for importance</td> + </tr> + <tr> + <th>importance</th> + <td>DOUBLE PRECISION[]. Importance values for each feature in the sorted order.</td> + </tr> + <tr> + <th>recall</th> + <td>DOUBLE PRECISION[]. Recall values for each class.</td> + </tr> + <tr> + <th>fscore</th> + <td>DOUBLE PRECISION[]. fscore values for each class.</td> + </tr> + <tr> + <th>support</th> + <td>DOUBLE PRECISION[]. support values for each class.</td> + </tr> + <tr> + <th>test_ids</th> + <td>INTEGER[]. The ids of the rows used for this particular model.</td> + </tr> + <tr> + <th>params_index</th> + <td>INTEGER. The index of the parameter set used for this model.</td> + </tr> + </table> + + </DD> +</DL> + +@anchor predict +@par Prediction Function + +<pre class="syntax"> +SELECT xgboost_predict( + test_table, + model_table, + predict_output_table, + id_column, + class_label, + model_filters +); +</pre> + +\b Arguments +<dl class="arglist"> + <dt>test_table</dt> + <dd>TEXT. Name of the table containing the test data.</dd> + + <dt>model_table</dt> + <dd>TEXT. Name of the table containing the xgboost train output.</dd> + + <dt>predict_output_table</dt> + <dd>TEXT. Name of the generated table containing the prediction. + If a table with the same name already exists, an + error will be returned. A metrics table named + <em>\<output_table_name\>_metrics</em> and an roc curve table named + <em>\<output_table_name\>_roc_curve</em> are also created if class_label is + not NULL. + </DD> + + <DT>id_column</DT> + <DD>TEXT. Name of the column containing id information in the test data. + This is a mandatory argument and the values are expected to be unique for each + row. + </DD> + + <DT>class_label (optional)</DT> + <DD>TEXT. Name of the column containing class_label for the metrics and + roc_curve calculations. + </DD> + + <DT>model_filters (optional)</DT> + <DD>TEXT. The filter for the model_table in case the user wants to use a + subset of the models from the grid search. + </DD> +</DL> + +\b Output + +XGBoost prediction function creates three tables: <predict_output_table>, +<predict_output_table>_metrics and <predict_output_table>_roc_curve. +Some fields of the metrics table as well as the roc_curve table are available +only if a class_label column is provided. + +<predict_output_table> +<dl class="arglist"> +<DD> +<table class="output"> + <tr> + <th>id_column</th> + <td>INTEGER. Name of the column containing id information in the test data.</td> + </tr> + <tr> + <th>class_label_predicted</th> + <td>TEXT. The predicted value for the given data id.</td> + </tr> + <tr> + <th>class_label_proba_predicted</th> + <td>DOUBLE PRECISION[]. The prediction probabilities for the given data id.</td> + </tr> +</table> + +<predict_output_table>_metrics +<table class="output"> + <tr> + <th>precision</th> + <td>DOUBLE PRECISION[]. Precision values for each class in the sorted order.</td> + </tr> + <tr> + <th>recall</th> + <td>DOUBLE PRECISION[]. Recall values for each class.</td> + </tr> + <tr> + <th>fscore</th> + <td>DOUBLE PRECISION[]. fscore values for each class.</td> + </tr> + <tr> + <th>support</th> + <td>DOUBLE PRECISION[]. support values for each class.</td> + </tr> + <tr> + <th>roc_auc_scores</th> + <td>DOUBLE PRECISION[]. roc_auc scores for each class.</td> + </tr> + <tr> + <th>feature_names</th> + <td>TEXT[]. The list of features in sorted order. This column is used for importance</td> + </tr> + <tr> + <th>feature_importance_scores</th> + <td>DOUBLE PRECISION[]. Importance values for each feature in the sorted order.</td> + </tr> +</table> + +<predict_output_table>_roc_curve +<table class="output"> + <tr> + <th>fpr</th> + <td>DOUBLE PRECISION[]. False positive rate for the roc curve.</td> + </tr> + <tr> + <th>tpr</th> + <td>DOUBLE PRECISION[]. True positive rate for the roc curve.</td> + </tr> + <tr> + <th>thresholds</th> + <td>DOUBLE PRECISION[]. Threshold values for the roc curve.</td> + </tr> +</table> + +@anchor examples +@par Examples + +<a href="example/madlib_xgboost_example.sql">Download the example sql file here.</a> + +-# Show the input data set. +<pre class="example"> +SELECT * FROM abalone LIMIT 10; +</pre> +<pre class="result"> + id | sex | length | diameter | height | whole_weight | shucked_weight | viscera_weight | shell_weight | rings +----+-----+--------+----------+--------+--------------+----------------+----------------+--------------+------- + 1 | M | 0.455 | 0.365 | 0.095 | 0.514 | 0.2245 | 0.101 | 0.15 | 15 + 12 | M | 0.43 | 0.35 | 0.11 | 0.406 | 0.1675 | 0.081 | 0.135 | 10 + 15 | F | 0.47 | 0.355 | 0.1 | 0.4755 | 0.1675 | 0.0805 | 0.185 | 10 + 20 | M | 0.45 | 0.32 | 0.1 | 0.381 | 0.1705 | 0.075 | 0.115 | 9 + 23 | F | 0.565 | 0.44 | 0.155 | 0.9395 | 0.4275 | 0.214 | 0.27 | 12 + 26 | F | 0.56 | 0.44 | 0.14 | 0.9285 | 0.3825 | 0.188 | 0.3 | 11 + 30 | M | 0.575 | 0.425 | 0.14 | 0.8635 | 0.393 | 0.227 | 0.2 | 11 + 31 | M | 0.58 | 0.47 | 0.165 | 0.9975 | 0.3935 | 0.242 | 0.33 | 10 + 35 | F | 0.705 | 0.55 | 0.2 | 1.7095 | 0.633 | 0.4115 | 0.49 | 13 + 36 | M | 0.465 | 0.355 | 0.105 | 0.4795 | 0.227 | 0.124 | 0.125 | 8 +</pre> + +-# Run XGBoost for a single parameter set and show the summary table +<pre class="example"> +DROP TABLE IF EXISTS xgb_out, xgb_out_summary; +SELECT xgboost( + 'abalone', -- Training table + 'id', -- Id column + 'sex', -- Class label column + '*', -- Independent variables + NULL, -- Columns to exclude from features + $$ + { + 'learning_rate': [0.01], #Regularization on weights (eta). For smaller values, increase n_estimators + 'max_depth': [9],#Larger values could lead to overfitting + 'subsample': [0.85],#introduce randomness in samples picked to prevent overfitting + 'colsample_bytree': [0.85],#introduce randomness in features picked to prevent overfitting + 'min_child_weight': [10],#larger values will prevent over-fitting + 'n_estimators':[100] #More estimators, lesser variance (better fit on test set) + } + $$, -- XGBoost grid search parameters + 'xgb_out', -- Grid search results table. + '', -- Class weights + 0.8, -- Training set size ratio + NULL -- Variable used to do the test/train split. +); +\\x on +SELECT * FROM xgb_out_summary; +</pre> +<pre class="result"> +-[ RECORD 1 ]+--------------------------------------------- +mdl_train_ts | 2022-05-16 17:42:45.033789+03 +mdl_name | abalone_xgboost +features | {length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings} +params | ('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9') +fnames | {viscera_weight,whole_weight,shucked_weight,shell_weight,length,rings,diameter,height} +importance | {1206,1189,1184,867,766,622,589,526} +precision | {0.4295774647887324,0.6858006042296072,0.4318181818181818} +recall | {0.46387832699619774,0.8021201413427562,0.328719723183391} +fscore | {0.4460694698354662,0.739413680781759,0.3732809430255403} +support | {263.0,283.0,289.0} +test_ids | {486,3627,432,2766,132,2397,3313,2346...} +</pre> + +-# Run XGBoost prediction. For this example we are using the same abalone table for prediction as well. +<pre class="example"> +DROP TABLE IF EXISTS xgb_score_out, xgb_score_out_metrics, xgb_score_out_roc_curve; +SELECT xgboost_predict( + 'abalone', -- test_table + 'xgb_out', -- model_table + 'xgb_score_out', -- predict_output_table + 'id', -- id_column + 'sex' -- class_label +); +\\x off +SELECT * FROM xgb_score_out LIMIT 10; +</pre> +<pre class="result"> + id | sex_predicted | sex_proba_predicted +----+---------------+------------------------------------------------ + 5 | I | {0.168330997229,0.632858633995,0.198810324073} + 6 | I | {0.202547758818,0.574552714825,0.222899526358} + 9 | I | {0.255484640598,0.44379144907,0.300723910332} + 10 | M | {0.347418963909,0.242429286242,0.410151779652} + 11 | F | {0.4157371521,0.316571623087,0.267691165209} + 13 | F | {0.338543832302,0.32665386796,0.334802359343} + 14 | F | {0.400314897299,0.293721526861,0.305963635445} + 17 | I | {0.175603896379,0.608917593956,0.215478509665} + 21 | M | {0.280931055546,0.333337903023,0.385731071234} + 25 | F | {0.498989373446,0.185877665877,0.315133005381} +</pre> + +-# Show roc curve and metrics tables +<pre class="example"> +SELECT * FROM xgb_score_out_roc_curve limit 10;</pre> +<pre class="result"> + fpr | tpr | thresholds +---------------------------+---------------------------+--------------------------- + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} + {0.00000,0.00000,0.00000} | {0.00077,0.00075,0.00065} | {0.54791,0.65354,0.49352} +</pre> +<pre class="example"> +\\x on +SELECT * FROM xgb_score_out_metrics; +</pre> +<pre class="result"> +-[ RECORD 1 ]-------------+--------------------------------------------------------------------------------------- +precision | {0.549047282992237,0.70062893081761,0.598290598290598} +recall | {0.595256312165264,0.830104321907601,0.458115183246073} +fscore | {0.571218795888399,0.759890859481583,0.518902891030393} +support | {1307,1342,1528} +roc_auc_scores | {0.77659,0.9091,0.74816} +feature_names | {viscera_weight,whole_weight,shucked_weight,shell_weight,length,rings,diameter,height} +feature_importance_scores | {1206,1189,1184,867,766,622,589,526} +</pre> + +-# Run XGBoost grid search with parameter options +<pre class="example"> +DROP TABLE IF EXISTS xgb_out, xgb_out_summary; +SELECT xgboost( + 'abalone', -- Training table + 'id', -- Id column + 'sex', -- Class label column + '*', -- Independent variables + NULL, -- Columns to exclude from features + $$ + { + 'learning_rate': [0.01,0.1], #Regularization on weights (eta). For smaller values, increase n_estimators + 'max_depth': [9,12],#Larger values could lead to overfitting + 'subsample': [0.85],#introduce randomness in samples picked to prevent overfitting + 'colsample_bytree': [0.85],#introduce randomness in features picked to prevent overfitting + 'min_child_weight': [10],#larger values will prevent over-fitting + 'n_estimators':[100] #More estimators, lesser variance (better fit on test set) + } + $$, -- XGBoost grid search parameters + 'xgb_out', -- Grid search results table. + '', -- Class weights + 0.8, -- Training set size ratio + NULL -- Variable used to do the test/train split. +); +\\x on +SELECT * FROM xgb_out_summary; +</pre> +<pre class="result"> +-[ RECORD 1 ]+--------------------------------------------- +mdl_train_ts | 2022-05-16 17:56:11.488767+03 +mdl_name | abalone_xgboost +features | {length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings} +params | ('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12') +fnames | {viscera_weight,whole_weight,shucked_weight,shell_weight,length,rings,diameter,height} +importance | {1276,1205,1200,906,864,822,580,415} +precision | {0.4090909090909091,0.752411575562701,0.5} +recall | {0.45188284518828453,0.8153310104529616,0.42071197411003236} +fscore | {0.42942345924453285,0.782608695652174,0.45694200351493847} +support | {239.0,287.0,309.0} +test_ids | {3975,35,1759,1469,3437,3951...} +-[ RECORD 2 ]+--------------------------------------------- +mdl_train_ts | 2022-05-16 17:56:11.488767+03 +mdl_name | abalone_xgboost +features | {length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings} +params | ('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9') +fnames | {viscera_weight,whole_weight,shucked_weight,shell_weight,length,rings,diameter,height} +importance | {1268,1196,1182,860,832,668,595,461} +precision | {0.46096654275092935,0.6566265060240963,0.5213675213675214} +recall | {0.4901185770750988,0.8288973384030418,0.3824451410658307} +fscore | {0.475095785440613,0.7327731092436974,0.4412296564195299} +support | {253.0,263.0,319.0} +test_ids | {2988,2303,2034,3085,2465,2887...} +-[ RECORD 3 ]+--------------------------------------------- +mdl_train_ts | 2022-05-16 17:56:11.488767+03 +mdl_name | abalone_xgboost +features | {length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings} +params | ('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9') +fnames | {shucked_weight,whole_weight,viscera_weight,shell_weight,length,diameter,height,rings} +importance | {998,948,862,765,629,489,441,383} +precision | {0.4635036496350365,0.6986301369863014,0.45724907063197023} +recall | {0.4847328244274809,0.75,0.40863787375415284} +fscore | {0.47388059701492535,0.7234042553191489,0.43157894736842106} +support | {262.0,272.0,301.0} +test_ids | {396,2150,809,2846,1108,1841...} +-[ RECORD 3 ]+--------------------------------------------- +mdl_train_ts | 2022-05-16 17:56:11.488767+03 +mdl_name | abalone_xgboost +features | {length,diameter,height,whole_weight,shucked_weight,viscera_weight,shell_weight,rings} +params | ('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12') +fnames | {shucked_weight,viscera_weight,whole_weight,shell_weight,length,diameter,height,rings} +importance | {1101,1056,1045,958,680,621,458,458} +precision | {0.40484429065743943,0.7016949152542373,0.43824701195219123} +recall | {0.4517374517374517,0.75,0.36666666666666664} +fscore | {0.42700729927007297,0.7250437828371278,0.3992740471869329} +support | {259.0,276.0,300.0} +test_ids | {1740,2777,1907,581,3525,1022...} +</pre> + +-# Run XGBoost prediction using params_index=2 +<pre class="example"> +DROP TABLE IF EXISTS xgb_score_out, xgb_score_out_metrics, xgb_score_out_roc_curve; +SELECT xgboost_predict( + 'abalone', -- test_table + 'xgb_out', -- model_table + 'xgb_score_out', -- predict_output_table + 'id', -- id_column + 'sex', -- class_label + 2 -- params_index +); +\\x off +SELECT * FROM xgb_score_out LIMIT 10; +</pre> +<pre class="result"> + id | sex_predicted | sex_proba_predicted +----+---------------+------------------------------------------------ + 5 | I | {0.178420484066,0.599636971951,0.221942588687} + 6 | I | {0.185853347182,0.602131128311,0.212015494704} + 9 | I | {0.253592431545,0.440728843212,0.30567869544} + 10 | M | {0.374555230141,0.249226689339,0.376218110323} + 11 | F | {0.402999937534,0.336779236794,0.260220855474} + 13 | I | {0.338803291321,0.36541134119,0.295785397291} + 14 | F | {0.377499818802,0.301990658045,0.320509523153} + 17 | I | {0.179169252515,0.602536559105,0.218294218183} + 21 | M | {0.27216938138,0.33275142312,0.395079195499} + 25 | F | {0.449239164591,0.187060594559,0.36370024085} +</pre> + + + + +@anchor literature +@literature + +[1] Chen, T., & Guestrin, C. (2016). XGBoost: A Scalable Tree Boosting System. +In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge +Discovery and Data Mining (pp. 785–794). New York, NY, USA: ACM. +https://doi.org/10.1145/2939672.2939785 + +*/ + +m4_include(`SQLCommon.m4') + +DROP FUNCTION IF EXISTS MADLIB_SCHEMA.__serialize_pandas_dframe_as_bytea__( + TEXT, + TEXT, + TEXT, + TEXT[] +); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__serialize_pandas_dframe_as_bytea__( + source_table TEXT, + id_column TEXT, + class_label TEXT, + features TEXT +) +RETURNS BYTEA +AS +$$ + PythonFunctionBodyOnly(mxgboost, madlib_xgboost) + with AOControl(False): + return madlib_xgboost.serialize_pandas_dframe_as_bytea(schema_madlib, source_table, id_column, class_label, features) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + + + +DROP TYPE IF EXISTS MADLIB_SCHEMA.xgb_gridsearch_train_results_type CASCADE; +CREATE TYPE MADLIB_SCHEMA.xgb_gridsearch_train_results_type +AS +( + features TEXT[], + mdl BYTEA, + params TEXT, + fnames TEXT[], + importance TEXT[], + precision TEXT[], + recall TEXT[], + fscore TEXT[], + support TEXT[], + test_ids INTEGER[] +); + +DROP FUNCTION IF EXISTS MADLIB_SCHEMA.__xgboost_train_parallel__( + BYTEA, + TEXT[], + TEXT, + TEXT, + TEXT, + NUMERIC +); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__xgboost_train_parallel__( + dframe BYTEA, + features_all TEXT[], + class_label TEXT, + params TEXT, + class_weights TEXT, + train_set_size NUMERIC, + id_column TEXT, + train_set_split_var TEXT +) +RETURNS MADLIB_SCHEMA.xgb_gridsearch_train_results_type +AS +$$ + PythonFunctionBodyOnly(mxgboost, madlib_xgboost) + with AOControl(False): + return madlib_xgboost.xgboost_train(schema_madlib, dframe, features_all, class_label, params, class_weights, train_set_size, id_column, train_set_split_var) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +DROP FUNCTION IF EXISTS MADLIB_SCHEMA.xgboost( + TEXT, + TEXT, + TEXT, + TEXT[], + TEXT, + TEXT, + TEXT, + TEXT, + NUMERIC, + TEXT +); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.xgboost( Review Comment: Since the params after `list_of_features` are optional, we probably should support the following (AFAIk we do it for the other modules): ``` SELECT xgboost( 'abalone', -- Training table 'id', -- Id column 'sex', -- Class label column 'rings'); -- Independent variables ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@madlib.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org