This is an automated email from the ASF dual-hosted git repository. njayaram pushed a commit to branch bugfix-svm-class-weights in repository https://gitbox.apache.org/repos/asf/madlib.git
commit dec12cd0d034372553545dd3f40bb9fb18bbbbfc Author: Nandish Jayaram <[email protected]> AuthorDate: Fri May 31 13:41:05 2019 -0700 SVM: Fix class weights when specified as a mapping JIRA: MADLIB-1346 Providing a mapping for the class_weight param was resulting in an error since the dependent var's levels were not quoted when used in an internal query. This PR now quotes the values. Additionally, this PR adds the following new error checks for the class_weight param: 1. Permitted values are `balanced`, a mapping of the form `{key:value}`, empty string. 2. Keys in a mapping must be legitimate class labels in the source table. 3. Values in a mapping must be numbers. 4. There has to be at least one key-value and at most 2 key-value pairs. This PR still does not address the issue of having special characters in class levels used in a mapping. MADLIB-1354 was created to track that issue. Closes #403 --- src/ports/postgres/modules/svm/svm.py_in | 59 ++++-- src/ports/postgres/modules/svm/svm.sql_in | 35 ++-- src/ports/postgres/modules/svm/test/svm.sql_in | 273 +++++++++++++++---------- 3 files changed, 225 insertions(+), 142 deletions(-) diff --git a/src/ports/postgres/modules/svm/svm.py_in b/src/ports/postgres/modules/svm/svm.py_in index d2d22c4..2890ae1 100644 --- a/src/ports/postgres/modules/svm/svm.py_in +++ b/src/ports/postgres/modules/svm/svm.py_in @@ -8,19 +8,19 @@ from kernel_approximation import create_kernel, load_kernel from utilities.control import MinWarning from utilities.in_mem_group_control import GroupIterationController +from utilities.utilities import _assert +from utilities.utilities import _string_to_array +from utilities.utilities import _string_to_array_with_quotes +from utilities.utilities import add_postfix from utilities.utilities import extract_keyvalue_params +from utilities.utilities import get_grouping_col_str +from utilities.utilities import num_features, num_samples from utilities.utilities import preprocess_keyvalue_params -from utilities.utilities import add_postfix -from utilities.utilities import _string_to_array_with_quotes -from utilities.utilities import _string_to_array -from utilities.utilities import _assert from utilities.utilities import unique_string -from utilities.utilities import num_features, num_samples from utilities.validate_args import cols_in_tbl_valid from utilities.validate_args import explicit_bool_to_text from utilities.validate_args import get_expr_type -from utilities.utilities import get_grouping_col_str from utilities.validate_args import input_tbl_valid from utilities.validate_args import is_var_valid from utilities.validate_args import output_tbl_valid @@ -945,6 +945,7 @@ def _compute_class_weight_sql(source_table, dependent_varname, return "1" dep_to_weight = defaultdict(float) + class_weight_str = class_weight_str.strip() if class_weight_str == "balanced": # use half of n_samples since only doing binary classification # Change the '2' to n_classes for multinomial @@ -956,21 +957,55 @@ def _compute_class_weight_sql(source_table, dependent_varname, src=source_table)) for each_count in bin_count: dep_to_weight[each_count['k']] = n_samples_per_class / each_count['v'] - else: - class_weight_splits = preprocess_keyvalue_params(class_weight_str, - split_char=':') + elif _is_class_weights_str_a_mapping(class_weight_str): + # preprocess_keyvalue_params() does not seem to handle special + # chars as expected. TODO: Fix it in MADLIB-1354. + class_weight_splits = preprocess_keyvalue_params( + class_weight_str, split_char=':') + + _assert(class_weight_splits and len(class_weight_splits)<=2, + "SVM: Only binary classification is supported. The " + "class_weight param should have at least one and at most " + "two labels in it.") + # Cast the distinct class values' array to a text array since a + # numeric class will show up with suffix 'L' sometimes, and that + # may cause issues when we try to check if a class level specified + # in class_weight (a string) exists in the distinct class levels + # or not. + distinct_class_levels = plpy.execute(""" + SELECT array_agg(DISTINCT({0}))::TEXT[] AS labels + FROM {1} + """.format(dependent_varname, source_table))[0]['labels'] for each_pair in class_weight_splits: k, v = each_pair.split(":") - dep_to_weight[k.strip()] = float(v.strip()) + _assert(k in distinct_class_levels, + "SVM: Key '{0}' in '{1}' is not a valid class label.". + format(k, class_weight_str)) + try: + dep_to_weight[k.strip()] = float(v.strip()) + except ValueError: + plpy.error("SVM: Weights for a class label must be numeric." + " Invalid class_weights param ({0})".format( + class_weight_str)) + else: + plpy.error("SVM: Invalid class_weight param ({0})".format( + class_weight_str)) class_weight_sql = "CASE " for k, v in dep_to_weight.items(): - class_weight_sql += ("WHEN {dep} = {k} THEN {v}::FLOAT8 \n". - format(dep=dependent_varname, k=k, v=v)) + class_weight_sql += ("WHEN {dep}=$madlib${k}$madlib$ THEN {v}::FLOAT8 \n". + format(dep=dependent_varname, k=k, v=v)) class_weight_sql += "ELSE 1.0 END" return class_weight_sql # ------------------------------------------------------------------------- +def _is_class_weights_str_a_mapping(class_weight_str): + """ + Check if the class_weight_str begins with a '{' and ends with a '}' + """ + return len(class_weight_str)>2 and class_weight_str[0]=='{' and \ + class_weight_str[-1]=='}' + def _svm_parsed_params(schema_madlib, source_table, model_table, dependent_varname, independent_varname, diff --git a/src/ports/postgres/modules/svm/svm.sql_in b/src/ports/postgres/modules/svm/svm.sql_in index 2320179..cb6b69e 100644 --- a/src/ports/postgres/modules/svm/svm.sql_in +++ b/src/ports/postgres/modules/svm/svm.sql_in @@ -45,7 +45,7 @@ a hyperplane or other nonlinear decision boundary. @anchor svm_classification @par Classification Training Function -The SVM classification training function has the following format: +The SVM binary classification training function has the following format: <pre class="syntax"> svm_classification( source_table, @@ -70,9 +70,9 @@ svm_classification( </DD> <DT>dependent_varname</DT> - <DD> TEXT. Name of the dependent variable column. For classification, this column - can contain values of any type, but must assume exactly two distinct values. - Otherwise, an error will be thrown. + <DD> TEXT. Name of the dependent variable column. For classification, this + column can contain values of any type, but must assume exactly two distinct + values since only binary classification is currently supported. </DD> <DT>independent_varname</DT> @@ -480,19 +480,20 @@ while the other k - 1 folds form the training set. </DD> <DT>class_weight</dt> -<DD>Default: 1 for classification, 'balanced' for one-class novelty detection, -n/a for regression. - -Set the weight for the positive and negative classes. If not given, all classes -are set to have weight one. -If class_weight = balanced, values of y are automatically adjusted as inversely -proportional to class frequencies in the input data i.e. the weights are set as -n_samples / (n_classes * bincount(y)). - -Alternatively, class_weight can be a mapping, giving the weight for each class. -Eg. For dependent variable values 'a' and 'b', the class_weight can be -{a: 2, b: 3}. This would lead to each 'a' tuple's y value multiplied by 2 and -each 'b' y value will be multiplied by 3. +<DD>Default: NULL for classification, 'balanced' for one-class novelty detection, +this param is not applicable for regression. + +Set the weight for the classes. If not given (empty/NULL), all classes are set to have +equal weight. If 'class_weight = balanced', values of y are automatically adjusted +as inversely proportional to class frequencies in the input data i.e. the weights +are set as n_samples / (2 * bincount(y)). + +Alternatively, 'class_weight' can be a mapping, giving the weight for each class. +E.g., for dependent variable values 'a' and 'b', the 'class_weight' might be +{a: 1, b: 3}. This gives three times the weight to observations with class value +'b' compared to 'a'. (In the SVM algorithm, this translates into observations +with class value 'b' contributing 3x to learning in the stochastic gradient step +compared to 'a'.) For regression, the class weights are always one. </DD> diff --git a/src/ports/postgres/modules/svm/test/svm.sql_in b/src/ports/postgres/modules/svm/test/svm.sql_in index 217b9a0..9dea0bb 100644 --- a/src/ports/postgres/modules/svm/test/svm.sql_in +++ b/src/ports/postgres/modules/svm/test/svm.sql_in @@ -684,122 +684,122 @@ CREATE TABLE svm_unbalanced ( index bigint, x1 double precision, x2 double precision, - y bigint + y bigint, + y_text text ); -COPY svm_unbalanced (index, x1, x2, y) FROM stdin delimiter '|'; -0|2.43651804549486251|-0.917634620475113127|0 -1|-0.792257628395183544|-1.60945293323425576|0 -2|1.29811144398701783|-3.45230804532042423|0 -3|2.61721764632472009|-1.14181035134265407|0 -4|0.478558644085647855|-0.374055563216115106|0 -5|2.19316190556746093|-3.09021106424648107|0 -6|-0.483625806020261229|-0.576081532002623464|0 -7|1.70065416350315601|-1.64983690097104629|0 -8|-0.258642311325653629|-1.31678762688205753|0 -9|0.0633206200733892471|0.87422282057373335|0 -10|-1.65092876581938186|1.7170855647594212|0 -11|1.35238608088919321|0.753741508352802292|0 -12|1.35128392389661767|-1.02559178876149959|0 -13|-0.184335338277972272|-1.40365415138860317|0 -14|-0.40183211943902386|0.795533200107279015|0 -15|-1.03749112758796347|-0.595130290283966024|0 -16|-1.03075905017939906|-1.26780846224807942|0 -17|-1.00686919625522853|-0.0189968983783520423|0 -18|-1.67596552295291668|0.351623546725638225|0 -19|2.48970326566480571|1.11306624086600348|0 -20|-0.287753328542422415|-1.3314434461272544|0 -21|-1.12073744062625646|2.53868190154161999|0 -22|0.0762116321640434469|-0.955493469854030053|0 -23|0.286373227001199049|3.15038270471826332|0 -24|0.180238428722443722|0.925804664561128865|0 -25|0.450255479933741265|-0.528374769740277972|0 -26|-1.71377729703321036|-0.524014083619316229|0 -27|-0.313341350062167179|0.879934786773296507|0 -28|1.25847512081175728|1.39665312195533597|0 -29|0.428380987881388176|1.32771174640609213|0 -30|-1.1315969114949791|1.87930223284993181|0 -31|0.769394730627013246|-0.447139252654073505|0 -32|0.73277721980624555|-0.113357569531583588|0 -33|1.69744408117714052|2.27972522463329819|0 -34|3.27836310979974233|-2.09474450323220651|0 -35|-2.16617070814438417|-0.756698794419676801|0 -36|0.240055604171745707|1.31425338167433736|0 -37|0.473452420862407852|-3.03330182373600454|0 -38|-0.459306018942557737|1.24196196391086922|0 -39|0.345142103046575111|1.14301677046803718|0 -40|-0.333492213915538904|-0.301137103394996164|0 -41|0.279842086482426478|0.615077470812384508|0 -42|0.297449580190154605|0.178512968711188214|0 -43|-1.00599342943354575|0.56634567948137915|0 -44|0.182731906487155399|1.6942258618678796|0 -45|1.7983768198522605|0.277734626225915771|0 -46|-0.562927425135171244|-0.958095611181333573|0 -47|0.635241531096169321|0.116010102522839123|0 -48|-0.515780513356613346|0.065395285251370408|0 -49|-0.930001265922193898|1.04704805110832844|0 -50|-0.670692847178997353|1.8367615572082483|0 -51|0.605237462686200045|0.890367784855600419|0 -52|-1.64236776861156275|0.254073649588002159|0 -53|1.11083467664441216|-1.43055090271190188|0 -54|-0.399327759005433103|0.0489218200400378389|0 -55|-2.05967598037013344|0.472739088063437674|0 -56|1.2692409713775501|-1.28927391124797941|0 -57|0.525818967996161013|-1.96842511685614774|0 -58|-0.0580432638990766719|-2.42365853205494197|0 -59|1.682126562353496|0.613350806905241686|0 -60|-0.0369254338136675297|-1.16274242875373934|0 -61|1.91063389523816496|2.95065262388210225|0 -62|-2.78697279667012809|1.85424604567923046|0 -63|2.44147612972335981|0.507017544861713687|0 -64|-1.79890204850277913|1.29501797631603233|0 -65|-0.271380453117225695|-0.905880941689885866|0 -66|-1.84508720350044264|0.825806243964323006|0 -67|1.18921029887902163|-0.935296094519687427|0 -68|0.78086450561005627|-1.71651208443471415|0 -69|1.20279154780701703|0.0698509476362183107|0 -70|-0.279854657861023148|-0.152618808793717808|0 -71|1.30332923550880198|1.12561745979751215|0 -72|0.794197986529063815|0.206551814996079108|0 -73|0.116731691869058879|0.927570392997786763|0 -74|0.348741838768106827|1.02382711029672779|0 -75|-0.465175160277089994|-3.65225664616070844|0 -76|1.55823690278912119|3.28046947046138637|0 -77|0.662046665352873154|-0.150232849925249656|0 -78|-0.204667115844049508|-0.178581281662214819|0 -79|0.0261141124500068982|-1.68302809312033252|0 -80|-0.775641686880341852|-1.49554024147539444|0 -81|0.373198742081655654|-0.444961728556294012|0 -82|0.742816985966940679|-0.26205473961375142|0 -83|1.47950278173186289|0.320300852003162662|0 -84|3.28604959345460035|-2.8445413843366385|0 -85|-0.970375032382362002|1.35223033747306642|0 -86|3.79248856020959701|-0.37295216657319008|0 -87|0.0655034897675836614|-0.339471363770407764|0 -88|1.9971856688813876|-0.430961795214028331|0 -89|1.02010475981715665|-0.479702398348006764|0 -90|-1.9088381328689914|0.470321580695148234|0 -91|0.754777220152989092|1.93983882379839256|0 -92|-0.165670539625974472|-0.926043095568541363|0 -93|0.844141644928539492|0.361105638356598369|0 -94|0.420997615683958493|-0.109669055620916639|0 -95|1.7405078549906543|0.554239074563585565|0 -96|2.85698806251147186|1.66658504784075689|0 -97|0.988574694150315292|-2.44115751092438593|0 -98|0.903478920443443689|0.630423305470589446|0 -99|1.2164275092053336|1.56666314206088808|0 -100|1.79956090410553671|2.41200280922520394|1 -101|1.71884728449045499|2.97743903750451722|1 -102|1.33402416674137592|1.1196557198006083|1 -103|1.1746393670879498|1.55472220791847571|1 -104|1.4404423007201359|2.97803945185182073|1 -105|1.83675025096090794|1.32866210531128193|1 -106|2.55719148838989607|1.7067380305892037|1 -107|1.38157331172930142|2.43791946382464975|1 -108|2.31168108828901619|1.7825216585223862|1 -109|2.70377000012061419|2.06455078985536256|1 +COPY svm_unbalanced (index, x1, x2, y, y_text) FROM stdin delimiter '|'; +0|2.43651804549486251|-0.917634620475113127|0|zero +1|-0.792257628395183544|-1.60945293323425576|0|zero +2|1.29811144398701783|-3.45230804532042423|0|zero +3|2.61721764632472009|-1.14181035134265407|0|zero +4|0.478558644085647855|-0.374055563216115106|0|zero +5|2.19316190556746093|-3.09021106424648107|0|zero +6|-0.483625806020261229|-0.576081532002623464|0|zero +7|1.70065416350315601|-1.64983690097104629|0|zero +8|-0.258642311325653629|-1.31678762688205753|0|zero +9|0.0633206200733892471|0.87422282057373335|0|zero +10|-1.65092876581938186|1.7170855647594212|0|zero +11|1.35238608088919321|0.753741508352802292|0|zero +12|1.35128392389661767|-1.02559178876149959|0|zero +13|-0.184335338277972272|-1.40365415138860317|0|zero +14|-0.40183211943902386|0.795533200107279015|0|zero +15|-1.03749112758796347|-0.595130290283966024|0|zero +16|-1.03075905017939906|-1.26780846224807942|0|zero +17|-1.00686919625522853|-0.0189968983783520423|0|zero +18|-1.67596552295291668|0.351623546725638225|0|zero +19|2.48970326566480571|1.11306624086600348|0|zero +20|-0.287753328542422415|-1.3314434461272544|0|zero +21|-1.12073744062625646|2.53868190154161999|0|zero +22|0.0762116321640434469|-0.955493469854030053|0|zero +23|0.286373227001199049|3.15038270471826332|0|zero +24|0.180238428722443722|0.925804664561128865|0|zero +25|0.450255479933741265|-0.528374769740277972|0|zero +26|-1.71377729703321036|-0.524014083619316229|0|zero +27|-0.313341350062167179|0.879934786773296507|0|zero +28|1.25847512081175728|1.39665312195533597|0|zero +29|0.428380987881388176|1.32771174640609213|0|zero +30|-1.1315969114949791|1.87930223284993181|0|zero +31|0.769394730627013246|-0.447139252654073505|0|zero +32|0.73277721980624555|-0.113357569531583588|0|zero +33|1.69744408117714052|2.27972522463329819|0|zero +34|3.27836310979974233|-2.09474450323220651|0|zero +35|-2.16617070814438417|-0.756698794419676801|0|zero +36|0.240055604171745707|1.31425338167433736|0|zero +37|0.473452420862407852|-3.03330182373600454|0|zero +38|-0.459306018942557737|1.24196196391086922|0|zero +39|0.345142103046575111|1.14301677046803718|0|zero +40|-0.333492213915538904|-0.301137103394996164|0|zero +41|0.279842086482426478|0.615077470812384508|0|zero +42|0.297449580190154605|0.178512968711188214|0|zero +43|-1.00599342943354575|0.56634567948137915|0|zero +44|0.182731906487155399|1.6942258618678796|0|zero +45|1.7983768198522605|0.277734626225915771|0|zero +46|-0.562927425135171244|-0.958095611181333573|0|zero +47|0.635241531096169321|0.116010102522839123|0|zero +48|-0.515780513356613346|0.065395285251370408|0|zero +49|-0.930001265922193898|1.04704805110832844|0|zero +50|-0.670692847178997353|1.8367615572082483|0|zero +51|0.605237462686200045|0.890367784855600419|0|zero +52|-1.64236776861156275|0.254073649588002159|0|zero +53|1.11083467664441216|-1.43055090271190188|0|zero +54|-0.399327759005433103|0.0489218200400378389|0|zero +55|-2.05967598037013344|0.472739088063437674|0|zero +56|1.2692409713775501|-1.28927391124797941|0|zero +57|0.525818967996161013|-1.96842511685614774|0|zero +58|-0.0580432638990766719|-2.42365853205494197|0|zero +59|1.682126562353496|0.613350806905241686|0|zero +60|-0.0369254338136675297|-1.16274242875373934|0|zero +61|1.91063389523816496|2.95065262388210225|0|zero +62|-2.78697279667012809|1.85424604567923046|0|zero +63|2.44147612972335981|0.507017544861713687|0|zero +64|-1.79890204850277913|1.29501797631603233|0|zero +65|-0.271380453117225695|-0.905880941689885866|0|zero +66|-1.84508720350044264|0.825806243964323006|0|zero +67|1.18921029887902163|-0.935296094519687427|0|zero +68|0.78086450561005627|-1.71651208443471415|0|zero +69|1.20279154780701703|0.0698509476362183107|0|zero +70|-0.279854657861023148|-0.152618808793717808|0|zero +71|1.30332923550880198|1.12561745979751215|0|zero +72|0.794197986529063815|0.206551814996079108|0|zero +73|0.116731691869058879|0.927570392997786763|0|zero +74|0.348741838768106827|1.02382711029672779|0|zero +75|-0.465175160277089994|-3.65225664616070844|0|zero +76|1.55823690278912119|3.28046947046138637|0|zero +77|0.662046665352873154|-0.150232849925249656|0|zero +78|-0.204667115844049508|-0.178581281662214819|0|zero +79|0.0261141124500068982|-1.68302809312033252|0|zero +80|-0.775641686880341852|-1.49554024147539444|0|zero +81|0.373198742081655654|-0.444961728556294012|0|zero +82|0.742816985966940679|-0.26205473961375142|0|zero +83|1.47950278173186289|0.320300852003162662|0|zero +84|3.28604959345460035|-2.8445413843366385|0|zero +85|-0.970375032382362002|1.35223033747306642|0|zero +86|3.79248856020959701|-0.37295216657319008|0|zero +87|0.0655034897675836614|-0.339471363770407764|0|zero +88|1.9971856688813876|-0.430961795214028331|0|zero +89|1.02010475981715665|-0.479702398348006764|0|zero +90|-1.9088381328689914|0.470321580695148234|0|zero +91|0.754777220152989092|1.93983882379839256|0|zero +92|-0.165670539625974472|-0.926043095568541363|0|zero +93|0.844141644928539492|0.361105638356598369|0|zero +94|0.420997615683958493|-0.109669055620916639|0|zero +95|1.7405078549906543|0.554239074563585565|0|zero +96|2.85698806251147186|1.66658504784075689|0|zero +97|0.988574694150315292|-2.44115751092438593|0|zero +98|0.903478920443443689|0.630423305470589446|0|zero +99|1.2164275092053336|1.56666314206088808|0|zero +100|1.79956090410553671|2.41200280922520394|1|one +101|1.71884728449045499|2.97743903750451722|1|one +102|1.33402416674137592|1.1196557198006083|1|one +103|1.1746393670879498|1.55472220791847571|1|one +104|1.4404423007201359|2.97803945185182073|1|one +105|1.83675025096090794|1.32866210531128193|1|one +106|2.55719148838989607|1.7067380305892037|1|one +107|1.38157331172930142|2.43791946382464975|1|one +108|2.31168108828901619|1.7825216585223862|1|one +109|2.70377000012061419|2.06455078985536256|1|one \. - DROP TABLE IF EXISTS svm_out, svm_out_summary; SELECT svm_classification( 'svm_unbalanced', @@ -822,6 +822,53 @@ FROM svm_unbalanced JOIN svm_predict_out using (index) WHERE y = prediction and y = 1; +-- Test case with class_weight specified as a mapping. svm_unbalanced has +-- unbalanced data with 10x more examples for class 0 compared to 1. A +-- mapping with {1:10, 0:1} should be the same as balanced. +DROP TABLE IF EXISTS svm_out, svm_out_summary; +SELECT svm_classification( + 'svm_unbalanced', + 'svm_out', + 'y', + 'ARRAY[1, x1, x2]', + 'linear', + NULL, + NULL, + 'max_iter=1000, init_stepsize=0.1, class_weight={1:10, 0:1}' + ); + +DROP TABLE IF EXISTS svm_predict_out; +SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out'); +-- we check if the accuracy in prediction the unbalanced class is relatively +-- good. Without the class weight, this can go as low as 50%. +SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced numeric class with mapping class_weight is too low') +FROM svm_unbalanced JOIN svm_predict_out +using (index) +WHERE y = prediction and y = 1; + +-- Test case for class_weight with text class values. +DROP TABLE IF EXISTS svm_out, svm_out_summary; +SELECT svm_classification( + 'svm_unbalanced', + 'svm_out', + 'y_text', + 'ARRAY[1, x1, x2]', + 'linear', + NULL, + NULL, + 'max_iter=1000, init_stepsize=0.1, class_weight={zero:1, one:10}' + ); + +DROP TABLE IF EXISTS svm_predict_out; +SELECT svm_predict('svm_out', 'svm_unbalanced', 'index', 'svm_predict_out'); + +-- we check if the accuracy in prediction the unbalanced class is relatively +-- good. Without the class weight, this can go as low as 50%. +SELECT assert(count(*)/10. >= 0.70, 'Prediction accuracy for unbalanced text class with mapping class_weight is too low') +FROM svm_unbalanced JOIN svm_predict_out +using (index) +WHERE y_text = prediction and y_text = 'one'; + -- Cross validation tests SELECT svm_one_class( 'svm_normalized',
