{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# K nearest neighbors and cross-validation\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"For this practical session, we will work on the real data mnist_digits.mat (digits), that can be downloaded from the course web page.\n",
"\n",
"For classification problems with $K$ classes, we call the \"confusion matrix\" associated to data $D_n=(x_t,y_t)$ the matrix $M \\in \\mathbb{N}^{K \\times K}$ such that $M_{i,j}$ is the number of elements with true class $i$ and predicted class $j$.\n",
"\n",
"**NB**: Given that there are more than $66000$ images in the dataset, we only work on a subset of these $66000$ images so as to not go beyond the memory of your computer.\n",
"\n",
"**1) Start by getting acquainted with the data. They are composed of a vector of labels `y` and images of size 28x28, given in matrix `x` of linearized vectors (each line of the matrix `x` corresponds to a single image).**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/vnd.plotly.v1+html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import plotly\n",
"import numpy as np\n",
"import plotly.plotly as py\n",
"import plotly.graph_objs as go\n",
"plotly.offline.init_notebook_mode()\n",
"import matplotlib.pyplot as plt\n",
"import scipy.io as sio\n",
"from scipy import stats\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeEAAAHiCAYAAADf3nSgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAG3dJREFUeJzt3X+8bXVd5/H3R5AUNEippoDADEzGZLAbUpZQkSNaXpqHFY5kOiQzltZoIw9JI6CaMWKyMvrBIzRpEiMruxpmPpSieIiDSGOgMd4BjBsqIj9CDK83PvPH3tjheC5n373P5es59/l8PO6Ds/Ze37W+Z7u9r7vW2Wet6u4AAA+9h42eAADsqUQYAAYRYQAYRIQBYBARBoBBRBgABhFh2INU1e9W1c+PngcwIcJsaFV1U1WdMHoee4qq+suq6qo6atnjb5s+fvx0+azp8g8sWWfv6WOHTZcf8A+Gqjq1qv6+qu6uqk9W1Z9V1aOr6p1V9Znpn89X1fYly7/1kHzjMCcRBlZUVXvPOfT/JnnBku08NsmxST61bL3bk5xTVXvNMJfjkvz3JM/r7kcneWKSS5Kku0/s7kd196OS/H6Sc+9f7u7/Muf3AA8JEWaPUVUvrKorqup1VXVnVd1QVd82ffzmqrq1qn5kyfrPrqprquqfps+ftWx7L6iqj1XVp6vqZ5YedVfVw6rqVVX1/6bPX1JVj9nJvI6vqm1V9VPTOXy8ql605Pn9q+qiqvrUdH+vqaqHzfM9TR1YVe+eHlH+VVUdumRfXVU/XlUfTfLR6WPfOF3/9qq6vqp+cJWX+veT/NCSuD4vyZ8k2b5svT+fPnbKKttLkm9J8r7uviZJuvv27n5Td989w1j4kiXC7GmemuRDSR6b5M1J3pLJX/DfkEkMfr2qHjVd955MjugOSPLsJC+pqpOSpKqOTPIbSZ6f5GuS7J/koCX7+YkkJyU5LsnXJrkjyfkPMq9/s2QbpyY5v6q+Yvrc66fPff10ey9I8qIlY3fle8p0zj+X5MAkf5tJNJc6abrNI6tqvyTvnm73qzIJ6m9U1b99kO/lliQfTvKM6fILkly0wnqd5GeS/GxVPfxBtpck70/y76vq7Kp6WlV92Srrw7ogwuxpbuzuN3b3vyT5gySHJDmnuz/X3X+RyZHZNyRJd/9ld/9dd9/X3R9KcnEmEUyS5yZ5e3f/TXdvT3JmJlG5339O8uru3tbdn0tyVpLnPsgp3s9P5/H57r40yWeSPGF6NPlDSc7o7ru7+6Yk/zPJD8/zPU39WXdfPp3Xq5N8a1UdsuT5/zE90vznJN+b5Kbp9nd09weT/NH0+38wFyV5QVU9IckB3f2+lVbq7i2ZnKb+0QfbWHf/dZL/kOQpSf4syaer6pdnOZUNX8rm/ZkPrFefXPL1PydJdy9/7FFJUlVPTfLaJE9Ksk+SL0vyh9P1vjbJzfcP6u7PVtWnl2zn0CR/UlX3LXnsX5J8dZJ/XGFen+7uHUuWPzudx4HTfX9syXMfywOPumf+nqaWzvszVXX7su/n5iXrHprkqVV155LH9k7yeyt8D0v9cSb/WPj0DOu+JskbV1uvu9+Z5J3TU/Hfmcn/Ftcn+e1Vtg9fskQYdu7NSX49yYndfW9V/UomUUySjyd5wv0rVtUjMzkdfL+bk/yn7r5iwTnclslR8qGZnOJNkq/LyiGf1ReOeqenqR+TySnk+y09or85yV919/fsyg6m/yh5Z5KXJHn8Kuu+u6q2JvmxGbd9X5L3VNV7M/kHEqxbTkfDzj06ye3TAB+T5D8uee6tSb5v+iGofZKcnaSWPP9bSX7h/g89VdVXVtXmXZ3A9BTzJdNtPXq6vVck+V/zfUtJkmdV1bdP5/1zSd7f3TfvZN13JDmiqn64qh4+/fMtVfXEGfbz00mOm55CX82rk5y+syeranNVnVxVX1ETx2Tyo4ErZ9g2fMkSYdi5H8vkV2juzuRnvpfc/0R3X5fkZZl8COrjSe5OcmuSz01X+dUkW5L8xXT8lZl82GkeL8vkQ2I3JPmbTI7Q3zDntjId/7OZ/IrQN2fyQa0VTT99/IwkJ2dytPyJJL+Yyan5B9Xdt3T338wyoekZg//9IKvckeTFmXxi+58y+UfIL3X38g+VwbpS3b36WsCDmp7WvTPJ4d194+j5AOuDI2GYU1V9X1XtO/01nvOS/F2Sm8bOClhPVo1wVb1h+gv/1+7k+aqqX6uqrVX1oap6ytpPE74kbc7kFO0tSQ5PcnI7tQTsglVPR1fV0zP5ncWLuvuLPolYVc/K5GdWz8rkZ16/2t3z/uwLAPYYqx4Jd/flmXyAY2c2ZxLo7u4rkxxQVV+zVhMEgI1qLX4mfFAe+Mv92/LACwkAACtYi4t11AqPrXiOu6pOS3LadPGb12DfADDSbd39lfMOXosIb8uSK/AkOTgPvPrOF3T3BUkuSCZ3a1mDfQPASB9bfZWdW4vT0VsyuVB7VdWxSe7q7o+vwXYBYENb9Ui4qi5Ocnwm9yDdlsmVdh6eJN39W0kuzeST0Vszuej8i1beEgCw1LArZjkdDcAGcHV3b5p3sCtmAcAgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8Age4+eAOxuD3vYYv/WfP7znz/32COOOGKhfS/isMMOW2j8scceu9D4s88+e+6xb37zmxfa93333bfQeHioOBIGgEFEGAAGEWEAGESEAWCQmSJcVc+squuramtVvWqF57+uqi6rqmuq6kNV9ay1nyoAbCyrRriq9kpyfpITkxyZ5HlVdeSy1V6T5JLuPjrJyUl+Y60nCgAbzSxHwsck2drdN3T39iRvSbJ52Tqd5MunX++f5Ja1myIAbEyz/J7wQUluXrK8LclTl61zVpK/qKqXJdkvyQlrMjsA2MBmORKuFR7rZcvPS/K73X1wkmcl+b2q+qJtV9VpVfWBqvrArk8VADaWWSK8LckhS5YPzhefbj41ySVJ0t3vS/KIJAcu31B3X9Ddm7p703zTBYCNY5YIX5Xk8Kp6XFXtk8kHr7YsW+cfknx3klTVEzOJ8KfWcqIAsNGsGuHu3pHkpUneleQjmXwK+rqqOqeqnjNd7aeSvLiq/k+Si5O8sLuXn7IGAJaY6QYO3X1pkkuXPXbmkq8/nORpazs1ANjYXDELAAapUWeNq8rpama2yO0IX/GKVyy073PPPXeh8evV9u3bFxq/zz77zD32mGOOWWjfH/zgB+ce6zaI7KKrF/mwsSNhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEHcT5h14aijjpp77DXXXLOGM9k1d99990Lj3/jGN8499ulPf/pC+37b29620PizzjprofGLOOKII+Yeu3Xr1jWcCXsA9xMGgPVIhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgkL1HT4A9wymnnLLQ+Ne85jVzj73zzjsX2vdFF10099jXve51C+37Yx/72ELjF/HsZz972L7vvffehcbv2LFjjWYCu5cjYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBB3E+Yh8Q3fdM3LTT+nnvumXvspk2bFtr3DTfcsND49eqAAw4Ytu8Xv/jFC42/6aab1mYisJs5EgaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAap7h6z46oxO4Y9xH777bfQ+CuuuGKh8Y9//OPnHnv00UcvtO+tW7cuNB52wdXdPff9Uh0JA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg8wU4ap6ZlVdX1Vbq+pVO1nnB6vqw1V1XVW9eW2nCQAbz96rrVBVeyU5P8n3JNmW5Kqq2tLdH16yzuFJzkjytO6+o6q+andNGJjNeeedt9D4Jz/5yQuNf8973jP32G3bti20b1gvZjkSPibJ1u6+obu3J3lLks3L1nlxkvO7+44k6e5b13aaALDxzBLhg5LcvGR52/SxpY5IckRVXVFVV1bVM9dqggCwUa16OjpJrfBYr7Cdw5Mcn+TgJH9dVU/q7jsfsKGq05KcNsc8AWDDmeVIeFuSQ5YsH5zklhXW+dPu/nx335jk+kyi/ADdfUF3b+ruTfNOGAA2ilkifFWSw6vqcVW1T5KTk2xZts7bknxnklTVgZmcnr5hLScKABvNqhHu7h1JXprkXUk+kuSS7r6uqs6pqudMV3tXkk9X1YeTXJbkld396d01aQDYCGb5mXC6+9Ikly577MwlX3eSV0z/AAAzcMUsABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQWryK74Ddlw1Zsewi/bdd9+5x55wwgkL7fuFL3zh3GO/4zu+Y6F9P/axj11o/CLe+973LjT+ta997dxjL7/88oX2vX379oXGs+5cvcilmB0JA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAzifsJseE972tMWGr/IvWkX3fdI11577ULjP/OZz8w99qijjlpo34985CPnHnvZZZcttO/nPve5c4+94447Fto3Q7ifMACsRyIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIPsPXoCsLtt3rx5ofGL3I5w0dsBvv71r5977I033rjQvq+66qqFxt91111zjz322GMX2vdJJ50099jTTz99oX2feuqpc48977zzFto3648jYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBB3E+YDe/ss89eaPw73vGOuccuej/h22+/faHx69WVV1650PgjjjhijWay6w499NBh+2b9cSQMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIj7CbPh3XPPPQuNv/zyy9doJgAP5EgYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgkJkiXFXPrKrrq2prVb3qQdZ7blV1VW1auykCwMa0aoSraq8k5yc5McmRSZ5XVUeusN6jk/xEkvev9SQBYCOa5Uj4mCRbu/uG7t6e5C1JNq+w3s8lOTfJvWs4PwDYsGaJ8EFJbl6yvG362BdU1dFJDunud6zh3ABgQ5vl2tG1wmP9hSerHpbkdUleuOqGqk5LctqskwOAjWyWI+FtSQ5ZsnxwkluWLD86yZOS/GVV3ZTk2CRbVvpwVndf0N2butsHtwDY480S4auSHF5Vj6uqfZKcnGTL/U92913dfWB3H9bdhyW5MslzuvsDu2XGALBBrHo6urt3VNVLk7wryV5J3tDd11XVOUk+0N1bHnwLAHuOCy+8cPQUWEdmup9wd1+a5NJlj525k3WPX3xaALDxuWIWAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwyEy3MgRYT4488shh+/7EJz4xbN+sP46EAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQdzKEPiSc8YZZyw0/uUvf/ncY88888yF9n3bbbctNJ49iyNhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEGqu8fsuGrMjpnbvvvuO/fYQw45ZKF9X3/99QuN56F32GGHzT32iiuuWGjf73vf++Yee/LJJy+07x07diw0nnXn6u7eNO9gR8IAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8Age4+eAOvHWWedNffYRW9F6FaGD70zzjhjofGvfOUr5x572223LbTvt7/97XOPdStCHkqOhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGcT/hPci+++670Pjjjjtu7rH777//Qvu+8MILFxq/Xh188MFzjz399NMX2vdLXvKShcbfdNNNc499xjOeMWzf8FByJAwAg8wU4ap6ZlVdX1Vbq+pVKzz/iqr6cFV9qKreU1WHrv1UAWBjWTXCVbVXkvOTnJjkyCTPq6ojl612TZJN3f3kJG9Ncu5aTxQANppZjoSPSbK1u2/o7u1J3pJk89IVuvuy7v7sdPHKJPP/IAsA9hCzRPigJDcvWd42fWxnTk3yzkUmBQB7glk+HV0rPNYrrlh1SpJNSVb8GG1VnZbktJlnBwAb2CwR3pbkkCXLBye5ZflKVXVCklcnOa67P7fShrr7giQXTNdfMeQAsKeY5XT0VUkOr6rHVdU+SU5OsmXpClV1dJLfTvKc7r517acJABvPqhHu7h1JXprkXUk+kuSS7r6uqs6pqudMV/ulJI9K8odV9bdVtWUnmwMApma6YlZ3X5rk0mWPnbnk6xPWeF4AsOG5YhYADCLCADCICAPAICIMAINU95hf1/V7wuvPKaecMvfYN73pTQvt+7Of/ezqK21Ae+2119xjH/GIRyy077POOmuh8eeeO/8l5O+9996F9g0Poau7e9O8gx0JA8AgIgwAg4gwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAzifsLMrKrmHnvggQcutO+Xvexlc4/9/u///oX2/aQnPWnusRdffPFC+966devcY2+88caF9r3oPaDvu+++hcbDOuF+wgCwHokwAAwiwgAwiAgDwCAiDACDiDAADCLCADCICAPAICIMAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAziVoYAMD+3MgSA9UiEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYBARBoBBRBgABhFhABhEhAFgEBEGgEFEGAAGEWEAGESEAWAQEQaAQUQYAAYRYQAYRIQBYJCZIlxVz6yq66tqa1W9aoXnv6yq/mD6/Pur6rC1nigAbDSrRriq9kpyfpITkxyZ5HlVdeSy1U5Nckd3f0OS1yX5xbWeKABsNLMcCR+TZGt339Dd25O8JcnmZetsTvKm6ddvTfLdVVVrN00A2HhmifBBSW5esrxt+tiK63T3jiR3JXnsWkwQADaqvWdYZ6Uj2p5jnVTVaUlOmy5+Lsm1M+yfXXdgkttGT2KD8truPl7b3cdru/s8YZHBs0R4W5JDliwfnOSWnayzrar2TrJ/ktuXb6i7L0hyQZJU1Qe6e9M8k+bBeW13H6/t7uO13X28trtPVX1gkfGznI6+KsnhVfW4qtonyclJtixbZ0uSH5l+/dwk7+3uLzoSBgD+1apHwt29o6pemuRdSfZK8obuvq6qzknyge7ekuTCJL9XVVszOQI+eXdOGgA2gllOR6e7L01y6bLHzlzy9b1JfmAX933BLq7P7Ly2u4/Xdvfx2u4+XtvdZ6HXtpw1BoAxXLYSAAYZEuHVLoPJbKrqkKq6rKo+UlXXVdVPTh9/TFW9u6o+Ov3vV4ye63pVVXtV1TVV9Y7p8uOml2b96PRSrfuMnuN6VFUHVNVbq+rvp+/fb/W+XRtV9fLp3wfXVtXFVfUI79v5VdUbqurWqrp2yWMrvldr4tembftQVT1lte0/5BGe8TKYzGZHkp/q7icmOTbJj09fy1cleU93H57kPdNl5vOTST6yZPkXk7xu+trekcklW9l1v5rkz7v7G5Mclclr7H27oKo6KMlPJNnU3U/K5MO0J8f7dhG/m+SZyx7b2Xv1xCSHT/+cluQ3V9v4iCPhWS6DyQy6++Pd/cHp13dn8hfZQXngZUTflOSkMTNc36rq4CTPTvI70+VK8l2ZXJo18drOpaq+PMnTM/mtinT39u6+M963a2XvJI+cXrNh3yQfj/ft3Lr78nzxdS929l7dnOSinrgyyQFV9TUPtv0REZ7lMpjsoumdq45O8v4kX93dH08moU7yVeNmtq79SpLTk9w3XX5skjunl2ZNvHfn9fVJPpXkjdNT/b9TVfvF+3Zh3f2PSc5L8g+ZxPeuJFfH+3at7ey9ust9GxHhmS5xyeyq6lFJ/ijJf+3ufxo9n42gqr43ya3dffXSh1dY1Xt31+2d5ClJfrO7j05yT5x6XhPTn01uTvK4JF+bZL9MTpEu5327e+zy3xEjIjzLZTCZUVU9PJMA/353//H04U/efwpk+t9bR81vHXtakudU1U2Z/MjkuzI5Mj5gepov8d6d17Yk27r7/dPlt2YSZe/bxZ2Q5Mbu/lR3fz7JHyf5tnjfrrWdvVd3uW8jIjzLZTCZwfRnlBcm+Uh3//KSp5ZeRvRHkvzpQz239a67z+jug7v7sEzeo+/t7ucnuSyTS7MmXtu5dPcnktxcVfdf+P67k3w43rdr4R+SHFtV+07/frj/tfW+XVs7e69uSfKC6aekj01y1/2nrXdmyMU6qupZmRxV3H8ZzF94yCexAVTVtyf56yR/l3/9ueVPZ/Jz4UuSfF0m/6f8ge7+ohtqMJuqOj7Jf+vu762qr8/kyPgxSa5Jckp3f27k/Najqvp3mXzgbZ8kNyR5USYHBd63C6qqs5P8UCa/PXFNkh/N5OeS3rdzqKqLkxyfyZ2oPpnkZ5O8LSu8V6f/8Pn1TD5N/dkkL+ruB73BgytmAcAgrpgFAIOIMAAMIsIAMIgIA8AgIgwAg4gwAAwiwgAwiAgDwCD/H6Oii4nNhzgVAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Chargement des données\n",
"data = sio.loadmat('mnist_digits.mat')\n",
"X_total = np.array(data['x'])\n",
"Y_total = np.array(data['y'])\n",
"n_total = X_total.shape[0]\n",
"\n",
"# Affichage d'une image\n",
"ind_im = 20000\n",
"im = X_total[ind_im,:].reshape((28, 28))\n",
"\n",
"fig, (ax1) = plt.subplots(nrows=1, figsize=(8,8))\n",
"ax1.set_title('Image nombre MNIST')\n",
"ax1.imshow(im, extent=[0,100,0,1], aspect=100, cmap='gray')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**b) Select $6000$ images at random in the dataset.** \n",
"\n",
"**c) Split the images into two parts (with proportions $1/3,2/3$ for example) : a training set and a testing set.**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**2) We will now implement the nearest neighbor classification rule. For this, you may use the function `cdist` from the module `scipy.spatial.distance`. This function allows, given two design matrices to compute all squared Euclidean distances between points.**\n",
"\n",
"**a) Write a function that takes as inputs the number $K$ of desired nearest neighbors, the training data, the testing data, and outputs the confusion matrix on the test data.**\n",
"\n",
"For the confusion matrix, you can use the commande `confusion_matrix` from the module `sklearn.metrics`."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"from scipy.spatial.distance import cdist\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**b) Display the classification error on the training and testing data as a function of $k$ (number of nearest neighbors) (note that the implicit complexity / smoothness of the learned function is decreasing with the number of nearest neighbors). You can vary $k$ from 1 to 20 for example.**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**c) Split your training data into a reduced training data and a validation set (this technique is called simple validation). By using the previous code, write a function that will select the best $K$ by using the validation set. **"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**d) Split several times randomly the original training data. Is the estimator of $K$ stable?**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**3) We now wish to select $K$ by cross-validation. Implement $V$-fold cross-validation for $V=8$ and select the best $K$. Vary randomly the split of the data into the folds and look at the behavior of the selected $K$. What do you notice?**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}