wine_ml_project/Untitled.ipynb
User Name 2a7969d0d5 init
2025-06-07 23:27:56 +02:00

955 lines
36 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "aaad2c40-c784-466b-9f49-7a60e1cbdb4c",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import warnings\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn import metrics\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.svm import SVC\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.naive_bayes import GaussianNB\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "23f82a60-419e-4af2-9401-1996bec33332",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fixed acidity</th>\n",
" <th>volatile acidity</th>\n",
" <th>citric acid</th>\n",
" <th>residual sugar</th>\n",
" <th>chlorides</th>\n",
" <th>free sulfur dioxide</th>\n",
" <th>total sulfur dioxide</th>\n",
" <th>density</th>\n",
" <th>pH</th>\n",
" <th>sulphates</th>\n",
" <th>alcohol</th>\n",
" <th>quality</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>7.8</td>\n",
" <td>0.88</td>\n",
" <td>0.00</td>\n",
" <td>2.6</td>\n",
" <td>0.098</td>\n",
" <td>25.0</td>\n",
" <td>67.0</td>\n",
" <td>0.9968</td>\n",
" <td>3.20</td>\n",
" <td>0.68</td>\n",
" <td>9.8</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.8</td>\n",
" <td>0.76</td>\n",
" <td>0.04</td>\n",
" <td>2.3</td>\n",
" <td>0.092</td>\n",
" <td>15.0</td>\n",
" <td>54.0</td>\n",
" <td>0.9970</td>\n",
" <td>3.26</td>\n",
" <td>0.65</td>\n",
" <td>9.8</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>11.2</td>\n",
" <td>0.28</td>\n",
" <td>0.56</td>\n",
" <td>1.9</td>\n",
" <td>0.075</td>\n",
" <td>17.0</td>\n",
" <td>60.0</td>\n",
" <td>0.9980</td>\n",
" <td>3.16</td>\n",
" <td>0.58</td>\n",
" <td>9.8</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality \n",
"0 9.4 5 \n",
"1 9.8 5 \n",
"2 9.8 5 \n",
"3 9.8 6 \n",
"4 9.4 5 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('winequality-red.csv')\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cbb402ec-23b9-4360-849a-2fccd12d48e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1599 entries, 0 to 1598\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 fixed acidity 1599 non-null float64\n",
" 1 volatile acidity 1599 non-null float64\n",
" 2 citric acid 1599 non-null float64\n",
" 3 residual sugar 1599 non-null float64\n",
" 4 chlorides 1599 non-null float64\n",
" 5 free sulfur dioxide 1599 non-null float64\n",
" 6 total sulfur dioxide 1599 non-null float64\n",
" 7 density 1599 non-null float64\n",
" 8 pH 1599 non-null float64\n",
" 9 sulphates 1599 non-null float64\n",
" 10 alcohol 1599 non-null float64\n",
" 11 quality 1599 non-null int64 \n",
"dtypes: float64(11), int64(1)\n",
"memory usage: 150.0 KB\n"
]
},
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_c3f81_row0_col0, #T_c3f81_row1_col0, #T_c3f81_row1_col7, #T_c3f81_row2_col0, #T_c3f81_row2_col3, #T_c3f81_row2_col4, #T_c3f81_row2_col7, #T_c3f81_row3_col0, #T_c3f81_row4_col0, #T_c3f81_row4_col1, #T_c3f81_row4_col2, #T_c3f81_row4_col3, #T_c3f81_row4_col4, #T_c3f81_row4_col5, #T_c3f81_row4_col6, #T_c3f81_row4_col7, #T_c3f81_row5_col0, #T_c3f81_row6_col0, #T_c3f81_row7_col0, #T_c3f81_row7_col2, #T_c3f81_row7_col7, #T_c3f81_row8_col0, #T_c3f81_row9_col0, #T_c3f81_row10_col0, #T_c3f81_row11_col0 {\n",
" background-color: #fff7fb;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col1, #T_c3f81_row10_col6 {\n",
" background-color: #e0deed;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col2, #T_c3f81_row0_col7, #T_c3f81_row3_col1, #T_c3f81_row3_col5, #T_c3f81_row3_col7, #T_c3f81_row8_col6 {\n",
" background-color: #f7f0f7;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col3 {\n",
" background-color: #5c9fc9;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c3f81_row0_col4 {\n",
" background-color: #b8c6e0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col5 {\n",
" background-color: #dad9ea;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col6 {\n",
" background-color: #e7e3f0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row1_col1, #T_c3f81_row1_col5, #T_c3f81_row1_col6, #T_c3f81_row9_col6 {\n",
" background-color: #fef6fa;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row1_col2, #T_c3f81_row2_col1, #T_c3f81_row2_col2, #T_c3f81_row2_col5, #T_c3f81_row2_col6, #T_c3f81_row8_col2, #T_c3f81_row9_col2, #T_c3f81_row9_col7 {\n",
" background-color: #fef6fb;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row1_col3, #T_c3f81_row1_col4, #T_c3f81_row7_col6, #T_c3f81_row8_col7, #T_c3f81_row9_col1, #T_c3f81_row9_col5 {\n",
" background-color: #fdf5fa;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row3_col2, #T_c3f81_row3_col6, #T_c3f81_row7_col4, #T_c3f81_row9_col3 {\n",
" background-color: #f9f2f8;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row3_col3 {\n",
" background-color: #efe9f3;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row3_col4, #T_c3f81_row8_col5 {\n",
" background-color: #f2ecf5;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col1 {\n",
" background-color: #b1c2de;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col2 {\n",
" background-color: #b9c6e0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col3, #T_c3f81_row7_col3, #T_c3f81_row11_col1 {\n",
" background-color: #ede8f3;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col4 {\n",
" background-color: #bbc7e0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col5 {\n",
" background-color: #a9bfdc;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col6 {\n",
" background-color: #b3c3de;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col7 {\n",
" background-color: #d1d2e6;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row6_col1, #T_c3f81_row6_col2, #T_c3f81_row6_col4, #T_c3f81_row6_col5, #T_c3f81_row6_col6, #T_c3f81_row6_col7, #T_c3f81_row10_col3 {\n",
" background-color: #023858;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c3f81_row6_col3 {\n",
" background-color: #1379b5;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c3f81_row7_col1, #T_c3f81_row9_col4 {\n",
" background-color: #fcf4fa;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row7_col5, #T_c3f81_row11_col2, #T_c3f81_row11_col7 {\n",
" background-color: #fbf4f9;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row8_col1 {\n",
" background-color: #f5eef6;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row8_col3 {\n",
" background-color: #b7c5df;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row8_col4 {\n",
" background-color: #e8e4f0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col1, #T_c3f81_row11_col4 {\n",
" background-color: #d6d6e9;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col2 {\n",
" background-color: #faf3f9;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col4 {\n",
" background-color: #8fb4d6;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col5 {\n",
" background-color: #cacee5;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col7 {\n",
" background-color: #f8f1f8;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row11_col3 {\n",
" background-color: #acc0dd;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row11_col5 {\n",
" background-color: #e6e2ef;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row11_col6 {\n",
" background-color: #f1ebf4;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_c3f81\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_c3f81_level0_col0\" class=\"col_heading level0 col0\" >count</th>\n",
" <th id=\"T_c3f81_level0_col1\" class=\"col_heading level0 col1\" >mean</th>\n",
" <th id=\"T_c3f81_level0_col2\" class=\"col_heading level0 col2\" >std</th>\n",
" <th id=\"T_c3f81_level0_col3\" class=\"col_heading level0 col3\" >min</th>\n",
" <th id=\"T_c3f81_level0_col4\" class=\"col_heading level0 col4\" >25%</th>\n",
" <th id=\"T_c3f81_level0_col5\" class=\"col_heading level0 col5\" >50%</th>\n",
" <th id=\"T_c3f81_level0_col6\" class=\"col_heading level0 col6\" >75%</th>\n",
" <th id=\"T_c3f81_level0_col7\" class=\"col_heading level0 col7\" >max</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row0\" class=\"row_heading level0 row0\" >fixed acidity</th>\n",
" <td id=\"T_c3f81_row0_col0\" class=\"data row0 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row0_col1\" class=\"data row0 col1\" >8.319637</td>\n",
" <td id=\"T_c3f81_row0_col2\" class=\"data row0 col2\" >1.741096</td>\n",
" <td id=\"T_c3f81_row0_col3\" class=\"data row0 col3\" >4.600000</td>\n",
" <td id=\"T_c3f81_row0_col4\" class=\"data row0 col4\" >7.100000</td>\n",
" <td id=\"T_c3f81_row0_col5\" class=\"data row0 col5\" >7.900000</td>\n",
" <td id=\"T_c3f81_row0_col6\" class=\"data row0 col6\" >9.200000</td>\n",
" <td id=\"T_c3f81_row0_col7\" class=\"data row0 col7\" >15.900000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row1\" class=\"row_heading level0 row1\" >volatile acidity</th>\n",
" <td id=\"T_c3f81_row1_col0\" class=\"data row1 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row1_col1\" class=\"data row1 col1\" >0.527821</td>\n",
" <td id=\"T_c3f81_row1_col2\" class=\"data row1 col2\" >0.179060</td>\n",
" <td id=\"T_c3f81_row1_col3\" class=\"data row1 col3\" >0.120000</td>\n",
" <td id=\"T_c3f81_row1_col4\" class=\"data row1 col4\" >0.390000</td>\n",
" <td id=\"T_c3f81_row1_col5\" class=\"data row1 col5\" >0.520000</td>\n",
" <td id=\"T_c3f81_row1_col6\" class=\"data row1 col6\" >0.640000</td>\n",
" <td id=\"T_c3f81_row1_col7\" class=\"data row1 col7\" >1.580000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row2\" class=\"row_heading level0 row2\" >citric acid</th>\n",
" <td id=\"T_c3f81_row2_col0\" class=\"data row2 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row2_col1\" class=\"data row2 col1\" >0.270976</td>\n",
" <td id=\"T_c3f81_row2_col2\" class=\"data row2 col2\" >0.194801</td>\n",
" <td id=\"T_c3f81_row2_col3\" class=\"data row2 col3\" >0.000000</td>\n",
" <td id=\"T_c3f81_row2_col4\" class=\"data row2 col4\" >0.090000</td>\n",
" <td id=\"T_c3f81_row2_col5\" class=\"data row2 col5\" >0.260000</td>\n",
" <td id=\"T_c3f81_row2_col6\" class=\"data row2 col6\" >0.420000</td>\n",
" <td id=\"T_c3f81_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row3\" class=\"row_heading level0 row3\" >residual sugar</th>\n",
" <td id=\"T_c3f81_row3_col0\" class=\"data row3 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row3_col1\" class=\"data row3 col1\" >2.538806</td>\n",
" <td id=\"T_c3f81_row3_col2\" class=\"data row3 col2\" >1.409928</td>\n",
" <td id=\"T_c3f81_row3_col3\" class=\"data row3 col3\" >0.900000</td>\n",
" <td id=\"T_c3f81_row3_col4\" class=\"data row3 col4\" >1.900000</td>\n",
" <td id=\"T_c3f81_row3_col5\" class=\"data row3 col5\" >2.200000</td>\n",
" <td id=\"T_c3f81_row3_col6\" class=\"data row3 col6\" >2.600000</td>\n",
" <td id=\"T_c3f81_row3_col7\" class=\"data row3 col7\" >15.500000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row4\" class=\"row_heading level0 row4\" >chlorides</th>\n",
" <td id=\"T_c3f81_row4_col0\" class=\"data row4 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row4_col1\" class=\"data row4 col1\" >0.087467</td>\n",
" <td id=\"T_c3f81_row4_col2\" class=\"data row4 col2\" >0.047065</td>\n",
" <td id=\"T_c3f81_row4_col3\" class=\"data row4 col3\" >0.012000</td>\n",
" <td id=\"T_c3f81_row4_col4\" class=\"data row4 col4\" >0.070000</td>\n",
" <td id=\"T_c3f81_row4_col5\" class=\"data row4 col5\" >0.079000</td>\n",
" <td id=\"T_c3f81_row4_col6\" class=\"data row4 col6\" >0.090000</td>\n",
" <td id=\"T_c3f81_row4_col7\" class=\"data row4 col7\" >0.611000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row5\" class=\"row_heading level0 row5\" >free sulfur dioxide</th>\n",
" <td id=\"T_c3f81_row5_col0\" class=\"data row5 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row5_col1\" class=\"data row5 col1\" >15.874922</td>\n",
" <td id=\"T_c3f81_row5_col2\" class=\"data row5 col2\" >10.460157</td>\n",
" <td id=\"T_c3f81_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_c3f81_row5_col4\" class=\"data row5 col4\" >7.000000</td>\n",
" <td id=\"T_c3f81_row5_col5\" class=\"data row5 col5\" >14.000000</td>\n",
" <td id=\"T_c3f81_row5_col6\" class=\"data row5 col6\" >21.000000</td>\n",
" <td id=\"T_c3f81_row5_col7\" class=\"data row5 col7\" >72.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row6\" class=\"row_heading level0 row6\" >total sulfur dioxide</th>\n",
" <td id=\"T_c3f81_row6_col0\" class=\"data row6 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row6_col1\" class=\"data row6 col1\" >46.467792</td>\n",
" <td id=\"T_c3f81_row6_col2\" class=\"data row6 col2\" >32.895324</td>\n",
" <td id=\"T_c3f81_row6_col3\" class=\"data row6 col3\" >6.000000</td>\n",
" <td id=\"T_c3f81_row6_col4\" class=\"data row6 col4\" >22.000000</td>\n",
" <td id=\"T_c3f81_row6_col5\" class=\"data row6 col5\" >38.000000</td>\n",
" <td id=\"T_c3f81_row6_col6\" class=\"data row6 col6\" >62.000000</td>\n",
" <td id=\"T_c3f81_row6_col7\" class=\"data row6 col7\" >289.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row7\" class=\"row_heading level0 row7\" >density</th>\n",
" <td id=\"T_c3f81_row7_col0\" class=\"data row7 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row7_col1\" class=\"data row7 col1\" >0.996747</td>\n",
" <td id=\"T_c3f81_row7_col2\" class=\"data row7 col2\" >0.001887</td>\n",
" <td id=\"T_c3f81_row7_col3\" class=\"data row7 col3\" >0.990070</td>\n",
" <td id=\"T_c3f81_row7_col4\" class=\"data row7 col4\" >0.995600</td>\n",
" <td id=\"T_c3f81_row7_col5\" class=\"data row7 col5\" >0.996750</td>\n",
" <td id=\"T_c3f81_row7_col6\" class=\"data row7 col6\" >0.997835</td>\n",
" <td id=\"T_c3f81_row7_col7\" class=\"data row7 col7\" >1.003690</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row8\" class=\"row_heading level0 row8\" >pH</th>\n",
" <td id=\"T_c3f81_row8_col0\" class=\"data row8 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row8_col1\" class=\"data row8 col1\" >3.311113</td>\n",
" <td id=\"T_c3f81_row8_col2\" class=\"data row8 col2\" >0.154386</td>\n",
" <td id=\"T_c3f81_row8_col3\" class=\"data row8 col3\" >2.740000</td>\n",
" <td id=\"T_c3f81_row8_col4\" class=\"data row8 col4\" >3.210000</td>\n",
" <td id=\"T_c3f81_row8_col5\" class=\"data row8 col5\" >3.310000</td>\n",
" <td id=\"T_c3f81_row8_col6\" class=\"data row8 col6\" >3.400000</td>\n",
" <td id=\"T_c3f81_row8_col7\" class=\"data row8 col7\" >4.010000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row9\" class=\"row_heading level0 row9\" >sulphates</th>\n",
" <td id=\"T_c3f81_row9_col0\" class=\"data row9 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row9_col1\" class=\"data row9 col1\" >0.658149</td>\n",
" <td id=\"T_c3f81_row9_col2\" class=\"data row9 col2\" >0.169507</td>\n",
" <td id=\"T_c3f81_row9_col3\" class=\"data row9 col3\" >0.330000</td>\n",
" <td id=\"T_c3f81_row9_col4\" class=\"data row9 col4\" >0.550000</td>\n",
" <td id=\"T_c3f81_row9_col5\" class=\"data row9 col5\" >0.620000</td>\n",
" <td id=\"T_c3f81_row9_col6\" class=\"data row9 col6\" >0.730000</td>\n",
" <td id=\"T_c3f81_row9_col7\" class=\"data row9 col7\" >2.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row10\" class=\"row_heading level0 row10\" >alcohol</th>\n",
" <td id=\"T_c3f81_row10_col0\" class=\"data row10 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row10_col1\" class=\"data row10 col1\" >10.422983</td>\n",
" <td id=\"T_c3f81_row10_col2\" class=\"data row10 col2\" >1.065668</td>\n",
" <td id=\"T_c3f81_row10_col3\" class=\"data row10 col3\" >8.400000</td>\n",
" <td id=\"T_c3f81_row10_col4\" class=\"data row10 col4\" >9.500000</td>\n",
" <td id=\"T_c3f81_row10_col5\" class=\"data row10 col5\" >10.200000</td>\n",
" <td id=\"T_c3f81_row10_col6\" class=\"data row10 col6\" >11.100000</td>\n",
" <td id=\"T_c3f81_row10_col7\" class=\"data row10 col7\" >14.900000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row11\" class=\"row_heading level0 row11\" >quality</th>\n",
" <td id=\"T_c3f81_row11_col0\" class=\"data row11 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row11_col1\" class=\"data row11 col1\" >5.636023</td>\n",
" <td id=\"T_c3f81_row11_col2\" class=\"data row11 col2\" >0.807569</td>\n",
" <td id=\"T_c3f81_row11_col3\" class=\"data row11 col3\" >3.000000</td>\n",
" <td id=\"T_c3f81_row11_col4\" class=\"data row11 col4\" >5.000000</td>\n",
" <td id=\"T_c3f81_row11_col5\" class=\"data row11 col5\" >6.000000</td>\n",
" <td id=\"T_c3f81_row11_col6\" class=\"data row11 col6\" >6.000000</td>\n",
" <td id=\"T_c3f81_row11_col7\" class=\"data row11 col7\" >8.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2ded8e0b820>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.info()\n",
"data.describe().T.style.background_gradient(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4b1487c6-ae6c-4ee6-a0e1-03a156654bda",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fixed acidity 0\n",
"volatile acidity 0\n",
"citric acid 0\n",
"residual sugar 0\n",
"chlorides 0\n",
"free sulfur dioxide 0\n",
"total sulfur dioxide 0\n",
"density 0\n",
"pH 0\n",
"sulphates 0\n",
"alcohol 0\n",
"quality 0\n",
"dtype: int64"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.isna().sum()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2d090c6d-5de2-41d5-a098-3516dcdb7155",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fixed acidity</th>\n",
" <th>volatile acidity</th>\n",
" <th>citric acid</th>\n",
" <th>residual sugar</th>\n",
" <th>chlorides</th>\n",
" <th>free sulfur dioxide</th>\n",
" <th>total sulfur dioxide</th>\n",
" <th>density</th>\n",
" <th>pH</th>\n",
" <th>sulphates</th>\n",
" <th>alcohol</th>\n",
" <th>quality</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>7.8</td>\n",
" <td>0.88</td>\n",
" <td>0.00</td>\n",
" <td>2.6</td>\n",
" <td>0.098</td>\n",
" <td>25.0</td>\n",
" <td>67.0</td>\n",
" <td>0.9968</td>\n",
" <td>3.20</td>\n",
" <td>0.68</td>\n",
" <td>9.8</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.8</td>\n",
" <td>0.76</td>\n",
" <td>0.04</td>\n",
" <td>2.3</td>\n",
" <td>0.092</td>\n",
" <td>15.0</td>\n",
" <td>54.0</td>\n",
" <td>0.9970</td>\n",
" <td>3.26</td>\n",
" <td>0.65</td>\n",
" <td>9.8</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>11.2</td>\n",
" <td>0.28</td>\n",
" <td>0.56</td>\n",
" <td>1.9</td>\n",
" <td>0.075</td>\n",
" <td>17.0</td>\n",
" <td>60.0</td>\n",
" <td>0.9980</td>\n",
" <td>3.16</td>\n",
" <td>0.58</td>\n",
" <td>9.8</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality \n",
"0 9.4 Average \n",
"1 9.8 Average \n",
"2 9.8 Average \n",
"3 9.8 Average \n",
"4 9.4 Average "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = data.replace({'quality': {8: 'Good', 7: 'Good', 6: 'Average', 5: 'Average', 4: 'Bad', 3: 'Bad'}})\n",
"\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0d672775-f7ef-468e-ac49-a6ab5254ee3a",
"metadata": {},
"outputs": [],
"source": [
"X = data.drop(columns='quality')\n",
"y = data.quality\n",
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
"X_scaled = scaler.fit_transform(X)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "32357ab4-74d6-4a32-8ff9-249e04984341",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "13b6ce95-34a6-48ea-a4b8-bdd69c3ca496",
"metadata": {},
"outputs": [],
"source": [
"models = {\n",
" \"Logistic Regression\": LogisticRegression(),\n",
" \"SVM\": SVC(),\n",
" \"Random Forest\": RandomForestClassifier(),\n",
" \"Decision Tree\": DecisionTreeClassifier(),\n",
" \"K-Nearest Neighbors\": KNeighborsClassifier(),\n",
" \"Naive Bayes\": GaussianNB()\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9096f5a4-8329-47d0-9b7c-8e3123191a3b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Logistic Regression...\n",
"Logistic Regression Accuracy: 0.828125\n",
"Training SVM...\n",
"SVM Accuracy: 0.8375\n",
"Training Random Forest...\n",
"Random Forest Accuracy: 0.865625\n",
"Training Decision Tree...\n",
"Decision Tree Accuracy: 0.79375\n",
"Training K-Nearest Neighbors...\n",
"K-Nearest Neighbors Accuracy: 0.840625\n",
"Training Naive Bayes...\n",
"Naive Bayes Accuracy: 0.803125\n"
]
}
],
"source": [
"results = {}\n",
"for model_name, model in models.items():\n",
" print(f\"Training {model_name}...\")\n",
" model.fit(X_train, y_train)\n",
" y_pred = model.predict(X_test)\n",
" accuracy = metrics.accuracy_score(y_test, y_pred)\n",
" results[model_name] = accuracy\n",
" print(f\"{model_name} Accuracy: {accuracy}\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d8f786c8-67e1-4131-bb2c-22d91d7fafab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Model Comparison (Accuracy):\n",
"Random Forest: 0.8656\n",
"K-Nearest Neighbors: 0.8406\n",
"SVM: 0.8375\n",
"Logistic Regression: 0.8281\n",
"Naive Bayes: 0.8031\n",
"Decision Tree: 0.7937\n"
]
}
],
"source": [
"sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)\n",
"print(\"\\nModel Comparison (Accuracy):\")\n",
"for model_name, accuracy in sorted_results:\n",
" print(f\"{model_name}: {accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b6eec30b-3a13-4f99-b841-a8633a8b4ad7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 27 candidates, totalling 135 fits\n",
"Best parameters for Random Forest: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n"
]
}
],
"source": [
"# Example: Hyperparameter tuning for Random Forest\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200],\n",
" 'max_depth': [10, 20, 30],\n",
" 'min_samples_split': [2, 5, 10]\n",
"}\n",
"\n",
"grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5, n_jobs=-1, verbose=1)\n",
"grid_search.fit(X_train, y_train)\n",
"print(f\"Best parameters for Random Forest: {grid_search.best_params_}\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "488c6fe7-1f79-4e6a-9649-fd82c0acd44e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cross-validation scores for Random Forest: [0.834375 0.821875 0.840625 0.8125 0.84326019]\n",
"Mean cross-validation score: 0.830527037617555\n"
]
}
],
"source": [
"# Cross-validation score for the best model (Random Forest)\n",
"best_model = grid_search.best_estimator_\n",
"cv_scores = cross_val_score(best_model, X_scaled, y, cv=5)\n",
"print(f\"Cross-validation scores for Random Forest: {cv_scores}\")\n",
"print(f\"Mean cross-validation score: {cv_scores.mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "81b08573-19a3-4508-9c9e-4a36ced59197",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" Average 0.88 0.95 0.92 262\n",
" Bad 1.00 0.00 0.00 11\n",
" Good 0.67 0.51 0.58 47\n",
"\n",
" accuracy 0.86 320\n",
" macro avg 0.85 0.49 0.50 320\n",
"weighted avg 0.85 0.86 0.83 320\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"\n",
"# Classification report with zero_division parameter to handle undefined precision\n",
"y_pred = best_model.predict(X_test)\n",
"print(\"\\nClassification Report:\")\n",
"print(classification_report(y_test, y_pred, zero_division=1))\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "191fa955-1969-4224-8420-a608c1463e31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model and Scaler have been saved.\n"
]
}
],
"source": [
"import joblib\n",
"\n",
"# Save the best model (Random Forest with Hyperparameter Tuning)\n",
"joblib.dump(best_model, 'best_model.pkl')\n",
"\n",
"# Save the scaler (MinMax SScaler used for feature scaling)\n",
"joblib.dump(scaler, 'scaler.pkl')\n",
"\n",
"print(\"Model and Scaler have been saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f75662b-1b38-4510-bccd-adefabbe635f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}