This commit is contained in:
User Name 2025-06-07 23:27:56 +02:00
commit 2a7969d0d5
6 changed files with 2690 additions and 0 deletions

954
Untitled.ipynb Normal file
View File

@ -0,0 +1,954 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "aaad2c40-c784-466b-9f49-7a60e1cbdb4c",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import warnings\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn import metrics\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.svm import SVC\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.naive_bayes import GaussianNB\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "23f82a60-419e-4af2-9401-1996bec33332",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fixed acidity</th>\n",
" <th>volatile acidity</th>\n",
" <th>citric acid</th>\n",
" <th>residual sugar</th>\n",
" <th>chlorides</th>\n",
" <th>free sulfur dioxide</th>\n",
" <th>total sulfur dioxide</th>\n",
" <th>density</th>\n",
" <th>pH</th>\n",
" <th>sulphates</th>\n",
" <th>alcohol</th>\n",
" <th>quality</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>7.8</td>\n",
" <td>0.88</td>\n",
" <td>0.00</td>\n",
" <td>2.6</td>\n",
" <td>0.098</td>\n",
" <td>25.0</td>\n",
" <td>67.0</td>\n",
" <td>0.9968</td>\n",
" <td>3.20</td>\n",
" <td>0.68</td>\n",
" <td>9.8</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.8</td>\n",
" <td>0.76</td>\n",
" <td>0.04</td>\n",
" <td>2.3</td>\n",
" <td>0.092</td>\n",
" <td>15.0</td>\n",
" <td>54.0</td>\n",
" <td>0.9970</td>\n",
" <td>3.26</td>\n",
" <td>0.65</td>\n",
" <td>9.8</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>11.2</td>\n",
" <td>0.28</td>\n",
" <td>0.56</td>\n",
" <td>1.9</td>\n",
" <td>0.075</td>\n",
" <td>17.0</td>\n",
" <td>60.0</td>\n",
" <td>0.9980</td>\n",
" <td>3.16</td>\n",
" <td>0.58</td>\n",
" <td>9.8</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality \n",
"0 9.4 5 \n",
"1 9.8 5 \n",
"2 9.8 5 \n",
"3 9.8 6 \n",
"4 9.4 5 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('winequality-red.csv')\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cbb402ec-23b9-4360-849a-2fccd12d48e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1599 entries, 0 to 1598\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 fixed acidity 1599 non-null float64\n",
" 1 volatile acidity 1599 non-null float64\n",
" 2 citric acid 1599 non-null float64\n",
" 3 residual sugar 1599 non-null float64\n",
" 4 chlorides 1599 non-null float64\n",
" 5 free sulfur dioxide 1599 non-null float64\n",
" 6 total sulfur dioxide 1599 non-null float64\n",
" 7 density 1599 non-null float64\n",
" 8 pH 1599 non-null float64\n",
" 9 sulphates 1599 non-null float64\n",
" 10 alcohol 1599 non-null float64\n",
" 11 quality 1599 non-null int64 \n",
"dtypes: float64(11), int64(1)\n",
"memory usage: 150.0 KB\n"
]
},
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_c3f81_row0_col0, #T_c3f81_row1_col0, #T_c3f81_row1_col7, #T_c3f81_row2_col0, #T_c3f81_row2_col3, #T_c3f81_row2_col4, #T_c3f81_row2_col7, #T_c3f81_row3_col0, #T_c3f81_row4_col0, #T_c3f81_row4_col1, #T_c3f81_row4_col2, #T_c3f81_row4_col3, #T_c3f81_row4_col4, #T_c3f81_row4_col5, #T_c3f81_row4_col6, #T_c3f81_row4_col7, #T_c3f81_row5_col0, #T_c3f81_row6_col0, #T_c3f81_row7_col0, #T_c3f81_row7_col2, #T_c3f81_row7_col7, #T_c3f81_row8_col0, #T_c3f81_row9_col0, #T_c3f81_row10_col0, #T_c3f81_row11_col0 {\n",
" background-color: #fff7fb;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col1, #T_c3f81_row10_col6 {\n",
" background-color: #e0deed;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col2, #T_c3f81_row0_col7, #T_c3f81_row3_col1, #T_c3f81_row3_col5, #T_c3f81_row3_col7, #T_c3f81_row8_col6 {\n",
" background-color: #f7f0f7;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col3 {\n",
" background-color: #5c9fc9;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c3f81_row0_col4 {\n",
" background-color: #b8c6e0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col5 {\n",
" background-color: #dad9ea;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row0_col6 {\n",
" background-color: #e7e3f0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row1_col1, #T_c3f81_row1_col5, #T_c3f81_row1_col6, #T_c3f81_row9_col6 {\n",
" background-color: #fef6fa;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row1_col2, #T_c3f81_row2_col1, #T_c3f81_row2_col2, #T_c3f81_row2_col5, #T_c3f81_row2_col6, #T_c3f81_row8_col2, #T_c3f81_row9_col2, #T_c3f81_row9_col7 {\n",
" background-color: #fef6fb;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row1_col3, #T_c3f81_row1_col4, #T_c3f81_row7_col6, #T_c3f81_row8_col7, #T_c3f81_row9_col1, #T_c3f81_row9_col5 {\n",
" background-color: #fdf5fa;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row3_col2, #T_c3f81_row3_col6, #T_c3f81_row7_col4, #T_c3f81_row9_col3 {\n",
" background-color: #f9f2f8;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row3_col3 {\n",
" background-color: #efe9f3;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row3_col4, #T_c3f81_row8_col5 {\n",
" background-color: #f2ecf5;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col1 {\n",
" background-color: #b1c2de;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col2 {\n",
" background-color: #b9c6e0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col3, #T_c3f81_row7_col3, #T_c3f81_row11_col1 {\n",
" background-color: #ede8f3;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col4 {\n",
" background-color: #bbc7e0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col5 {\n",
" background-color: #a9bfdc;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col6 {\n",
" background-color: #b3c3de;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row5_col7 {\n",
" background-color: #d1d2e6;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row6_col1, #T_c3f81_row6_col2, #T_c3f81_row6_col4, #T_c3f81_row6_col5, #T_c3f81_row6_col6, #T_c3f81_row6_col7, #T_c3f81_row10_col3 {\n",
" background-color: #023858;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c3f81_row6_col3 {\n",
" background-color: #1379b5;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c3f81_row7_col1, #T_c3f81_row9_col4 {\n",
" background-color: #fcf4fa;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row7_col5, #T_c3f81_row11_col2, #T_c3f81_row11_col7 {\n",
" background-color: #fbf4f9;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row8_col1 {\n",
" background-color: #f5eef6;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row8_col3 {\n",
" background-color: #b7c5df;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row8_col4 {\n",
" background-color: #e8e4f0;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col1, #T_c3f81_row11_col4 {\n",
" background-color: #d6d6e9;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col2 {\n",
" background-color: #faf3f9;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col4 {\n",
" background-color: #8fb4d6;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col5 {\n",
" background-color: #cacee5;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row10_col7 {\n",
" background-color: #f8f1f8;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row11_col3 {\n",
" background-color: #acc0dd;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row11_col5 {\n",
" background-color: #e6e2ef;\n",
" color: #000000;\n",
"}\n",
"#T_c3f81_row11_col6 {\n",
" background-color: #f1ebf4;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_c3f81\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_c3f81_level0_col0\" class=\"col_heading level0 col0\" >count</th>\n",
" <th id=\"T_c3f81_level0_col1\" class=\"col_heading level0 col1\" >mean</th>\n",
" <th id=\"T_c3f81_level0_col2\" class=\"col_heading level0 col2\" >std</th>\n",
" <th id=\"T_c3f81_level0_col3\" class=\"col_heading level0 col3\" >min</th>\n",
" <th id=\"T_c3f81_level0_col4\" class=\"col_heading level0 col4\" >25%</th>\n",
" <th id=\"T_c3f81_level0_col5\" class=\"col_heading level0 col5\" >50%</th>\n",
" <th id=\"T_c3f81_level0_col6\" class=\"col_heading level0 col6\" >75%</th>\n",
" <th id=\"T_c3f81_level0_col7\" class=\"col_heading level0 col7\" >max</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row0\" class=\"row_heading level0 row0\" >fixed acidity</th>\n",
" <td id=\"T_c3f81_row0_col0\" class=\"data row0 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row0_col1\" class=\"data row0 col1\" >8.319637</td>\n",
" <td id=\"T_c3f81_row0_col2\" class=\"data row0 col2\" >1.741096</td>\n",
" <td id=\"T_c3f81_row0_col3\" class=\"data row0 col3\" >4.600000</td>\n",
" <td id=\"T_c3f81_row0_col4\" class=\"data row0 col4\" >7.100000</td>\n",
" <td id=\"T_c3f81_row0_col5\" class=\"data row0 col5\" >7.900000</td>\n",
" <td id=\"T_c3f81_row0_col6\" class=\"data row0 col6\" >9.200000</td>\n",
" <td id=\"T_c3f81_row0_col7\" class=\"data row0 col7\" >15.900000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row1\" class=\"row_heading level0 row1\" >volatile acidity</th>\n",
" <td id=\"T_c3f81_row1_col0\" class=\"data row1 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row1_col1\" class=\"data row1 col1\" >0.527821</td>\n",
" <td id=\"T_c3f81_row1_col2\" class=\"data row1 col2\" >0.179060</td>\n",
" <td id=\"T_c3f81_row1_col3\" class=\"data row1 col3\" >0.120000</td>\n",
" <td id=\"T_c3f81_row1_col4\" class=\"data row1 col4\" >0.390000</td>\n",
" <td id=\"T_c3f81_row1_col5\" class=\"data row1 col5\" >0.520000</td>\n",
" <td id=\"T_c3f81_row1_col6\" class=\"data row1 col6\" >0.640000</td>\n",
" <td id=\"T_c3f81_row1_col7\" class=\"data row1 col7\" >1.580000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row2\" class=\"row_heading level0 row2\" >citric acid</th>\n",
" <td id=\"T_c3f81_row2_col0\" class=\"data row2 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row2_col1\" class=\"data row2 col1\" >0.270976</td>\n",
" <td id=\"T_c3f81_row2_col2\" class=\"data row2 col2\" >0.194801</td>\n",
" <td id=\"T_c3f81_row2_col3\" class=\"data row2 col3\" >0.000000</td>\n",
" <td id=\"T_c3f81_row2_col4\" class=\"data row2 col4\" >0.090000</td>\n",
" <td id=\"T_c3f81_row2_col5\" class=\"data row2 col5\" >0.260000</td>\n",
" <td id=\"T_c3f81_row2_col6\" class=\"data row2 col6\" >0.420000</td>\n",
" <td id=\"T_c3f81_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row3\" class=\"row_heading level0 row3\" >residual sugar</th>\n",
" <td id=\"T_c3f81_row3_col0\" class=\"data row3 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row3_col1\" class=\"data row3 col1\" >2.538806</td>\n",
" <td id=\"T_c3f81_row3_col2\" class=\"data row3 col2\" >1.409928</td>\n",
" <td id=\"T_c3f81_row3_col3\" class=\"data row3 col3\" >0.900000</td>\n",
" <td id=\"T_c3f81_row3_col4\" class=\"data row3 col4\" >1.900000</td>\n",
" <td id=\"T_c3f81_row3_col5\" class=\"data row3 col5\" >2.200000</td>\n",
" <td id=\"T_c3f81_row3_col6\" class=\"data row3 col6\" >2.600000</td>\n",
" <td id=\"T_c3f81_row3_col7\" class=\"data row3 col7\" >15.500000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row4\" class=\"row_heading level0 row4\" >chlorides</th>\n",
" <td id=\"T_c3f81_row4_col0\" class=\"data row4 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row4_col1\" class=\"data row4 col1\" >0.087467</td>\n",
" <td id=\"T_c3f81_row4_col2\" class=\"data row4 col2\" >0.047065</td>\n",
" <td id=\"T_c3f81_row4_col3\" class=\"data row4 col3\" >0.012000</td>\n",
" <td id=\"T_c3f81_row4_col4\" class=\"data row4 col4\" >0.070000</td>\n",
" <td id=\"T_c3f81_row4_col5\" class=\"data row4 col5\" >0.079000</td>\n",
" <td id=\"T_c3f81_row4_col6\" class=\"data row4 col6\" >0.090000</td>\n",
" <td id=\"T_c3f81_row4_col7\" class=\"data row4 col7\" >0.611000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row5\" class=\"row_heading level0 row5\" >free sulfur dioxide</th>\n",
" <td id=\"T_c3f81_row5_col0\" class=\"data row5 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row5_col1\" class=\"data row5 col1\" >15.874922</td>\n",
" <td id=\"T_c3f81_row5_col2\" class=\"data row5 col2\" >10.460157</td>\n",
" <td id=\"T_c3f81_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_c3f81_row5_col4\" class=\"data row5 col4\" >7.000000</td>\n",
" <td id=\"T_c3f81_row5_col5\" class=\"data row5 col5\" >14.000000</td>\n",
" <td id=\"T_c3f81_row5_col6\" class=\"data row5 col6\" >21.000000</td>\n",
" <td id=\"T_c3f81_row5_col7\" class=\"data row5 col7\" >72.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row6\" class=\"row_heading level0 row6\" >total sulfur dioxide</th>\n",
" <td id=\"T_c3f81_row6_col0\" class=\"data row6 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row6_col1\" class=\"data row6 col1\" >46.467792</td>\n",
" <td id=\"T_c3f81_row6_col2\" class=\"data row6 col2\" >32.895324</td>\n",
" <td id=\"T_c3f81_row6_col3\" class=\"data row6 col3\" >6.000000</td>\n",
" <td id=\"T_c3f81_row6_col4\" class=\"data row6 col4\" >22.000000</td>\n",
" <td id=\"T_c3f81_row6_col5\" class=\"data row6 col5\" >38.000000</td>\n",
" <td id=\"T_c3f81_row6_col6\" class=\"data row6 col6\" >62.000000</td>\n",
" <td id=\"T_c3f81_row6_col7\" class=\"data row6 col7\" >289.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row7\" class=\"row_heading level0 row7\" >density</th>\n",
" <td id=\"T_c3f81_row7_col0\" class=\"data row7 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row7_col1\" class=\"data row7 col1\" >0.996747</td>\n",
" <td id=\"T_c3f81_row7_col2\" class=\"data row7 col2\" >0.001887</td>\n",
" <td id=\"T_c3f81_row7_col3\" class=\"data row7 col3\" >0.990070</td>\n",
" <td id=\"T_c3f81_row7_col4\" class=\"data row7 col4\" >0.995600</td>\n",
" <td id=\"T_c3f81_row7_col5\" class=\"data row7 col5\" >0.996750</td>\n",
" <td id=\"T_c3f81_row7_col6\" class=\"data row7 col6\" >0.997835</td>\n",
" <td id=\"T_c3f81_row7_col7\" class=\"data row7 col7\" >1.003690</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row8\" class=\"row_heading level0 row8\" >pH</th>\n",
" <td id=\"T_c3f81_row8_col0\" class=\"data row8 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row8_col1\" class=\"data row8 col1\" >3.311113</td>\n",
" <td id=\"T_c3f81_row8_col2\" class=\"data row8 col2\" >0.154386</td>\n",
" <td id=\"T_c3f81_row8_col3\" class=\"data row8 col3\" >2.740000</td>\n",
" <td id=\"T_c3f81_row8_col4\" class=\"data row8 col4\" >3.210000</td>\n",
" <td id=\"T_c3f81_row8_col5\" class=\"data row8 col5\" >3.310000</td>\n",
" <td id=\"T_c3f81_row8_col6\" class=\"data row8 col6\" >3.400000</td>\n",
" <td id=\"T_c3f81_row8_col7\" class=\"data row8 col7\" >4.010000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row9\" class=\"row_heading level0 row9\" >sulphates</th>\n",
" <td id=\"T_c3f81_row9_col0\" class=\"data row9 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row9_col1\" class=\"data row9 col1\" >0.658149</td>\n",
" <td id=\"T_c3f81_row9_col2\" class=\"data row9 col2\" >0.169507</td>\n",
" <td id=\"T_c3f81_row9_col3\" class=\"data row9 col3\" >0.330000</td>\n",
" <td id=\"T_c3f81_row9_col4\" class=\"data row9 col4\" >0.550000</td>\n",
" <td id=\"T_c3f81_row9_col5\" class=\"data row9 col5\" >0.620000</td>\n",
" <td id=\"T_c3f81_row9_col6\" class=\"data row9 col6\" >0.730000</td>\n",
" <td id=\"T_c3f81_row9_col7\" class=\"data row9 col7\" >2.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row10\" class=\"row_heading level0 row10\" >alcohol</th>\n",
" <td id=\"T_c3f81_row10_col0\" class=\"data row10 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row10_col1\" class=\"data row10 col1\" >10.422983</td>\n",
" <td id=\"T_c3f81_row10_col2\" class=\"data row10 col2\" >1.065668</td>\n",
" <td id=\"T_c3f81_row10_col3\" class=\"data row10 col3\" >8.400000</td>\n",
" <td id=\"T_c3f81_row10_col4\" class=\"data row10 col4\" >9.500000</td>\n",
" <td id=\"T_c3f81_row10_col5\" class=\"data row10 col5\" >10.200000</td>\n",
" <td id=\"T_c3f81_row10_col6\" class=\"data row10 col6\" >11.100000</td>\n",
" <td id=\"T_c3f81_row10_col7\" class=\"data row10 col7\" >14.900000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c3f81_level0_row11\" class=\"row_heading level0 row11\" >quality</th>\n",
" <td id=\"T_c3f81_row11_col0\" class=\"data row11 col0\" >1599.000000</td>\n",
" <td id=\"T_c3f81_row11_col1\" class=\"data row11 col1\" >5.636023</td>\n",
" <td id=\"T_c3f81_row11_col2\" class=\"data row11 col2\" >0.807569</td>\n",
" <td id=\"T_c3f81_row11_col3\" class=\"data row11 col3\" >3.000000</td>\n",
" <td id=\"T_c3f81_row11_col4\" class=\"data row11 col4\" >5.000000</td>\n",
" <td id=\"T_c3f81_row11_col5\" class=\"data row11 col5\" >6.000000</td>\n",
" <td id=\"T_c3f81_row11_col6\" class=\"data row11 col6\" >6.000000</td>\n",
" <td id=\"T_c3f81_row11_col7\" class=\"data row11 col7\" >8.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2ded8e0b820>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.info()\n",
"data.describe().T.style.background_gradient(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4b1487c6-ae6c-4ee6-a0e1-03a156654bda",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fixed acidity 0\n",
"volatile acidity 0\n",
"citric acid 0\n",
"residual sugar 0\n",
"chlorides 0\n",
"free sulfur dioxide 0\n",
"total sulfur dioxide 0\n",
"density 0\n",
"pH 0\n",
"sulphates 0\n",
"alcohol 0\n",
"quality 0\n",
"dtype: int64"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.isna().sum()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2d090c6d-5de2-41d5-a098-3516dcdb7155",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fixed acidity</th>\n",
" <th>volatile acidity</th>\n",
" <th>citric acid</th>\n",
" <th>residual sugar</th>\n",
" <th>chlorides</th>\n",
" <th>free sulfur dioxide</th>\n",
" <th>total sulfur dioxide</th>\n",
" <th>density</th>\n",
" <th>pH</th>\n",
" <th>sulphates</th>\n",
" <th>alcohol</th>\n",
" <th>quality</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>7.8</td>\n",
" <td>0.88</td>\n",
" <td>0.00</td>\n",
" <td>2.6</td>\n",
" <td>0.098</td>\n",
" <td>25.0</td>\n",
" <td>67.0</td>\n",
" <td>0.9968</td>\n",
" <td>3.20</td>\n",
" <td>0.68</td>\n",
" <td>9.8</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.8</td>\n",
" <td>0.76</td>\n",
" <td>0.04</td>\n",
" <td>2.3</td>\n",
" <td>0.092</td>\n",
" <td>15.0</td>\n",
" <td>54.0</td>\n",
" <td>0.9970</td>\n",
" <td>3.26</td>\n",
" <td>0.65</td>\n",
" <td>9.8</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>11.2</td>\n",
" <td>0.28</td>\n",
" <td>0.56</td>\n",
" <td>1.9</td>\n",
" <td>0.075</td>\n",
" <td>17.0</td>\n",
" <td>60.0</td>\n",
" <td>0.9980</td>\n",
" <td>3.16</td>\n",
" <td>0.58</td>\n",
" <td>9.8</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>Average</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality \n",
"0 9.4 Average \n",
"1 9.8 Average \n",
"2 9.8 Average \n",
"3 9.8 Average \n",
"4 9.4 Average "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = data.replace({'quality': {8: 'Good', 7: 'Good', 6: 'Average', 5: 'Average', 4: 'Bad', 3: 'Bad'}})\n",
"\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0d672775-f7ef-468e-ac49-a6ab5254ee3a",
"metadata": {},
"outputs": [],
"source": [
"X = data.drop(columns='quality')\n",
"y = data.quality\n",
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
"X_scaled = scaler.fit_transform(X)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "32357ab4-74d6-4a32-8ff9-249e04984341",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "13b6ce95-34a6-48ea-a4b8-bdd69c3ca496",
"metadata": {},
"outputs": [],
"source": [
"models = {\n",
" \"Logistic Regression\": LogisticRegression(),\n",
" \"SVM\": SVC(),\n",
" \"Random Forest\": RandomForestClassifier(),\n",
" \"Decision Tree\": DecisionTreeClassifier(),\n",
" \"K-Nearest Neighbors\": KNeighborsClassifier(),\n",
" \"Naive Bayes\": GaussianNB()\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9096f5a4-8329-47d0-9b7c-8e3123191a3b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Logistic Regression...\n",
"Logistic Regression Accuracy: 0.828125\n",
"Training SVM...\n",
"SVM Accuracy: 0.8375\n",
"Training Random Forest...\n",
"Random Forest Accuracy: 0.865625\n",
"Training Decision Tree...\n",
"Decision Tree Accuracy: 0.79375\n",
"Training K-Nearest Neighbors...\n",
"K-Nearest Neighbors Accuracy: 0.840625\n",
"Training Naive Bayes...\n",
"Naive Bayes Accuracy: 0.803125\n"
]
}
],
"source": [
"results = {}\n",
"for model_name, model in models.items():\n",
" print(f\"Training {model_name}...\")\n",
" model.fit(X_train, y_train)\n",
" y_pred = model.predict(X_test)\n",
" accuracy = metrics.accuracy_score(y_test, y_pred)\n",
" results[model_name] = accuracy\n",
" print(f\"{model_name} Accuracy: {accuracy}\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d8f786c8-67e1-4131-bb2c-22d91d7fafab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Model Comparison (Accuracy):\n",
"Random Forest: 0.8656\n",
"K-Nearest Neighbors: 0.8406\n",
"SVM: 0.8375\n",
"Logistic Regression: 0.8281\n",
"Naive Bayes: 0.8031\n",
"Decision Tree: 0.7937\n"
]
}
],
"source": [
"sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)\n",
"print(\"\\nModel Comparison (Accuracy):\")\n",
"for model_name, accuracy in sorted_results:\n",
" print(f\"{model_name}: {accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b6eec30b-3a13-4f99-b841-a8633a8b4ad7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 27 candidates, totalling 135 fits\n",
"Best parameters for Random Forest: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n"
]
}
],
"source": [
"# Example: Hyperparameter tuning for Random Forest\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200],\n",
" 'max_depth': [10, 20, 30],\n",
" 'min_samples_split': [2, 5, 10]\n",
"}\n",
"\n",
"grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5, n_jobs=-1, verbose=1)\n",
"grid_search.fit(X_train, y_train)\n",
"print(f\"Best parameters for Random Forest: {grid_search.best_params_}\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "488c6fe7-1f79-4e6a-9649-fd82c0acd44e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cross-validation scores for Random Forest: [0.834375 0.821875 0.840625 0.8125 0.84326019]\n",
"Mean cross-validation score: 0.830527037617555\n"
]
}
],
"source": [
"# Cross-validation score for the best model (Random Forest)\n",
"best_model = grid_search.best_estimator_\n",
"cv_scores = cross_val_score(best_model, X_scaled, y, cv=5)\n",
"print(f\"Cross-validation scores for Random Forest: {cv_scores}\")\n",
"print(f\"Mean cross-validation score: {cv_scores.mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "81b08573-19a3-4508-9c9e-4a36ced59197",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" Average 0.88 0.95 0.92 262\n",
" Bad 1.00 0.00 0.00 11\n",
" Good 0.67 0.51 0.58 47\n",
"\n",
" accuracy 0.86 320\n",
" macro avg 0.85 0.49 0.50 320\n",
"weighted avg 0.85 0.86 0.83 320\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"\n",
"# Classification report with zero_division parameter to handle undefined precision\n",
"y_pred = best_model.predict(X_test)\n",
"print(\"\\nClassification Report:\")\n",
"print(classification_report(y_test, y_pred, zero_division=1))\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "191fa955-1969-4224-8420-a608c1463e31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model and Scaler have been saved.\n"
]
}
],
"source": [
"import joblib\n",
"\n",
"# Save the best model (Random Forest with Hyperparameter Tuning)\n",
"joblib.dump(best_model, 'best_model.pkl')\n",
"\n",
"# Save the scaler (MinMax SScaler used for feature scaling)\n",
"joblib.dump(scaler, 'scaler.pkl')\n",
"\n",
"print(\"Model and Scaler have been saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f75662b-1b38-4510-bccd-adefabbe635f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

75
app.py Normal file
View File

@ -0,0 +1,75 @@
import streamlit as st
import requests
# FastAPI backend URL
BASE_URL = "http://127.0.0.1:8000"
# Define pages
def home_page():
st.title("Wine Quality Prediction API")
st.write("This is the home page of the Wine Quality Prediction App.")
# Communicate with the FastAPI `/` endpoint
try:
response = requests.get(f"{BASE_URL}/")
if response.status_code == 200:
data = response.json()
st.success(data.get("message", "Welcome to the API!"))
else:
st.error("Failed to fetch data from the backend.")
except Exception as e:
st.error(f"Error: {e}")
def prediction_page():
st.title("Wine Quality Prediction")
st.write("Enter the wine features to predict its quality:")
# Input fields for wine features
fixed_acidity = st.number_input("Fixed Acidity", value=7.0)
volatile_acidity = st.number_input("Volatile Acidity", value=0.27)
citric_acid = st.number_input("Citric Acid", value=0.36)
residual_sugar = st.number_input("Residual Sugar", value=20.7)
chlorides = st.number_input("Chlorides", value=0.045)
free_sulfur_dioxide = st.number_input("Free Sulfur Dioxide", value=45.0)
total_sulfur_dioxide = st.number_input("Total Sulfur Dioxide", value=170.0)
density = st.number_input("Density", value=1.001)
pH = st.number_input("pH", value=3.0)
sulphates = st.number_input("Sulphates", value=0.45)
alcohol = st.number_input("Alcohol", value=8.8)
# Predict button
if st.button("Predict Quality"):
# Prepare the payload
payload = {
"fixed_acidity": fixed_acidity,
"volatile_acidity": volatile_acidity,
"citric_acid": citric_acid,
"residual_sugar": residual_sugar,
"chlorides": chlorides,
"free_sulfur_dioxide": free_sulfur_dioxide,
"total_sulfur_dioxide": total_sulfur_dioxide,
"density": density,
"pH": pH,
"sulphates": sulphates,
"alcohol": alcohol
}
# Communicate with the FastAPI `/predict` endpoint
try:
response = requests.post(f"{BASE_URL}/predict", json=payload)
if response.status_code == 200:
result = response.json()
st.success(f"Predicted Wine Quality: {result['predicted_quality']}")
else:
st.error("Failed to get prediction from the backend.")
except Exception as e:
st.error(f"Error: {e}")
# Streamlit page navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Home", "Predict Wine Quality"])
if page == "Home":
home_page()
elif page == "Predict Wine Quality":
prediction_page()

BIN
best_model.pkl Normal file

Binary file not shown.

61
main.py Normal file
View File

@ -0,0 +1,61 @@
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
import pandas as pd
# Load the saved model and scaler
model = joblib.load('best_model.pkl')
scaler = joblib.load('scaler.pkl')
app = FastAPI()
class WineFeatures(BaseModel):
fixed_acidity: float
volatile_acidity: float
citric_acid: float
residual_sugar: float
chlorides: float
free_sulfur_dioxide: float
total_sulfur_dioxide: float
density: float
pH: float
sulphates: float
alcohol: float
@app.get("/")
def home():
return {
"message": "Welcome to the Wine Quality Prediction API! Use the /predict endpoint to predict wine quality."
}
# Define the prediction endpoint
@app.post("/predict")
def predict(wine: WineFeatures):
# Extract the features from the incoming request
features = np.array([
[
wine.fixed_acidity,
wine.volatile_acidity,
wine.citric_acid,
wine.residual_sugar,
wine.chlorides,
wine.free_sulfur_dioxide,
wine.total_sulfur_dioxide,
wine.density,
wine.pH,
wine.sulphates,
wine.alcohol
]
])
# Scale the input features using the saved scaler
scaled_features = scaler.transform(features)
# Make the prediction using the loaded model
prediction = model.predict(scaled_features)
# Return the prediction (wine quality)
return {"predicted_quality": str(prediction[0])}

BIN
scaler.pkl Normal file

Binary file not shown.

1600
winequality-red.csv Normal file

File diff suppressed because it is too large Load Diff