http://git-wip-us.apache.org/repos/asf/madlib-site/blob/acd339f6/community-artifacts/SVM-v1.ipynb ---------------------------------------------------------------------- diff --git a/community-artifacts/SVM-v1.ipynb b/community-artifacts/SVM-v1.ipynb new file mode 100644 index 0000000..405710d --- /dev/null +++ b/community-artifacts/SVM-v1.ipynb @@ -0,0 +1,2806 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Support Vector Machines\n", + "Support Vector Machines (SVMs) are models for regression and classification tasks. SVM models have two particularly desirable features: robustness in the presence of noisy data and applicability to a variety of data configurations. At its core, a linear SVM model is a hyperplane separating two distinct classes of data (in the case of classification problems), in such a way that the distance between the hyperplane and the nearest training data point (called the margin) is maximized. Vectors that lie on this margin are called support vectors. With the support vectors fixed, perturbations of vectors beyond the margin will not affect the model; this contributes to the modelâs robustness. By substituting a kernel function for the usual inner product, one can approximate a large variety of decision boundaries in addition to linear hyperplanes." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The sql extension is already loaded. To reload it, use:\n", + " %reload_ext sql\n" + ] + } + ], + "source": [ + "%load_ext sql" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "u'Connected: gpadmin@madlib'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Greenplum Database 5.4.0 on GCP (demo machine)\n", + "%sql postgresql://gpadmin@35.184.253.255:5432/madlib\n", + " \n", + "# PostgreSQL local\n", + "#%sql postgresql://fmcquillan@localhost:5432/madlib\n", + "\n", + "# Greenplum Database 4.3.10.0\n", + "#%sql postgresql://gpdbchina@10.194.10.68:61000/madlib" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>version</th>\n", + " </tr>\n", + " <tr>\n", + " <td>MADlib version: 1.15-dev, git revision: rc/1.14-rc1-25-gda13eb7, cmake configuration time: Tue Jul 10 21:37:52 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(u'MADlib version: 1.15-dev, git revision: rc/1.14-rc1-25-gda13eb7, cmake configuration time: Tue Jul 10 21:37:52 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7',)]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql select madlib.version();" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# Classification\n", + "# 1. Create input data set" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Done.\n", + "15 rows affected.\n", + "15 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>tax</th>\n", + " <th>bedroom</th>\n", + " <th>bath</th>\n", + " <th>price</th>\n", + " <th>size</th>\n", + " <th>lot</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>590</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>50000</td>\n", + " <td>770</td>\n", + " <td>22100</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>1050</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>85000</td>\n", + " <td>1410</td>\n", + " <td>12000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>20</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " <td>22500</td>\n", + " <td>1060</td>\n", + " <td>3500</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>870</td>\n", + " <td>2</td>\n", + " <td>2.0</td>\n", + " <td>90000</td>\n", + " <td>1300</td>\n", + " <td>17500</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>1320</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>133000</td>\n", + " <td>1500</td>\n", + " <td>30000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>1350</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>90500</td>\n", + " <td>820</td>\n", + " <td>25700</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>2790</td>\n", + " <td>3</td>\n", + " <td>2.5</td>\n", + " <td>260000</td>\n", + " <td>2130</td>\n", + " <td>25000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>680</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>142500</td>\n", + " <td>1170</td>\n", + " <td>22000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>1840</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>160000</td>\n", + " <td>1500</td>\n", + " <td>19000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>3680</td>\n", + " <td>4</td>\n", + " <td>2.0</td>\n", + " <td>240000</td>\n", + " <td>2790</td>\n", + " <td>20000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>1660</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " <td>87000</td>\n", + " <td>1030</td>\n", + " <td>17500</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>1620</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>118600</td>\n", + " <td>1250</td>\n", + " <td>20000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>3100</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>140000</td>\n", + " <td>1760</td>\n", + " <td>38000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>2070</td>\n", + " <td>2</td>\n", + " <td>3.0</td>\n", + " <td>148000</td>\n", + " <td>1550</td>\n", + " <td>14000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>650</td>\n", + " <td>3</td>\n", + " <td>1.5</td>\n", + " <td>65000</td>\n", + " <td>1450</td>\n", + " <td>12000</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, 590, 2, 1.0, 50000, 770, 22100),\n", + " (2, 1050, 3, 2.0, 85000, 1410, 12000),\n", + " (3, 20, 3, 1.0, 22500, 1060, 3500),\n", + " (4, 870, 2, 2.0, 90000, 1300, 17500),\n", + " (5, 1320, 3, 2.0, 133000, 1500, 30000),\n", + " (6, 1350, 2, 1.0, 90500, 820, 25700),\n", + " (7, 2790, 3, 2.5, 260000, 2130, 25000),\n", + " (8, 680, 2, 1.0, 142500, 1170, 22000),\n", + " (9, 1840, 3, 2.0, 160000, 1500, 19000),\n", + " (10, 3680, 4, 2.0, 240000, 2790, 20000),\n", + " (11, 1660, 3, 1.0, 87000, 1030, 17500),\n", + " (12, 1620, 3, 2.0, 118600, 1250, 20000),\n", + " (13, 3100, 3, 2.0, 140000, 1760, 38000),\n", + " (14, 2070, 2, 3.0, 148000, 1550, 14000),\n", + " (15, 650, 3, 1.5, 65000, 1450, 12000)]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql \n", + "DROP TABLE IF EXISTS houses;\n", + "\n", + "CREATE TABLE houses (id INT, tax INT, bedroom INT, bath FLOAT, price INT,\n", + " size INT, lot INT);\n", + "\n", + "INSERT INTO houses VALUES \n", + " (1 , 590 , 2 , 1 , 50000 , 770 , 22100),\n", + " (2 , 1050 , 3 , 2 , 85000 , 1410 , 12000),\n", + " (3 , 20 , 3 , 1 , 22500 , 1060 , 3500),\n", + " (4 , 870 , 2 , 2 , 90000 , 1300 , 17500),\n", + " (5 , 1320 , 3 , 2 , 133000 , 1500 , 30000),\n", + " (6 , 1350 , 2 , 1 , 90500 , 820 , 25700),\n", + " (7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000),\n", + " (8 , 680 , 2 , 1 , 142500 , 1170 , 22000),\n", + " (9 , 1840 , 3 , 2 , 160000 , 1500 , 19000),\n", + " (10 , 3680 , 4 , 2 , 240000 , 2790 , 20000),\n", + " (11 , 1660 , 3 , 1 , 87000 , 1030 , 17500),\n", + " (12 , 1620 , 3 , 2 , 118600 , 1250 , 20000),\n", + " (13 , 3100 , 3 , 2 , 140000 , 1760 , 38000),\n", + " (14 , 2070 , 2 , 3 , 148000 , 1550 , 14000),\n", + " (15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000);\n", + " \n", + "SELECT * FROM houses ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "# 2. Train linear classification model\n", + "Categorical variable is price < $100,0000." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coef</th>\n", + " <th>loss</th>\n", + " <th>norm_of_gradient</th>\n", + " <th>num_iterations</th>\n", + " <th>num_rows_processed</th>\n", + " <th>num_rows_skipped</th>\n", + " <th>dep_var_mapping</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[0.124749754442359, -0.002823869432027, 0.0751780666986316, 0.00163774992345709]</td>\n", + " <td>0.647742474881</td>\n", + " <td>4412.03185101</td>\n", + " <td>100</td>\n", + " <td>15</td>\n", + " <td>0</td>\n", + " <td>[False, True]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([0.124749754442359, -0.002823869432027, 0.0751780666986316, 0.00163774992345709], 0.647742474880954, 4412.03185100955, 100, 15L, 0L, [False, True])]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS houses_svm, houses_svm_summary;\n", + "\n", + "SELECT madlib.svm_classification('houses',\n", + " 'houses_svm',\n", + " 'price < 100000',\n", + " 'ARRAY[1, tax, bath, size]'\n", + " );\n", + "SELECT * FROM houses_svm;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Predict using linear model\n", + "We want to predict if house price is less than $100,000. We use the training data set for prediction as well, which is not usual but serves to show the syntax. The predicted results are in the \"prediction\" column and the actual data is in the \"actual\" column." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "15 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>tax</th>\n", + " <th>bedroom</th>\n", + " <th>bath</th>\n", + " <th>price</th>\n", + " <th>size</th>\n", + " <th>lot</th>\n", + " <th>prediction</th>\n", + " <th>decision_function</th>\n", + " <th>actual</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>590</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>50000</td>\n", + " <td>770</td>\n", + " <td>22100</td>\n", + " <td>False</td>\n", + " <td>-0.205087702693</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>1050</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>85000</td>\n", + " <td>1410</td>\n", + " <td>12000</td>\n", + " <td>False</td>\n", + " <td>-0.380729623714</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>20</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " <td>22500</td>\n", + " <td>1060</td>\n", + " <td>3500</td>\n", + " <td>True</td>\n", + " <td>1.87946535136</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>870</td>\n", + " <td>2</td>\n", + " <td>2.0</td>\n", + " <td>90000</td>\n", + " <td>1300</td>\n", + " <td>17500</td>\n", + " <td>False</td>\n", + " <td>-0.0525856175296</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>1320</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>133000</td>\n", + " <td>1500</td>\n", + " <td>30000</td>\n", + " <td>False</td>\n", + " <td>-0.99577687725</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>1350</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>90500</td>\n", + " <td>820</td>\n", + " <td>25700</td>\n", + " <td>False</td>\n", + " <td>-2.26934097486</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>2790</td>\n", + " <td>3</td>\n", + " <td>2.5</td>\n", + " <td>260000</td>\n", + " <td>2130</td>\n", + " <td>25000</td>\n", + " <td>False</td>\n", + " <td>-4.0774934572</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>680</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>142500</td>\n", + " <td>1170</td>\n", + " <td>22000</td>\n", + " <td>True</td>\n", + " <td>0.195864017807</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>1840</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>160000</td>\n", + " <td>1500</td>\n", + " <td>19000</td>\n", + " <td>False</td>\n", + " <td>-2.4641889819</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>3680</td>\n", + " <td>4</td>\n", + " <td>2.0</td>\n", + " <td>240000</td>\n", + " <td>2790</td>\n", + " <td>20000</td>\n", + " <td>False</td>\n", + " <td>-5.54741133557</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>1660</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " <td>87000</td>\n", + " <td>1030</td>\n", + " <td>17500</td>\n", + " <td>False</td>\n", + " <td>-2.80081301486</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>1620</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>118600</td>\n", + " <td>1250</td>\n", + " <td>20000</td>\n", + " <td>False</td>\n", + " <td>-2.25237518772</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>3100</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>140000</td>\n", + " <td>1760</td>\n", + " <td>38000</td>\n", + " <td>False</td>\n", + " <td>-5.59644948616</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>2070</td>\n", + " <td>2</td>\n", + " <td>3.0</td>\n", + " <td>148000</td>\n", + " <td>1550</td>\n", + " <td>14000</td>\n", + " <td>False</td>\n", + " <td>-2.9566133884</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>650</td>\n", + " <td>3</td>\n", + " <td>1.5</td>\n", + " <td>65000</td>\n", + " <td>1450</td>\n", + " <td>12000</td>\n", + " <td>True</td>\n", + " <td>0.776739112686</td>\n", + " <td>True</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, 590, 2, 1.0, 50000, 770, 22100, False, -0.205087702692976, True),\n", + " (2, 1050, 3, 2.0, 85000, 1410, 12000, False, -0.380729623714223, True),\n", + " (3, 20, 3, 1.0, 22500, 1060, 3500, True, 1.87946535136497, True),\n", + " (4, 870, 2, 2.0, 90000, 1300, 17500, False, -0.0525856175296444, True),\n", + " (5, 1320, 3, 2.0, 133000, 1500, 30000, False, -0.995776877250374, False),\n", + " (6, 1350, 2, 1.0, 90500, 820, 25700, False, -2.26934097486064, True),\n", + " (7, 2790, 3, 2.5, 260000, 2130, 25000, False, -4.07749345720278, False),\n", + " (8, 680, 2, 1.0, 142500, 1170, 22000, True, 0.195864017807432, False),\n", + " (9, 1840, 3, 2.0, 160000, 1500, 19000, False, -2.46418898190441, False),\n", + " (10, 3680, 4, 2.0, 240000, 2790, 20000, False, -5.54741133557444, False),\n", + " (11, 1660, 3, 1.0, 87000, 1030, 17500, False, -2.80081301486302, True),\n", + " (12, 1620, 3, 2.0, 118600, 1250, 20000, False, -2.25237518772275, False),\n", + " (13, 3100, 3, 2.0, 140000, 1760, 38000, False, -5.59644948615959, False),\n", + " (14, 2070, 2, 3.0, 148000, 1550, 14000, False, -2.95661338839914, False),\n", + " (15, 650, 3, 1.5, 65000, 1450, 12000, True, 0.776739112685544, True)]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS houses_pred;\n", + "\n", + "SELECT madlib.svm_predict('houses_svm', \n", + " 'houses', \n", + " 'id', \n", + " 'houses_pred');\n", + "\n", + "SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred USING (id) ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count the miss-classifications:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>count</th>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(6L,)]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM houses_pred JOIN houses USING (id) \n", + "WHERE houses_pred.prediction != (houses.price < 100000);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. Train using Gaussian kernel\n", + "Next generate a nonlinear model using a Gaussian kernel. This time we specify the initial step size and maximum number of iterations to run. As part of the kernel parameter, we choose 10 as the dimension of the space where we train SVM. A larger number will lead to a more powerful model but run the risk of overfitting. As a result, the model will be a 10 dimensional vector, instead of 4 as in the case of linear model." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coef</th>\n", + " <th>loss</th>\n", + " <th>norm_of_gradient</th>\n", + " <th>num_iterations</th>\n", + " <th>num_rows_processed</th>\n", + " <th>num_rows_skipped</th>\n", + " <th>dep_var_mapping</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[-1.67275666209207, 1.5191640881642, -0.503066422926726, 1.33250956564454, 2.23009854231314, -0.0602475029497933, 1.97466397155921, 2.3668779833279, 0.577739846910355, 2.81255996089823]</td>\n", + " <td>0.0571869097341</td>\n", + " <td>1.18281830047</td>\n", + " <td>177</td>\n", + " <td>15</td>\n", + " <td>0</td>\n", + " <td>[False, True]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([-1.67275666209207, 1.5191640881642, -0.503066422926726, 1.33250956564454, 2.23009854231314, -0.0602475029497933, 1.97466397155921, 2.3668779833279, 0.577739846910355, 2.81255996089823], 0.0571869097340992, 1.18281830047046, 177, 15L, 0L, [False, True])]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;\n", + "\n", + "SELECT madlib.svm_classification( 'houses',\n", + " 'houses_svm_gaussian',\n", + " 'price < 100000',\n", + " 'ARRAY[1, tax, bath, size]',\n", + " 'gaussian',\n", + " 'n_components=10',\n", + " '',\n", + " 'init_stepsize=1, max_iter=200'\n", + " );\n", + "\n", + "SELECT * FROM houses_svm_gaussian;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Predict using Gaussian model\n", + "The predicted results are in the \"prediction\" column and the actual data is in the \"actual\" column." + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "15 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>tax</th>\n", + " <th>bedroom</th>\n", + " <th>bath</th>\n", + " <th>price</th>\n", + " <th>size</th>\n", + " <th>lot</th>\n", + " <th>prediction</th>\n", + " <th>decision_function</th>\n", + " <th>actual</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>590</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>50000</td>\n", + " <td>770</td>\n", + " <td>22100</td>\n", + " <td>True</td>\n", + " <td>1.64923454025</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>1050</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>85000</td>\n", + " <td>1410</td>\n", + " <td>12000</td>\n", + " <td>True</td>\n", + " <td>1.34505433447</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>20</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " <td>22500</td>\n", + " <td>1060</td>\n", + " <td>3500</td>\n", + " <td>True</td>\n", + " <td>1.00000000092</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>870</td>\n", + " <td>2</td>\n", + " <td>2.0</td>\n", + " <td>90000</td>\n", + " <td>1300</td>\n", + " <td>17500</td>\n", + " <td>True</td>\n", + " <td>1.00000000712</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>1320</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>133000</td>\n", + " <td>1500</td>\n", + " <td>30000</td>\n", + " <td>False</td>\n", + " <td>-1.00000001729</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>1350</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>90500</td>\n", + " <td>820</td>\n", + " <td>25700</td>\n", + " <td>True</td>\n", + " <td>1.11113745879</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>2790</td>\n", + " <td>3</td>\n", + " <td>2.5</td>\n", + " <td>260000</td>\n", + " <td>2130</td>\n", + " <td>25000</td>\n", + " <td>False</td>\n", + " <td>-0.29148279088</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>680</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " <td>142500</td>\n", + " <td>1170</td>\n", + " <td>22000</td>\n", + " <td>False</td>\n", + " <td>-1.00000000609</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>1840</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>160000</td>\n", + " <td>1500</td>\n", + " <td>19000</td>\n", + " <td>False</td>\n", + " <td>-1.23665846847</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>3680</td>\n", + " <td>4</td>\n", + " <td>2.0</td>\n", + " <td>240000</td>\n", + " <td>2790</td>\n", + " <td>20000</td>\n", + " <td>False</td>\n", + " <td>-1.0938201061</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>1660</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " <td>87000</td>\n", + " <td>1030</td>\n", + " <td>17500</td>\n", + " <td>True</td>\n", + " <td>1.62636283239</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>1620</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>118600</td>\n", + " <td>1250</td>\n", + " <td>20000</td>\n", + " <td>False</td>\n", + " <td>-1.60116812307</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>3100</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>140000</td>\n", + " <td>1760</td>\n", + " <td>38000</td>\n", + " <td>False</td>\n", + " <td>-1.09173031656</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>2070</td>\n", + " <td>2</td>\n", + " <td>3.0</td>\n", + " <td>148000</td>\n", + " <td>1550</td>\n", + " <td>14000</td>\n", + " <td>False</td>\n", + " <td>-3.16301875478</td>\n", + " <td>False</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>650</td>\n", + " <td>3</td>\n", + " <td>1.5</td>\n", + " <td>65000</td>\n", + " <td>1450</td>\n", + " <td>12000</td>\n", + " <td>True</td>\n", + " <td>1.00000000486</td>\n", + " <td>True</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, 590, 2, 1.0, 50000, 770, 22100, True, 1.64923454025379, True),\n", + " (2, 1050, 3, 2.0, 85000, 1410, 12000, True, 1.34505433446611, True),\n", + " (3, 20, 3, 1.0, 22500, 1060, 3500, True, 1.0000000009249, True),\n", + " (4, 870, 2, 2.0, 90000, 1300, 17500, True, 1.00000000711647, True),\n", + " (5, 1320, 3, 2.0, 133000, 1500, 30000, False, -1.00000001728685, False),\n", + " (6, 1350, 2, 1.0, 90500, 820, 25700, True, 1.11113745878827, True),\n", + " (7, 2790, 3, 2.5, 260000, 2130, 25000, False, -0.291482790879796, False),\n", + " (8, 680, 2, 1.0, 142500, 1170, 22000, False, -1.00000000609094, False),\n", + " (9, 1840, 3, 2.0, 160000, 1500, 19000, False, -1.23665846846941, False),\n", + " (10, 3680, 4, 2.0, 240000, 2790, 20000, False, -1.09382010610257, False),\n", + " (11, 1660, 3, 1.0, 87000, 1030, 17500, True, 1.62636283239171, True),\n", + " (12, 1620, 3, 2.0, 118600, 1250, 20000, False, -1.6011681230749, False),\n", + " (13, 3100, 3, 2.0, 140000, 1760, 38000, False, -1.09173031656082, False),\n", + " (14, 2070, 2, 3.0, 148000, 1550, 14000, False, -3.16301875478316, False),\n", + " (15, 650, 3, 1.5, 65000, 1450, 12000, True, 1.00000000486389, True)]" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS houses_pred_gaussian;\n", + "\n", + "SELECT madlib.svm_predict('houses_svm_gaussian', \n", + " 'houses', \n", + " 'id', \n", + " 'houses_pred_gaussian');\n", + "\n", + "SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred_gaussian USING (id) ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count the miss classifications. Note this produces a more accurate result than the linear case for this small data set:" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>count</th>\n", + " </tr>\n", + " <tr>\n", + " <td>0</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(0L,)]" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM houses_pred_gaussian JOIN houses USING (id) \n", + "WHERE houses_pred_gaussian.prediction != (houses.price < 100000);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. Balancing data sets\n", + "In the case of an unbalanced class-size dataset, use the 'balanced' parameter to classify when building the model:" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coef</th>\n", + " <th>loss</th>\n", + " <th>norm_of_gradient</th>\n", + " <th>num_iterations</th>\n", + " <th>num_rows_processed</th>\n", + " <th>num_rows_skipped</th>\n", + " <th>dep_var_mapping</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[0.891926151039837, 0.169282494673541, -2.26539133689874, 0.526518499596676, -0.900664505989526, 0.508112011288015, -0.355474591147659, 1.23127975981665, 1.53694964239487, 1.46496058633682]</td>\n", + " <td>0.569002744458</td>\n", + " <td>0.989597662459</td>\n", + " <td>183</td>\n", + " <td>15</td>\n", + " <td>0</td>\n", + " <td>[False, True]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([0.891926151039837, 0.169282494673541, -2.26539133689874, 0.526518499596676, -0.900664505989526, 0.508112011288015, -0.355474591147659, 1.23127975981665, 1.53694964239487, 1.46496058633682], 0.56900274445785, 0.989597662458527, 183, 15L, 0L, [False, True])]" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;\n", + "\n", + "SELECT madlib.svm_classification( 'houses',\n", + " 'houses_svm_gaussian',\n", + " 'price < 150000',\n", + " 'ARRAY[1, tax, bath, size]',\n", + " 'gaussian',\n", + " 'n_components=10',\n", + " '',\n", + " 'init_stepsize=1, max_iter=200, class_weight=balanced'\n", + " );\n", + "\n", + "SELECT * FROM houses_svm_gaussian;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Regression\n", + "# 1. Create input data set\n", + "For regression we use part of the well known abalone data set https://archive.ics.uci.edu/ml/datasets/abalone :" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Done.\n", + "20 rows affected.\n", + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>sex</th>\n", + " <th>length</th>\n", + " <th>diameter</th>\n", + " <th>height</th>\n", + " <th>rings</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>M</td>\n", + " <td>0.455</td>\n", + " <td>0.365</td>\n", + " <td>0.095</td>\n", + " <td>15</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>M</td>\n", + " <td>0.35</td>\n", + " <td>0.265</td>\n", + " <td>0.09</td>\n", + " <td>7</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>F</td>\n", + " <td>0.53</td>\n", + " <td>0.42</td>\n", + " <td>0.135</td>\n", + " <td>9</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>M</td>\n", + " <td>0.44</td>\n", + " <td>0.365</td>\n", + " <td>0.125</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>I</td>\n", + " <td>0.33</td>\n", + " <td>0.255</td>\n", + " <td>0.08</td>\n", + " <td>7</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>I</td>\n", + " <td>0.425</td>\n", + " <td>0.3</td>\n", + " <td>0.095</td>\n", + " <td>8</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>F</td>\n", + " <td>0.53</td>\n", + " <td>0.415</td>\n", + " <td>0.15</td>\n", + " <td>20</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>F</td>\n", + " <td>0.545</td>\n", + " <td>0.425</td>\n", + " <td>0.125</td>\n", + " <td>16</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>M</td>\n", + " <td>0.475</td>\n", + " <td>0.37</td>\n", + " <td>0.125</td>\n", + " <td>9</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>F</td>\n", + " <td>0.55</td>\n", + " <td>0.44</td>\n", + " <td>0.15</td>\n", + " <td>19</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>F</td>\n", + " <td>0.525</td>\n", + " <td>0.38</td>\n", + " <td>0.14</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>M</td>\n", + " <td>0.43</td>\n", + " <td>0.35</td>\n", + " <td>0.11</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>M</td>\n", + " <td>0.49</td>\n", + " <td>0.38</td>\n", + " <td>0.135</td>\n", + " <td>11</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>F</td>\n", + " <td>0.535</td>\n", + " <td>0.405</td>\n", + " <td>0.145</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>F</td>\n", + " <td>0.47</td>\n", + " <td>0.355</td>\n", + " <td>0.1</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>M</td>\n", + " <td>0.5</td>\n", + " <td>0.4</td>\n", + " <td>0.13</td>\n", + " <td>12</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>I</td>\n", + " <td>0.355</td>\n", + " <td>0.28</td>\n", + " <td>0.085</td>\n", + " <td>7</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>F</td>\n", + " <td>0.44</td>\n", + " <td>0.34</td>\n", + " <td>0.1</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>M</td>\n", + " <td>0.365</td>\n", + " <td>0.295</td>\n", + " <td>0.08</td>\n", + " <td>7</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>M</td>\n", + " <td>0.45</td>\n", + " <td>0.32</td>\n", + " <td>0.1</td>\n", + " <td>9</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, u'M', 0.455, 0.365, 0.095, 15),\n", + " (2, u'M', 0.35, 0.265, 0.09, 7),\n", + " (3, u'F', 0.53, 0.42, 0.135, 9),\n", + " (4, u'M', 0.44, 0.365, 0.125, 10),\n", + " (5, u'I', 0.33, 0.255, 0.08, 7),\n", + " (6, u'I', 0.425, 0.3, 0.095, 8),\n", + " (7, u'F', 0.53, 0.415, 0.15, 20),\n", + " (8, u'F', 0.545, 0.425, 0.125, 16),\n", + " (9, u'M', 0.475, 0.37, 0.125, 9),\n", + " (10, u'F', 0.55, 0.44, 0.15, 19),\n", + " (11, u'F', 0.525, 0.38, 0.14, 14),\n", + " (12, u'M', 0.43, 0.35, 0.11, 10),\n", + " (13, u'M', 0.49, 0.38, 0.135, 11),\n", + " (14, u'F', 0.535, 0.405, 0.145, 10),\n", + " (15, u'F', 0.47, 0.355, 0.1, 10),\n", + " (16, u'M', 0.5, 0.4, 0.13, 12),\n", + " (17, u'I', 0.355, 0.28, 0.085, 7),\n", + " (18, u'F', 0.44, 0.34, 0.1, 10),\n", + " (19, u'M', 0.365, 0.295, 0.08, 7),\n", + " (20, u'M', 0.45, 0.32, 0.1, 9)]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS abalone;\n", + "\n", + "CREATE TABLE abalone (id INT, sex TEXT, length FLOAT, diameter FLOAT, height FLOAT, rings INT);\n", + "\n", + "INSERT INTO abalone VALUES\n", + "(1,'M',0.455,0.365,0.095,15),\n", + "(2,'M',0.35,0.265,0.09,7),\n", + "(3,'F',0.53,0.42,0.135,9),\n", + "(4,'M',0.44,0.365,0.125,10),\n", + "(5,'I',0.33,0.255,0.08,7),\n", + "(6,'I',0.425,0.3,0.095,8),\n", + "(7,'F',0.53,0.415,0.15,20),\n", + "(8,'F',0.545,0.425,0.125,16),\n", + "(9,'M',0.475,0.37,0.125,9),\n", + "(10,'F',0.55,0.44,0.15,19),\n", + "(11,'F',0.525,0.38,0.14,14),\n", + "(12,'M',0.43,0.35,0.11,10),\n", + "(13,'M',0.49,0.38,0.135,11),\n", + "(14,'F',0.535,0.405,0.145,10),\n", + "(15,'F',0.47,0.355,0.1,10),\n", + "(16,'M',0.5,0.4,0.13,12),\n", + "(17,'I',0.355,0.28,0.085,7),\n", + "(18,'F',0.44,0.34,0.1,10),\n", + "(19,'M',0.365,0.295,0.08,7),\n", + "(20,'M',0.45,0.32,0.1,9);\n", + "\n", + "SELECT * FROM abalone ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Train linear regression model" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coef</th>\n", + " <th>loss</th>\n", + " <th>norm_of_gradient</th>\n", + " <th>num_iterations</th>\n", + " <th>num_rows_processed</th>\n", + " <th>num_rows_skipped</th>\n", + " <th>dep_var_mapping</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[1.998949892503, 0.918517478913099, 0.712125856084095, 0.229379472956877]</td>\n", + " <td>8.29033295818</td>\n", + " <td>23.2251777858</td>\n", + " <td>100</td>\n", + " <td>20</td>\n", + " <td>0</td>\n", + " <td>[None]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([1.998949892503, 0.918517478913099, 0.712125856084095, 0.229379472956877], 8.29033295818392, 23.225177785827, 100, 20L, 0L, [None])]" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS abalone_svm_regression, abalone_svm_regression_summary;\n", + "\n", + "SELECT madlib.svm_regression('abalone',\n", + " 'abalone_svm_regression',\n", + " 'rings',\n", + " 'ARRAY[1, length, diameter, height]'\n", + " );\n", + "\n", + "SELECT * FROM abalone_svm_regression;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Predict using linear model" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>sex</th>\n", + " <th>length</th>\n", + " <th>diameter</th>\n", + " <th>height</th>\n", + " <th>rings</th>\n", + " <th>prediction</th>\n", + " <th>decision_function</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>M</td>\n", + " <td>0.455</td>\n", + " <td>0.365</td>\n", + " <td>0.095</td>\n", + " <td>15</td>\n", + " <td>2.69859233281</td>\n", + " <td>2.69859233281</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>M</td>\n", + " <td>0.35</td>\n", + " <td>0.265</td>\n", + " <td>0.09</td>\n", + " <td>7</td>\n", + " <td>2.52978851455</td>\n", + " <td>2.52978851455</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>F</td>\n", + " <td>0.53</td>\n", + " <td>0.42</td>\n", + " <td>0.135</td>\n", + " <td>9</td>\n", + " <td>2.81582324473</td>\n", + " <td>2.81582324473</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>M</td>\n", + " <td>0.44</td>\n", + " <td>0.365</td>\n", + " <td>0.125</td>\n", + " <td>10</td>\n", + " <td>2.69169595482</td>\n", + " <td>2.69169595482</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>I</td>\n", + " <td>0.33</td>\n", + " <td>0.255</td>\n", + " <td>0.08</td>\n", + " <td>7</td>\n", + " <td>2.50200311168</td>\n", + " <td>2.50200311168</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>I</td>\n", + " <td>0.425</td>\n", + " <td>0.3</td>\n", + " <td>0.095</td>\n", + " <td>8</td>\n", + " <td>2.6247486278</td>\n", + " <td>2.6247486278</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>F</td>\n", + " <td>0.53</td>\n", + " <td>0.415</td>\n", + " <td>0.15</td>\n", + " <td>20</td>\n", + " <td>2.81570330755</td>\n", + " <td>2.81570330755</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>F</td>\n", + " <td>0.545</td>\n", + " <td>0.425</td>\n", + " <td>0.125</td>\n", + " <td>16</td>\n", + " <td>2.83086784147</td>\n", + " <td>2.83086784147</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>M</td>\n", + " <td>0.475</td>\n", + " <td>0.37</td>\n", + " <td>0.125</td>\n", + " <td>9</td>\n", + " <td>2.72740469586</td>\n", + " <td>2.72740469586</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>F</td>\n", + " <td>0.55</td>\n", + " <td>0.44</td>\n", + " <td>0.15</td>\n", + " <td>19</td>\n", + " <td>2.85187680353</td>\n", + " <td>2.85187680353</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>F</td>\n", + " <td>0.525</td>\n", + " <td>0.38</td>\n", + " <td>0.14</td>\n", + " <td>14</td>\n", + " <td>2.78389252046</td>\n", + " <td>2.78389252046</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>M</td>\n", + " <td>0.43</td>\n", + " <td>0.35</td>\n", + " <td>0.11</td>\n", + " <td>10</td>\n", + " <td>2.66838820009</td>\n", + " <td>2.66838820009</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>M</td>\n", + " <td>0.49</td>\n", + " <td>0.38</td>\n", + " <td>0.135</td>\n", + " <td>11</td>\n", + " <td>2.75059751133</td>\n", + " <td>2.75059751133</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>F</td>\n", + " <td>0.535</td>\n", + " <td>0.405</td>\n", + " <td>0.145</td>\n", + " <td>10</td>\n", + " <td>2.81202773901</td>\n", + " <td>2.81202773901</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>F</td>\n", + " <td>0.47</td>\n", + " <td>0.355</td>\n", + " <td>0.1</td>\n", + " <td>10</td>\n", + " <td>2.7063957338</td>\n", + " <td>2.7063957338</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>M</td>\n", + " <td>0.5</td>\n", + " <td>0.4</td>\n", + " <td>0.13</td>\n", + " <td>12</td>\n", + " <td>2.77287830588</td>\n", + " <td>2.77287830588</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>I</td>\n", + " <td>0.355</td>\n", + " <td>0.28</td>\n", + " <td>0.085</td>\n", + " <td>7</td>\n", + " <td>2.54391609242</td>\n", + " <td>2.54391609242</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>F</td>\n", + " <td>0.44</td>\n", + " <td>0.34</td>\n", + " <td>0.1</td>\n", + " <td>10</td>\n", + " <td>2.66815832159</td>\n", + " <td>2.66815832159</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>M</td>\n", + " <td>0.365</td>\n", + " <td>0.295</td>\n", + " <td>0.08</td>\n", + " <td>7</td>\n", + " <td>2.56263625769</td>\n", + " <td>2.56263625769</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>M</td>\n", + " <td>0.45</td>\n", + " <td>0.32</td>\n", + " <td>0.1</td>\n", + " <td>9</td>\n", + " <td>2.66310097926</td>\n", + " <td>2.66310097926</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, u'M', 0.455, 0.365, 0.095, 15, 2.69859233281006, 2.69859233281006),\n", + " (2, u'M', 0.35, 0.265, 0.09, 7, 2.52978851455099, 2.52978851455099),\n", + " (3, u'F', 0.53, 0.42, 0.135, 9, 2.81582324473145, 2.81582324473145),\n", + " (4, u'M', 0.44, 0.365, 0.125, 10, 2.69169595481507, 2.69169595481507),\n", + " (5, u'I', 0.33, 0.255, 0.08, 7, 2.50200311168232, 2.50200311168232),\n", + " (6, u'I', 0.425, 0.3, 0.095, 8, 2.6247486277972, 2.6247486277972),\n", + " (7, u'F', 0.53, 0.415, 0.15, 20, 2.81570330754538, 2.81570330754538),\n", + " (8, u'F', 0.545, 0.425, 0.125, 16, 2.83086784146599, 2.83086784146599),\n", + " (9, u'M', 0.475, 0.37, 0.125, 9, 2.72740469585745, 2.72740469585745),\n", + " (10, u'F', 0.55, 0.44, 0.15, 19, 2.85187680352574, 2.85187680352574),\n", + " (11, u'F', 0.525, 0.38, 0.14, 14, 2.7838925204583, 2.7838925204583),\n", + " (12, u'M', 0.43, 0.35, 0.11, 10, 2.66838820009033, 2.66838820009033),\n", + " (13, u'M', 0.49, 0.38, 0.135, 11, 2.75059751133156, 2.75059751133156),\n", + " (14, u'F', 0.535, 0.405, 0.145, 10, 2.81202773901432, 2.81202773901432),\n", + " (15, u'F', 0.47, 0.355, 0.1, 10, 2.7063957337977, 2.7063957337977),\n", + " (16, u'M', 0.5, 0.4, 0.13, 12, 2.77287830587759, 2.77287830587759),\n", + " (17, u'I', 0.355, 0.28, 0.085, 7, 2.54391609242204, 2.54391609242204),\n", + " (18, u'F', 0.44, 0.34, 0.1, 10, 2.66815832158905, 2.66815832158905),\n", + " (19, u'M', 0.365, 0.295, 0.08, 7, 2.56263625768764, 2.56263625768764),\n", + " (20, u'M', 0.45, 0.32, 0.1, 9, 2.6631009792565, 2.6631009792565)]" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS abalone_regr;\n", + "\n", + "SELECT madlib.svm_predict('abalone_svm_regression',\n", + " 'abalone', \n", + " 'id', \n", + " 'abalone_regr');\n", + "\n", + "SELECT * FROM abalone JOIN abalone_regr USING (id) ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "RMS error:" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>rms_error</th>\n", + " </tr>\n", + " <tr>\n", + " <td>9.08842725553</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(9.08842725552861,)]" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n", + "JOIN abalone_regr USING (id);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. Train using Gaussian model" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coef</th>\n", + " <th>loss</th>\n", + " <th>norm_of_gradient</th>\n", + " <th>num_iterations</th>\n", + " <th>num_rows_processed</th>\n", + " <th>num_rows_skipped</th>\n", + " <th>dep_var_mapping</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707]</td>\n", + " <td>2.66850539542</td>\n", + " <td>0.974400795364</td>\n", + " <td>163</td>\n", + " <td>20</td>\n", + " <td>0</td>\n", + " <td>[None]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707], 2.66850539541894, 0.97440079536379, 163, 20L, 0L, [None])]" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS abalone_svm_gaussian_regression, abalone_svm_gaussian_regression_summary, abalone_svm_gaussian_regression_random;\n", + "\n", + "SELECT madlib.svm_regression( 'abalone',\n", + " 'abalone_svm_gaussian_regression',\n", + " 'rings',\n", + " 'ARRAY[1, length, diameter, height]',\n", + " 'gaussian',\n", + " 'n_components=10',\n", + " '',\n", + " 'init_stepsize=1, max_iter=200'\n", + " );\n", + "\n", + "SELECT * FROM abalone_svm_gaussian_regression;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Predict using Gaussian model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>sex</th>\n", + " <th>length</th>\n", + " <th>diameter</th>\n", + " <th>height</th>\n", + " <th>rings</th>\n", + " <th>prediction</th>\n", + " <th>decision_function</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>M</td>\n", + " <td>0.455</td>\n", + " <td>0.365</td>\n", + " <td>0.095</td>\n", + " <td>15</td>\n", + " <td>9.9302009808</td>\n", + " <td>9.9302009808</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>M</td>\n", + " <td>0.35</td>\n", + " <td>0.265</td>\n", + " <td>0.09</td>\n", + " <td>7</td>\n", + " <td>9.87712610207</td>\n", + " <td>9.87712610207</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>F</td>\n", + " <td>0.53</td>\n", + " <td>0.42</td>\n", + " <td>0.135</td>\n", + " <td>9</td>\n", + " <td>10.0459812729</td>\n", + " <td>10.0459812729</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>M</td>\n", + " <td>0.44</td>\n", + " <td>0.365</td>\n", + " <td>0.125</td>\n", + " <td>10</td>\n", + " <td>10.018415777</td>\n", + " <td>10.018415777</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>I</td>\n", + " <td>0.33</td>\n", + " <td>0.255</td>\n", + " <td>0.08</td>\n", + " <td>7</td>\n", + " <td>9.81382643977</td>\n", + " <td>9.81382643977</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>I</td>\n", + " <td>0.425</td>\n", + " <td>0.3</td>\n", + " <td>0.095</td>\n", + " <td>8</td>\n", + " <td>9.973725783</td>\n", + " <td>9.973725783</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>F</td>\n", + " <td>0.53</td>\n", + " <td>0.415</td>\n", + " <td>0.15</td>\n", + " <td>20</td>\n", + " <td>10.1032556038</td>\n", + " <td>10.1032556038</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>F</td>\n", + " <td>0.545</td>\n", + " <td>0.425</td>\n", + " <td>0.125</td>\n", + " <td>16</td>\n", + " <td>10.0140320794</td>\n", + " <td>10.0140320794</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>M</td>\n", + " <td>0.475</td>\n", + " <td>0.37</td>\n", + " <td>0.125</td>\n", + " <td>9</td>\n", + " <td>10.0478657373</td>\n", + " <td>10.0478657373</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>F</td>\n", + " <td>0.55</td>\n", + " <td>0.44</td>\n", + " <td>0.15</td>\n", + " <td>19</td>\n", + " <td>10.0698224494</td>\n", + " <td>10.0698224494</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>F</td>\n", + " <td>0.525</td>\n", + " <td>0.38</td>\n", + " <td>0.14</td>\n", + " <td>14</td>\n", + " <td>10.1259635318</td>\n", + " <td>10.1259635318</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>M</td>\n", + " <td>0.43</td>\n", + " <td>0.35</td>\n", + " <td>0.11</td>\n", + " <td>10</td>\n", + " <td>9.97481060063</td>\n", + " <td>9.97481060063</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>M</td>\n", + " <td>0.49</td>\n", + " <td>0.38</td>\n", + " <td>0.135</td>\n", + " <td>11</td>\n", + " <td>10.0805427887</td>\n", + " <td>10.0805427887</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>F</td>\n", + " <td>0.535</td>\n", + " <td>0.405</td>\n", + " <td>0.145</td>\n", + " <td>10</td>\n", + " <td>10.107947317</td>\n", + " <td>10.107947317</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>F</td>\n", + " <td>0.47</td>\n", + " <td>0.355</td>\n", + " <td>0.1</td>\n", + " <td>10</td>\n", + " <td>9.97781238334</td>\n", + " <td>9.97781238334</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>M</td>\n", + " <td>0.5</td>\n", + " <td>0.4</td>\n", + " <td>0.13</td>\n", + " <td>12</td>\n", + " <td>10.0409088715</td>\n", + " <td>10.0409088715</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>I</td>\n", + " <td>0.355</td>\n", + " <td>0.28</td>\n", + " <td>0.085</td>\n", + " <td>7</td>\n", + " <td>9.8548093316</td>\n", + " <td>9.8548093316</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>F</td>\n", + " <td>0.44</td>\n", + " <td>0.34</td>\n", + " <td>0.1</td>\n", + " <td>10</td>\n", + " <td>9.96407219215</td>\n", + " <td>9.96407219215</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>M</td>\n", + " <td>0.365</td>\n", + " <td>0.295</td>\n", + " <td>0.08</td>\n", + " <td>7</td>\n", + " <td>9.83873423654</td>\n", + " <td>9.83873423654</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>M</td>\n", + " <td>0.45</td>\n", + " <td>0.32</td>\n", + " <td>0.1</td>\n", + " <td>9</td>\n", + " <td>10.0003544239</td>\n", + " <td>10.0003544239</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, u'M', 0.455, 0.365, 0.095, 15, 9.93020098079582, 9.93020098079582),\n", + " (2, u'M', 0.35, 0.265, 0.09, 7, 9.87712610207203, 9.87712610207203),\n", + " (3, u'F', 0.53, 0.42, 0.135, 9, 10.045981272917, 10.045981272917),\n", + " (4, u'M', 0.44, 0.365, 0.125, 10, 10.0184157770077, 10.0184157770077),\n", + " (5, u'I', 0.33, 0.255, 0.08, 7, 9.81382643976989, 9.81382643976989),\n", + " (6, u'I', 0.425, 0.3, 0.095, 8, 9.97372578299521, 9.97372578299521),\n", + " (7, u'F', 0.53, 0.415, 0.15, 20, 10.1032556037805, 10.1032556037805),\n", + " (8, u'F', 0.545, 0.425, 0.125, 16, 10.0140320794144, 10.0140320794144),\n", + " (9, u'M', 0.475, 0.37, 0.125, 9, 10.0478657373155, 10.0478657373155),\n", + " (10, u'F', 0.55, 0.44, 0.15, 19, 10.0698224493735, 10.0698224493735),\n", + " (11, u'F', 0.525, 0.38, 0.14, 14, 10.1259635317559, 10.1259635317559),\n", + " (12, u'M', 0.43, 0.35, 0.11, 10, 9.97481060062509, 9.97481060062509),\n", + " (13, u'M', 0.49, 0.38, 0.135, 11, 10.0805427887436, 10.0805427887436),\n", + " (14, u'F', 0.535, 0.405, 0.145, 10, 10.107947317027, 10.107947317027),\n", + " (15, u'F', 0.47, 0.355, 0.1, 10, 9.97781238333585, 9.97781238333585),\n", + " (16, u'M', 0.5, 0.4, 0.13, 12, 10.0409088715201, 10.0409088715201),\n", + " (17, u'I', 0.355, 0.28, 0.085, 7, 9.85480933160473, 9.85480933160473),\n", + " (18, u'F', 0.44, 0.34, 0.1, 10, 9.96407219215287, 9.96407219215287),\n", + " (19, u'M', 0.365, 0.295, 0.08, 7, 9.83873423654298, 9.83873423654298),\n", + " (20, u'M', 0.45, 0.32, 0.1, 9, 10.0003544238551, 10.0003544238551)]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS abalone_gaussian_regr;\n", + "\n", + "SELECT madlib.svm_predict('abalone_svm_gaussian_regression', \n", + " 'abalone', \n", + " 'id', \n", + " 'abalone_gaussian_regr');\n", + "\n", + "SELECT * FROM abalone JOIN abalone_gaussian_regr USING (id) ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute the RMS error. Note this produces a more accurate result than the linear case for this small data set:" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>rms_error</th>\n", + " </tr>\n", + " <tr>\n", + " <td>3.84176368344</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(3.84176368343915,)]" + ] + }, + "execution_count": 63, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n", + "JOIN abalone_gaussian_regr USING (id);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. Cross validation\n", + "Let's run cross validation for different initial step sizes and lambda values:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coef</th>\n", + " <th>loss</th>\n", + " <th>norm_of_gradient</th>\n", + " <th>num_iterations</th>\n", + " <th>num_rows_processed</th>\n", + " <th>num_rows_skipped</th>\n", + " <th>dep_var_mapping</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707]</td>\n", + " <td>2.63941855054</td>\n", + " <td>1.07622244533</td>\n", + " <td>163</td>\n", + " <td>20</td>\n", + " <td>0</td>\n", + " <td>[None]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707], 2.63941855054256, 1.07622244533275, 163, 20L, 0L, [None])]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS abalone_svm_gaussian_regression, abalone_svm_gaussian_regression_summary, \n", + "abalone_svm_gaussian_regression_random, abalone_svm_gaussian_regression_cv;\n", + "\n", + "SELECT madlib.svm_regression( 'abalone',\n", + " 'abalone_svm_gaussian_regression',\n", + " 'rings',\n", + " 'ARRAY[1, length, diameter, height]',\n", + " 'gaussian',\n", + " 'n_components=10',\n", + " '',\n", + " 'init_stepsize=[0.01,1], n_folds=3, max_iter=200, lambda=[0.01, 0.1, 0.5], validation_result=abalone_svm_gaussian_regression_cv'\n", + " );\n", + "\n", + "SELECT * FROM abalone_svm_gaussian_regression;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "View the summary table showing the final model parameters are those that produced \n", + "the lowest error in the cross validation runs:" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>method</th>\n", + " <th>version_number</th>\n", + " <th>source_table</th>\n", + " <th>model_table</th>\n", + " <th>dependent_varname</th>\n", + " <th>independent_varname</th>\n", + " <th>kernel_func</th>\n", + " <th>kernel_params</th>\n", + " <th>grouping_col</th>\n", + " <th>optim_params</th>\n", + " <th>reg_params</th>\n", + " <th>num_all_groups</th>\n", + " <th>num_failed_groups</th>\n", + " <th>total_rows_processed</th>\n", + " <th>total_rows_skipped</th>\n", + " </tr>\n", + " <tr>\n", + " <td>SVR</td>\n", + " <td>1.15-dev</td>\n", + " <td>abalone</td>\n", + " <td>abalone_svm_gaussian_regression</td>\n", + " <td>rings</td>\n", + " <td>ARRAY[1, length, diameter, height]</td>\n", + " <td>gaussian</td>\n", + " <td>gamma=0.25, n_components=10,random_state=1, fit_intercept=False, fit_in_memory=True</td>\n", + " <td>NULL</td>\n", + " <td> init_stepsize=1.0,<br> decay_factor=0.9,<br> max_iter=200,<br> tolerance=1e-10,<br> epsilon=0.01,<br> eps_table=,<br> class_weight=<br> </td>\n", + " <td>lambda=0.01, norm=l2, n_folds=3</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>20</td>\n", + " <td>0</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(u'SVR', u'1.15-dev', u'abalone', u'abalone_svm_gaussian_regression', u'rings', u'ARRAY[1, length, diameter, height]', u'gaussian', u'gamma=0.25, n_components=10,random_state=1, fit_intercept=False, fit_in_memory=True', u'NULL', u' init_stepsize=1.0,\\n decay_factor=0.9,\\n max_iter=200,\\n tolerance=1e-10,\\n epsilon=0.01,\\n eps_table=,\\n class_weight=\\n ', u'lambda=0.01, norm=l2, n_folds=3', 1, 0, 20L, 0L)]" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql SELECT * FROM abalone_svm_gaussian_regression_summary;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "View the values for cross validation:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>init_stepsize</th>\n", + " <th>lambda</th>\n", + " <th>mean_score</th>\n", + " <th>std_dev_score</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1.0</td>\n", + " <td>0.01</td>\n", + " <td>-4.06711568585</td>\n", + " <td>0.435966381366</td>\n", + " </tr>\n", + " <tr>\n", + " <td>1.0</td>\n", + " <td>0.1</td>\n", + " <td>-4.08068428345</td>\n", + " <td>0.44660797513</td>\n", + " </tr>\n", + " <tr>\n", + " <td>1.0</td>\n", + " <td>0.5</td>\n", + " <td>-4.52576046087</td>\n", + " <td>0.20597876382</td>\n", + " </tr>\n", + " <tr>\n", + " <td>0.01</td>\n", + " <td>0.01</td>\n", + " <td>-11.0231044189</td>\n", + " <td>0.739956548721</td>\n", + " </tr>\n", + " <tr>\n", + " <td>0.01</td>\n", + " <td>0.1</td>\n", + " <td>-11.0244799274</td>\n", + " <td>0.740029346709</td>\n", + " </tr>\n", + " <tr>\n", + " <td>0.01</td>\n", + " <td>0.5</td>\n", + " <td>-11.0305445077</td>\n", + " <td>0.740350338532</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(Decimal('1.0'), Decimal('0.01'), Decimal('-4.06711568585'), Decimal('0.435966381366')),\n", + " (Decimal('1.0'), Decimal('0.1'), Decimal('-4.08068428345'), Decimal('0.44660797513')),\n", + " (Decimal('1.0'), Decimal('0.5'), Decimal('-4.52576046087'), Decimal('0.20597876382')),\n", + " (Decimal('0.01'), Decimal('0.01'), Decimal('-11.0231044189'), Decimal('0.739956548721')),\n", + " (Decimal('0.01'), Decimal('0.1'), Decimal('-11.0244799274'), Decimal('0.740029346709')),\n", + " (Decimal('0.01'), Decimal('0.5'), Decimal('-11.0305445077'), Decimal('0.740350338532'))]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT * FROM abalone_svm_gaussian_regression_cv;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Predict using cross-validated Gaussian regression model:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>svm_predict</th>\n", + " </tr>\n", + " <tr>\n", + " <td></td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[('',)]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS abalone_gaussian_regr;\n", + "SELECT madlib.svm_predict('abalone_svm_gaussian_regression', \n", + " 'abalone', \n", + " 'id', \n", + " 'abalone_gaussian_regr');" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute the RMS error. Note this produces a more accurate result than the previous run with the Gaussian kernel:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>rms_error</th>\n", + " </tr>\n", + " <tr>\n", + " <td>3.84176368344</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(3.84176368343915,)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n", + "JOIN abalone_gaussian_regr USING (id);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# Novelty detection \n", + "# 1. Train a non-linear one-class SVM\n", + "Use a Gaussian kernel using the housing data set. Note that the dependent variable is not a parameter for one-class:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coef</th>\n", + " <th>loss</th>\n", + " <th>norm_of_gradient</th>\n", + " <th>num_iterations</th>\n", + " <th>num_rows_processed</th>\n", + " <th>num_rows_skipped</th>\n", + " <th>dep_var_mapping</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[0.0207901288823711, -0.00103437489314969, 0.00407820868429805, 0.0274910360546609, 0.0105696547048294, -0.00313332466259033, -0.0216703145014011, 0.0363248037825208, -0.0211400498166549, -0.00827402232219555, 0.0265909439934851, 0.0282462482323058, -0.0407407195393746, 0.0191290942177852, -0.00313542082923064, -0.0191740603622109, 0.0143626646548982, -0.0620527674181034, -0.000319831622794402, 0.00388104709972051, 0.00248129433065678, 0.00764915273571186, 0.014492283562898, 0.0184730815984353, -0.00745840880633255, -0.0232208663374367, -0.010724056217189, 0.00541494627043399, 0.0150679846777238, 0.0204022414812525, -0.0294626167089617, -0.00399506510201406, -0.0231139983460727, 0.0242203153309423, -0.0421196963278802, 0.0112202149916885, -0.00720876723524249, 0.0213674589734111, -0.00260107056222295, -0.0130652059444514, 0.0710580616012718, 0.0519822855717347, 0.00961050532247376, 0.0390561950837254, -0.0152620688050253, 0.0100336750737295, 0.0632488712630204, - 0.0549714494076944, -0.007684860916257, 0.0322104572263339, -0.00832311210931705, 0.0279669244721609, 0.0455147539995411, -0.0639670005155479, -0.00965055072583972, 0.00648588125681694]</td>\n", + " <td>0.944016313708</td>\n", + " <td>14.5271059047</td>\n", + " <td>100</td>\n", + " <td>16</td>\n", + " <td>-1</td>\n", + " <td>[-1.0, 1.0]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([0.0207901288823711, -0.00103437489314969, 0.00407820868429805, 0.0274910360546609, 0.0105696547048294, -0.00313332466259033, -0.0216703145014011, 0.0363248037825208, -0.0211400498166549, -0.00827402232219555, 0.0265909439934851, 0.0282462482323058, -0.0407407195393746, 0.0191290942177852, -0.00313542082923064, -0.0191740603622109, 0.0143626646548982, -0.0620527674181034, -0.000319831622794402, 0.00388104709972051, 0.00248129433065678, 0.00764915273571186, 0.014492283562898, 0.0184730815984353, -0.00745840880633255, -0.0232208663374367, -0.010724056217189, 0.00541494627043399, 0.0150679846777238, 0.0204022414812525, -0.0294626167089617, -0.00399506510201406, -0.0231139983460727, 0.0242203153309423, -0.0421196963278802, 0.0112202149916885, -0.00720876723524249, 0.0213674589734111, -0.00260107056222295, -0.0130652059444514, 0.0710580616012718, 0.0519822855717347, 0.00961050532247376, 0.0390561950837254, -0.0152620688050253, 0.0100336750737295, 0.0632488712630204, -0.05497144 94076944, -0.007684860916257, 0.0322104572263339, -0.00832311210931705, 0.0279669244721609, 0.0455147539995411, -0.0639670005155479, -0.00965055072583972, 0.00648588125681694], 0.944016313708205, 14.5271059047443, 100, 16L, -1L, [-1.0, 1.0])]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS houses_one_class_gaussian, houses_one_class_gaussian_summary, houses_one_class_gaussian_random;\n", + "\n", + "SELECT madlib.svm_one_class('houses',\n", + " 'houses_one_class_gaussian',\n", + " 'ARRAY[1,tax,bedroom,bath,size,lot,price]',\n", + " 'gaussian',\n", + " 'gamma=0.5,n_components=55, random_state=3',\n", + " NULL,\n", + " 'max_iter=100, init_stepsize=10,lambda=10, tolerance=0'\n", + " );\n", + "\n", + "SELECT * FROM houses_one_class_gaussian;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Create test data\n", + "For the novelty detection using one-class, let's create a test data set using the last 3 values from the training set plus an outlier at the end (10x price):" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Done.\n", + "4 rows affected.\n", + "4 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>tax</th>\n", + " <th>bedroom</th>\n", + " <th>bath</th>\n", + " <th>price</th>\n", + " <th>size</th>\n", + " <th>lot</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>3100</td>\n", + " <td>3</td>\n", + " <td>2.0</td>\n", + " <td>140000</td>\n", + " <td>1760</td>\n", + " <td>38000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>2070</td>\n", + " <td>2</td>\n", + " <td>3.0</td>\n", + " <td>148000</td>\n", + " <td>1550</td>\n", + " <td>14000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>650</td>\n", + " <td>3</td>\n", + " <td>1.5</td>\n", + " <td>65000</td>\n", + " <td>1450</td>\n", + " <td>12000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>65
<TRUNCATED>