init
This commit is contained in:
commit
2a7969d0d5
954
Untitled.ipynb
Normal file
954
Untitled.ipynb
Normal file
@ -0,0 +1,954 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "aaad2c40-c784-466b-9f49-7a60e1cbdb4c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import seaborn as sns\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import warnings\n",
|
||||
"from sklearn.preprocessing import MinMaxScaler\n",
|
||||
"from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score\n",
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"from sklearn import metrics\n",
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||||
"from sklearn.naive_bayes import GaussianNB\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "23f82a60-419e-4af2-9401-1996bec33332",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>fixed acidity</th>\n",
|
||||
" <th>volatile acidity</th>\n",
|
||||
" <th>citric acid</th>\n",
|
||||
" <th>residual sugar</th>\n",
|
||||
" <th>chlorides</th>\n",
|
||||
" <th>free sulfur dioxide</th>\n",
|
||||
" <th>total sulfur dioxide</th>\n",
|
||||
" <th>density</th>\n",
|
||||
" <th>pH</th>\n",
|
||||
" <th>sulphates</th>\n",
|
||||
" <th>alcohol</th>\n",
|
||||
" <th>quality</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>7.4</td>\n",
|
||||
" <td>0.70</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" <td>0.076</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>34.0</td>\n",
|
||||
" <td>0.9978</td>\n",
|
||||
" <td>3.51</td>\n",
|
||||
" <td>0.56</td>\n",
|
||||
" <td>9.4</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>7.8</td>\n",
|
||||
" <td>0.88</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>2.6</td>\n",
|
||||
" <td>0.098</td>\n",
|
||||
" <td>25.0</td>\n",
|
||||
" <td>67.0</td>\n",
|
||||
" <td>0.9968</td>\n",
|
||||
" <td>3.20</td>\n",
|
||||
" <td>0.68</td>\n",
|
||||
" <td>9.8</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>7.8</td>\n",
|
||||
" <td>0.76</td>\n",
|
||||
" <td>0.04</td>\n",
|
||||
" <td>2.3</td>\n",
|
||||
" <td>0.092</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>54.0</td>\n",
|
||||
" <td>0.9970</td>\n",
|
||||
" <td>3.26</td>\n",
|
||||
" <td>0.65</td>\n",
|
||||
" <td>9.8</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>11.2</td>\n",
|
||||
" <td>0.28</td>\n",
|
||||
" <td>0.56</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" <td>0.075</td>\n",
|
||||
" <td>17.0</td>\n",
|
||||
" <td>60.0</td>\n",
|
||||
" <td>0.9980</td>\n",
|
||||
" <td>3.16</td>\n",
|
||||
" <td>0.58</td>\n",
|
||||
" <td>9.8</td>\n",
|
||||
" <td>6</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>7.4</td>\n",
|
||||
" <td>0.70</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" <td>0.076</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>34.0</td>\n",
|
||||
" <td>0.9978</td>\n",
|
||||
" <td>3.51</td>\n",
|
||||
" <td>0.56</td>\n",
|
||||
" <td>9.4</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
|
||||
"0 7.4 0.70 0.00 1.9 0.076 \n",
|
||||
"1 7.8 0.88 0.00 2.6 0.098 \n",
|
||||
"2 7.8 0.76 0.04 2.3 0.092 \n",
|
||||
"3 11.2 0.28 0.56 1.9 0.075 \n",
|
||||
"4 7.4 0.70 0.00 1.9 0.076 \n",
|
||||
"\n",
|
||||
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
|
||||
"0 11.0 34.0 0.9978 3.51 0.56 \n",
|
||||
"1 25.0 67.0 0.9968 3.20 0.68 \n",
|
||||
"2 15.0 54.0 0.9970 3.26 0.65 \n",
|
||||
"3 17.0 60.0 0.9980 3.16 0.58 \n",
|
||||
"4 11.0 34.0 0.9978 3.51 0.56 \n",
|
||||
"\n",
|
||||
" alcohol quality \n",
|
||||
"0 9.4 5 \n",
|
||||
"1 9.8 5 \n",
|
||||
"2 9.8 5 \n",
|
||||
"3 9.8 6 \n",
|
||||
"4 9.4 5 "
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data = pd.read_csv('winequality-red.csv')\n",
|
||||
"data.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "cbb402ec-23b9-4360-849a-2fccd12d48e8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||
"RangeIndex: 1599 entries, 0 to 1598\n",
|
||||
"Data columns (total 12 columns):\n",
|
||||
" # Column Non-Null Count Dtype \n",
|
||||
"--- ------ -------------- ----- \n",
|
||||
" 0 fixed acidity 1599 non-null float64\n",
|
||||
" 1 volatile acidity 1599 non-null float64\n",
|
||||
" 2 citric acid 1599 non-null float64\n",
|
||||
" 3 residual sugar 1599 non-null float64\n",
|
||||
" 4 chlorides 1599 non-null float64\n",
|
||||
" 5 free sulfur dioxide 1599 non-null float64\n",
|
||||
" 6 total sulfur dioxide 1599 non-null float64\n",
|
||||
" 7 density 1599 non-null float64\n",
|
||||
" 8 pH 1599 non-null float64\n",
|
||||
" 9 sulphates 1599 non-null float64\n",
|
||||
" 10 alcohol 1599 non-null float64\n",
|
||||
" 11 quality 1599 non-null int64 \n",
|
||||
"dtypes: float64(11), int64(1)\n",
|
||||
"memory usage: 150.0 KB\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<style type=\"text/css\">\n",
|
||||
"#T_c3f81_row0_col0, #T_c3f81_row1_col0, #T_c3f81_row1_col7, #T_c3f81_row2_col0, #T_c3f81_row2_col3, #T_c3f81_row2_col4, #T_c3f81_row2_col7, #T_c3f81_row3_col0, #T_c3f81_row4_col0, #T_c3f81_row4_col1, #T_c3f81_row4_col2, #T_c3f81_row4_col3, #T_c3f81_row4_col4, #T_c3f81_row4_col5, #T_c3f81_row4_col6, #T_c3f81_row4_col7, #T_c3f81_row5_col0, #T_c3f81_row6_col0, #T_c3f81_row7_col0, #T_c3f81_row7_col2, #T_c3f81_row7_col7, #T_c3f81_row8_col0, #T_c3f81_row9_col0, #T_c3f81_row10_col0, #T_c3f81_row11_col0 {\n",
|
||||
" background-color: #fff7fb;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row0_col1, #T_c3f81_row10_col6 {\n",
|
||||
" background-color: #e0deed;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row0_col2, #T_c3f81_row0_col7, #T_c3f81_row3_col1, #T_c3f81_row3_col5, #T_c3f81_row3_col7, #T_c3f81_row8_col6 {\n",
|
||||
" background-color: #f7f0f7;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row0_col3 {\n",
|
||||
" background-color: #5c9fc9;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row0_col4 {\n",
|
||||
" background-color: #b8c6e0;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row0_col5 {\n",
|
||||
" background-color: #dad9ea;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row0_col6 {\n",
|
||||
" background-color: #e7e3f0;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row1_col1, #T_c3f81_row1_col5, #T_c3f81_row1_col6, #T_c3f81_row9_col6 {\n",
|
||||
" background-color: #fef6fa;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row1_col2, #T_c3f81_row2_col1, #T_c3f81_row2_col2, #T_c3f81_row2_col5, #T_c3f81_row2_col6, #T_c3f81_row8_col2, #T_c3f81_row9_col2, #T_c3f81_row9_col7 {\n",
|
||||
" background-color: #fef6fb;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row1_col3, #T_c3f81_row1_col4, #T_c3f81_row7_col6, #T_c3f81_row8_col7, #T_c3f81_row9_col1, #T_c3f81_row9_col5 {\n",
|
||||
" background-color: #fdf5fa;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row3_col2, #T_c3f81_row3_col6, #T_c3f81_row7_col4, #T_c3f81_row9_col3 {\n",
|
||||
" background-color: #f9f2f8;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row3_col3 {\n",
|
||||
" background-color: #efe9f3;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row3_col4, #T_c3f81_row8_col5 {\n",
|
||||
" background-color: #f2ecf5;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row5_col1 {\n",
|
||||
" background-color: #b1c2de;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row5_col2 {\n",
|
||||
" background-color: #b9c6e0;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row5_col3, #T_c3f81_row7_col3, #T_c3f81_row11_col1 {\n",
|
||||
" background-color: #ede8f3;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row5_col4 {\n",
|
||||
" background-color: #bbc7e0;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row5_col5 {\n",
|
||||
" background-color: #a9bfdc;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row5_col6 {\n",
|
||||
" background-color: #b3c3de;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row5_col7 {\n",
|
||||
" background-color: #d1d2e6;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row6_col1, #T_c3f81_row6_col2, #T_c3f81_row6_col4, #T_c3f81_row6_col5, #T_c3f81_row6_col6, #T_c3f81_row6_col7, #T_c3f81_row10_col3 {\n",
|
||||
" background-color: #023858;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row6_col3 {\n",
|
||||
" background-color: #1379b5;\n",
|
||||
" color: #f1f1f1;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row7_col1, #T_c3f81_row9_col4 {\n",
|
||||
" background-color: #fcf4fa;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row7_col5, #T_c3f81_row11_col2, #T_c3f81_row11_col7 {\n",
|
||||
" background-color: #fbf4f9;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row8_col1 {\n",
|
||||
" background-color: #f5eef6;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row8_col3 {\n",
|
||||
" background-color: #b7c5df;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row8_col4 {\n",
|
||||
" background-color: #e8e4f0;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row10_col1, #T_c3f81_row11_col4 {\n",
|
||||
" background-color: #d6d6e9;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row10_col2 {\n",
|
||||
" background-color: #faf3f9;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row10_col4 {\n",
|
||||
" background-color: #8fb4d6;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row10_col5 {\n",
|
||||
" background-color: #cacee5;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row10_col7 {\n",
|
||||
" background-color: #f8f1f8;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row11_col3 {\n",
|
||||
" background-color: #acc0dd;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row11_col5 {\n",
|
||||
" background-color: #e6e2ef;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"#T_c3f81_row11_col6 {\n",
|
||||
" background-color: #f1ebf4;\n",
|
||||
" color: #000000;\n",
|
||||
"}\n",
|
||||
"</style>\n",
|
||||
"<table id=\"T_c3f81\">\n",
|
||||
" <thead>\n",
|
||||
" <tr>\n",
|
||||
" <th class=\"blank level0\" > </th>\n",
|
||||
" <th id=\"T_c3f81_level0_col0\" class=\"col_heading level0 col0\" >count</th>\n",
|
||||
" <th id=\"T_c3f81_level0_col1\" class=\"col_heading level0 col1\" >mean</th>\n",
|
||||
" <th id=\"T_c3f81_level0_col2\" class=\"col_heading level0 col2\" >std</th>\n",
|
||||
" <th id=\"T_c3f81_level0_col3\" class=\"col_heading level0 col3\" >min</th>\n",
|
||||
" <th id=\"T_c3f81_level0_col4\" class=\"col_heading level0 col4\" >25%</th>\n",
|
||||
" <th id=\"T_c3f81_level0_col5\" class=\"col_heading level0 col5\" >50%</th>\n",
|
||||
" <th id=\"T_c3f81_level0_col6\" class=\"col_heading level0 col6\" >75%</th>\n",
|
||||
" <th id=\"T_c3f81_level0_col7\" class=\"col_heading level0 col7\" >max</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row0\" class=\"row_heading level0 row0\" >fixed acidity</th>\n",
|
||||
" <td id=\"T_c3f81_row0_col0\" class=\"data row0 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row0_col1\" class=\"data row0 col1\" >8.319637</td>\n",
|
||||
" <td id=\"T_c3f81_row0_col2\" class=\"data row0 col2\" >1.741096</td>\n",
|
||||
" <td id=\"T_c3f81_row0_col3\" class=\"data row0 col3\" >4.600000</td>\n",
|
||||
" <td id=\"T_c3f81_row0_col4\" class=\"data row0 col4\" >7.100000</td>\n",
|
||||
" <td id=\"T_c3f81_row0_col5\" class=\"data row0 col5\" >7.900000</td>\n",
|
||||
" <td id=\"T_c3f81_row0_col6\" class=\"data row0 col6\" >9.200000</td>\n",
|
||||
" <td id=\"T_c3f81_row0_col7\" class=\"data row0 col7\" >15.900000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row1\" class=\"row_heading level0 row1\" >volatile acidity</th>\n",
|
||||
" <td id=\"T_c3f81_row1_col0\" class=\"data row1 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row1_col1\" class=\"data row1 col1\" >0.527821</td>\n",
|
||||
" <td id=\"T_c3f81_row1_col2\" class=\"data row1 col2\" >0.179060</td>\n",
|
||||
" <td id=\"T_c3f81_row1_col3\" class=\"data row1 col3\" >0.120000</td>\n",
|
||||
" <td id=\"T_c3f81_row1_col4\" class=\"data row1 col4\" >0.390000</td>\n",
|
||||
" <td id=\"T_c3f81_row1_col5\" class=\"data row1 col5\" >0.520000</td>\n",
|
||||
" <td id=\"T_c3f81_row1_col6\" class=\"data row1 col6\" >0.640000</td>\n",
|
||||
" <td id=\"T_c3f81_row1_col7\" class=\"data row1 col7\" >1.580000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row2\" class=\"row_heading level0 row2\" >citric acid</th>\n",
|
||||
" <td id=\"T_c3f81_row2_col0\" class=\"data row2 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row2_col1\" class=\"data row2 col1\" >0.270976</td>\n",
|
||||
" <td id=\"T_c3f81_row2_col2\" class=\"data row2 col2\" >0.194801</td>\n",
|
||||
" <td id=\"T_c3f81_row2_col3\" class=\"data row2 col3\" >0.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row2_col4\" class=\"data row2 col4\" >0.090000</td>\n",
|
||||
" <td id=\"T_c3f81_row2_col5\" class=\"data row2 col5\" >0.260000</td>\n",
|
||||
" <td id=\"T_c3f81_row2_col6\" class=\"data row2 col6\" >0.420000</td>\n",
|
||||
" <td id=\"T_c3f81_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row3\" class=\"row_heading level0 row3\" >residual sugar</th>\n",
|
||||
" <td id=\"T_c3f81_row3_col0\" class=\"data row3 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row3_col1\" class=\"data row3 col1\" >2.538806</td>\n",
|
||||
" <td id=\"T_c3f81_row3_col2\" class=\"data row3 col2\" >1.409928</td>\n",
|
||||
" <td id=\"T_c3f81_row3_col3\" class=\"data row3 col3\" >0.900000</td>\n",
|
||||
" <td id=\"T_c3f81_row3_col4\" class=\"data row3 col4\" >1.900000</td>\n",
|
||||
" <td id=\"T_c3f81_row3_col5\" class=\"data row3 col5\" >2.200000</td>\n",
|
||||
" <td id=\"T_c3f81_row3_col6\" class=\"data row3 col6\" >2.600000</td>\n",
|
||||
" <td id=\"T_c3f81_row3_col7\" class=\"data row3 col7\" >15.500000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row4\" class=\"row_heading level0 row4\" >chlorides</th>\n",
|
||||
" <td id=\"T_c3f81_row4_col0\" class=\"data row4 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row4_col1\" class=\"data row4 col1\" >0.087467</td>\n",
|
||||
" <td id=\"T_c3f81_row4_col2\" class=\"data row4 col2\" >0.047065</td>\n",
|
||||
" <td id=\"T_c3f81_row4_col3\" class=\"data row4 col3\" >0.012000</td>\n",
|
||||
" <td id=\"T_c3f81_row4_col4\" class=\"data row4 col4\" >0.070000</td>\n",
|
||||
" <td id=\"T_c3f81_row4_col5\" class=\"data row4 col5\" >0.079000</td>\n",
|
||||
" <td id=\"T_c3f81_row4_col6\" class=\"data row4 col6\" >0.090000</td>\n",
|
||||
" <td id=\"T_c3f81_row4_col7\" class=\"data row4 col7\" >0.611000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row5\" class=\"row_heading level0 row5\" >free sulfur dioxide</th>\n",
|
||||
" <td id=\"T_c3f81_row5_col0\" class=\"data row5 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row5_col1\" class=\"data row5 col1\" >15.874922</td>\n",
|
||||
" <td id=\"T_c3f81_row5_col2\" class=\"data row5 col2\" >10.460157</td>\n",
|
||||
" <td id=\"T_c3f81_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row5_col4\" class=\"data row5 col4\" >7.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row5_col5\" class=\"data row5 col5\" >14.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row5_col6\" class=\"data row5 col6\" >21.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row5_col7\" class=\"data row5 col7\" >72.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row6\" class=\"row_heading level0 row6\" >total sulfur dioxide</th>\n",
|
||||
" <td id=\"T_c3f81_row6_col0\" class=\"data row6 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row6_col1\" class=\"data row6 col1\" >46.467792</td>\n",
|
||||
" <td id=\"T_c3f81_row6_col2\" class=\"data row6 col2\" >32.895324</td>\n",
|
||||
" <td id=\"T_c3f81_row6_col3\" class=\"data row6 col3\" >6.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row6_col4\" class=\"data row6 col4\" >22.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row6_col5\" class=\"data row6 col5\" >38.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row6_col6\" class=\"data row6 col6\" >62.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row6_col7\" class=\"data row6 col7\" >289.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row7\" class=\"row_heading level0 row7\" >density</th>\n",
|
||||
" <td id=\"T_c3f81_row7_col0\" class=\"data row7 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row7_col1\" class=\"data row7 col1\" >0.996747</td>\n",
|
||||
" <td id=\"T_c3f81_row7_col2\" class=\"data row7 col2\" >0.001887</td>\n",
|
||||
" <td id=\"T_c3f81_row7_col3\" class=\"data row7 col3\" >0.990070</td>\n",
|
||||
" <td id=\"T_c3f81_row7_col4\" class=\"data row7 col4\" >0.995600</td>\n",
|
||||
" <td id=\"T_c3f81_row7_col5\" class=\"data row7 col5\" >0.996750</td>\n",
|
||||
" <td id=\"T_c3f81_row7_col6\" class=\"data row7 col6\" >0.997835</td>\n",
|
||||
" <td id=\"T_c3f81_row7_col7\" class=\"data row7 col7\" >1.003690</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row8\" class=\"row_heading level0 row8\" >pH</th>\n",
|
||||
" <td id=\"T_c3f81_row8_col0\" class=\"data row8 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row8_col1\" class=\"data row8 col1\" >3.311113</td>\n",
|
||||
" <td id=\"T_c3f81_row8_col2\" class=\"data row8 col2\" >0.154386</td>\n",
|
||||
" <td id=\"T_c3f81_row8_col3\" class=\"data row8 col3\" >2.740000</td>\n",
|
||||
" <td id=\"T_c3f81_row8_col4\" class=\"data row8 col4\" >3.210000</td>\n",
|
||||
" <td id=\"T_c3f81_row8_col5\" class=\"data row8 col5\" >3.310000</td>\n",
|
||||
" <td id=\"T_c3f81_row8_col6\" class=\"data row8 col6\" >3.400000</td>\n",
|
||||
" <td id=\"T_c3f81_row8_col7\" class=\"data row8 col7\" >4.010000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row9\" class=\"row_heading level0 row9\" >sulphates</th>\n",
|
||||
" <td id=\"T_c3f81_row9_col0\" class=\"data row9 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row9_col1\" class=\"data row9 col1\" >0.658149</td>\n",
|
||||
" <td id=\"T_c3f81_row9_col2\" class=\"data row9 col2\" >0.169507</td>\n",
|
||||
" <td id=\"T_c3f81_row9_col3\" class=\"data row9 col3\" >0.330000</td>\n",
|
||||
" <td id=\"T_c3f81_row9_col4\" class=\"data row9 col4\" >0.550000</td>\n",
|
||||
" <td id=\"T_c3f81_row9_col5\" class=\"data row9 col5\" >0.620000</td>\n",
|
||||
" <td id=\"T_c3f81_row9_col6\" class=\"data row9 col6\" >0.730000</td>\n",
|
||||
" <td id=\"T_c3f81_row9_col7\" class=\"data row9 col7\" >2.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row10\" class=\"row_heading level0 row10\" >alcohol</th>\n",
|
||||
" <td id=\"T_c3f81_row10_col0\" class=\"data row10 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row10_col1\" class=\"data row10 col1\" >10.422983</td>\n",
|
||||
" <td id=\"T_c3f81_row10_col2\" class=\"data row10 col2\" >1.065668</td>\n",
|
||||
" <td id=\"T_c3f81_row10_col3\" class=\"data row10 col3\" >8.400000</td>\n",
|
||||
" <td id=\"T_c3f81_row10_col4\" class=\"data row10 col4\" >9.500000</td>\n",
|
||||
" <td id=\"T_c3f81_row10_col5\" class=\"data row10 col5\" >10.200000</td>\n",
|
||||
" <td id=\"T_c3f81_row10_col6\" class=\"data row10 col6\" >11.100000</td>\n",
|
||||
" <td id=\"T_c3f81_row10_col7\" class=\"data row10 col7\" >14.900000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th id=\"T_c3f81_level0_row11\" class=\"row_heading level0 row11\" >quality</th>\n",
|
||||
" <td id=\"T_c3f81_row11_col0\" class=\"data row11 col0\" >1599.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row11_col1\" class=\"data row11 col1\" >5.636023</td>\n",
|
||||
" <td id=\"T_c3f81_row11_col2\" class=\"data row11 col2\" >0.807569</td>\n",
|
||||
" <td id=\"T_c3f81_row11_col3\" class=\"data row11 col3\" >3.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row11_col4\" class=\"data row11 col4\" >5.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row11_col5\" class=\"data row11 col5\" >6.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row11_col6\" class=\"data row11 col6\" >6.000000</td>\n",
|
||||
" <td id=\"T_c3f81_row11_col7\" class=\"data row11 col7\" >8.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<pandas.io.formats.style.Styler at 0x2ded8e0b820>"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data.info()\n",
|
||||
"data.describe().T.style.background_gradient(axis=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "4b1487c6-ae6c-4ee6-a0e1-03a156654bda",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"fixed acidity 0\n",
|
||||
"volatile acidity 0\n",
|
||||
"citric acid 0\n",
|
||||
"residual sugar 0\n",
|
||||
"chlorides 0\n",
|
||||
"free sulfur dioxide 0\n",
|
||||
"total sulfur dioxide 0\n",
|
||||
"density 0\n",
|
||||
"pH 0\n",
|
||||
"sulphates 0\n",
|
||||
"alcohol 0\n",
|
||||
"quality 0\n",
|
||||
"dtype: int64"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data.isna().sum()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "2d090c6d-5de2-41d5-a098-3516dcdb7155",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>fixed acidity</th>\n",
|
||||
" <th>volatile acidity</th>\n",
|
||||
" <th>citric acid</th>\n",
|
||||
" <th>residual sugar</th>\n",
|
||||
" <th>chlorides</th>\n",
|
||||
" <th>free sulfur dioxide</th>\n",
|
||||
" <th>total sulfur dioxide</th>\n",
|
||||
" <th>density</th>\n",
|
||||
" <th>pH</th>\n",
|
||||
" <th>sulphates</th>\n",
|
||||
" <th>alcohol</th>\n",
|
||||
" <th>quality</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>7.4</td>\n",
|
||||
" <td>0.70</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" <td>0.076</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>34.0</td>\n",
|
||||
" <td>0.9978</td>\n",
|
||||
" <td>3.51</td>\n",
|
||||
" <td>0.56</td>\n",
|
||||
" <td>9.4</td>\n",
|
||||
" <td>Average</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>7.8</td>\n",
|
||||
" <td>0.88</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>2.6</td>\n",
|
||||
" <td>0.098</td>\n",
|
||||
" <td>25.0</td>\n",
|
||||
" <td>67.0</td>\n",
|
||||
" <td>0.9968</td>\n",
|
||||
" <td>3.20</td>\n",
|
||||
" <td>0.68</td>\n",
|
||||
" <td>9.8</td>\n",
|
||||
" <td>Average</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>7.8</td>\n",
|
||||
" <td>0.76</td>\n",
|
||||
" <td>0.04</td>\n",
|
||||
" <td>2.3</td>\n",
|
||||
" <td>0.092</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>54.0</td>\n",
|
||||
" <td>0.9970</td>\n",
|
||||
" <td>3.26</td>\n",
|
||||
" <td>0.65</td>\n",
|
||||
" <td>9.8</td>\n",
|
||||
" <td>Average</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>11.2</td>\n",
|
||||
" <td>0.28</td>\n",
|
||||
" <td>0.56</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" <td>0.075</td>\n",
|
||||
" <td>17.0</td>\n",
|
||||
" <td>60.0</td>\n",
|
||||
" <td>0.9980</td>\n",
|
||||
" <td>3.16</td>\n",
|
||||
" <td>0.58</td>\n",
|
||||
" <td>9.8</td>\n",
|
||||
" <td>Average</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>7.4</td>\n",
|
||||
" <td>0.70</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" <td>0.076</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>34.0</td>\n",
|
||||
" <td>0.9978</td>\n",
|
||||
" <td>3.51</td>\n",
|
||||
" <td>0.56</td>\n",
|
||||
" <td>9.4</td>\n",
|
||||
" <td>Average</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
|
||||
"0 7.4 0.70 0.00 1.9 0.076 \n",
|
||||
"1 7.8 0.88 0.00 2.6 0.098 \n",
|
||||
"2 7.8 0.76 0.04 2.3 0.092 \n",
|
||||
"3 11.2 0.28 0.56 1.9 0.075 \n",
|
||||
"4 7.4 0.70 0.00 1.9 0.076 \n",
|
||||
"\n",
|
||||
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
|
||||
"0 11.0 34.0 0.9978 3.51 0.56 \n",
|
||||
"1 25.0 67.0 0.9968 3.20 0.68 \n",
|
||||
"2 15.0 54.0 0.9970 3.26 0.65 \n",
|
||||
"3 17.0 60.0 0.9980 3.16 0.58 \n",
|
||||
"4 11.0 34.0 0.9978 3.51 0.56 \n",
|
||||
"\n",
|
||||
" alcohol quality \n",
|
||||
"0 9.4 Average \n",
|
||||
"1 9.8 Average \n",
|
||||
"2 9.8 Average \n",
|
||||
"3 9.8 Average \n",
|
||||
"4 9.4 Average "
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data = data.replace({'quality': {8: 'Good', 7: 'Good', 6: 'Average', 5: 'Average', 4: 'Bad', 3: 'Bad'}})\n",
|
||||
"\n",
|
||||
"data.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "0d672775-f7ef-468e-ac49-a6ab5254ee3a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X = data.drop(columns='quality')\n",
|
||||
"y = data.quality\n",
|
||||
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
|
||||
"X_scaled = scaler.fit_transform(X)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "32357ab4-74d6-4a32-8ff9-249e04984341",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "13b6ce95-34a6-48ea-a4b8-bdd69c3ca496",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"models = {\n",
|
||||
" \"Logistic Regression\": LogisticRegression(),\n",
|
||||
" \"SVM\": SVC(),\n",
|
||||
" \"Random Forest\": RandomForestClassifier(),\n",
|
||||
" \"Decision Tree\": DecisionTreeClassifier(),\n",
|
||||
" \"K-Nearest Neighbors\": KNeighborsClassifier(),\n",
|
||||
" \"Naive Bayes\": GaussianNB()\n",
|
||||
"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "9096f5a4-8329-47d0-9b7c-8e3123191a3b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training Logistic Regression...\n",
|
||||
"Logistic Regression Accuracy: 0.828125\n",
|
||||
"Training SVM...\n",
|
||||
"SVM Accuracy: 0.8375\n",
|
||||
"Training Random Forest...\n",
|
||||
"Random Forest Accuracy: 0.865625\n",
|
||||
"Training Decision Tree...\n",
|
||||
"Decision Tree Accuracy: 0.79375\n",
|
||||
"Training K-Nearest Neighbors...\n",
|
||||
"K-Nearest Neighbors Accuracy: 0.840625\n",
|
||||
"Training Naive Bayes...\n",
|
||||
"Naive Bayes Accuracy: 0.803125\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"results = {}\n",
|
||||
"for model_name, model in models.items():\n",
|
||||
" print(f\"Training {model_name}...\")\n",
|
||||
" model.fit(X_train, y_train)\n",
|
||||
" y_pred = model.predict(X_test)\n",
|
||||
" accuracy = metrics.accuracy_score(y_test, y_pred)\n",
|
||||
" results[model_name] = accuracy\n",
|
||||
" print(f\"{model_name} Accuracy: {accuracy}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "d8f786c8-67e1-4131-bb2c-22d91d7fafab",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Model Comparison (Accuracy):\n",
|
||||
"Random Forest: 0.8656\n",
|
||||
"K-Nearest Neighbors: 0.8406\n",
|
||||
"SVM: 0.8375\n",
|
||||
"Logistic Regression: 0.8281\n",
|
||||
"Naive Bayes: 0.8031\n",
|
||||
"Decision Tree: 0.7937\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)\n",
|
||||
"print(\"\\nModel Comparison (Accuracy):\")\n",
|
||||
"for model_name, accuracy in sorted_results:\n",
|
||||
" print(f\"{model_name}: {accuracy:.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "b6eec30b-3a13-4f99-b841-a8633a8b4ad7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fitting 5 folds for each of 27 candidates, totalling 135 fits\n",
|
||||
"Best parameters for Random Forest: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Example: Hyperparameter tuning for Random Forest\n",
|
||||
"param_grid = {\n",
|
||||
" 'n_estimators': [50, 100, 200],\n",
|
||||
" 'max_depth': [10, 20, 30],\n",
|
||||
" 'min_samples_split': [2, 5, 10]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5, n_jobs=-1, verbose=1)\n",
|
||||
"grid_search.fit(X_train, y_train)\n",
|
||||
"print(f\"Best parameters for Random Forest: {grid_search.best_params_}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "488c6fe7-1f79-4e6a-9649-fd82c0acd44e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Cross-validation scores for Random Forest: [0.834375 0.821875 0.840625 0.8125 0.84326019]\n",
|
||||
"Mean cross-validation score: 0.830527037617555\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Cross-validation score for the best model (Random Forest)\n",
|
||||
"best_model = grid_search.best_estimator_\n",
|
||||
"cv_scores = cross_val_score(best_model, X_scaled, y, cv=5)\n",
|
||||
"print(f\"Cross-validation scores for Random Forest: {cv_scores}\")\n",
|
||||
"print(f\"Mean cross-validation score: {cv_scores.mean()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "81b08573-19a3-4508-9c9e-4a36ced59197",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Classification Report:\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Average 0.88 0.95 0.92 262\n",
|
||||
" Bad 1.00 0.00 0.00 11\n",
|
||||
" Good 0.67 0.51 0.58 47\n",
|
||||
"\n",
|
||||
" accuracy 0.86 320\n",
|
||||
" macro avg 0.85 0.49 0.50 320\n",
|
||||
"weighted avg 0.85 0.86 0.83 320\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Classification report with zero_division parameter to handle undefined precision\n",
|
||||
"y_pred = best_model.predict(X_test)\n",
|
||||
"print(\"\\nClassification Report:\")\n",
|
||||
"print(classification_report(y_test, y_pred, zero_division=1))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "191fa955-1969-4224-8420-a608c1463e31",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model and Scaler have been saved.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import joblib\n",
|
||||
"\n",
|
||||
"# Save the best model (Random Forest with Hyperparameter Tuning)\n",
|
||||
"joblib.dump(best_model, 'best_model.pkl')\n",
|
||||
"\n",
|
||||
"# Save the scaler (MinMax SScaler used for feature scaling)\n",
|
||||
"joblib.dump(scaler, 'scaler.pkl')\n",
|
||||
"\n",
|
||||
"print(\"Model and Scaler have been saved.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8f75662b-1b38-4510-bccd-adefabbe635f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
75
app.py
Normal file
75
app.py
Normal file
@ -0,0 +1,75 @@
|
||||
import streamlit as st
|
||||
import requests
|
||||
|
||||
# FastAPI backend URL
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
|
||||
# Define pages
|
||||
def home_page():
|
||||
st.title("Wine Quality Prediction API")
|
||||
st.write("This is the home page of the Wine Quality Prediction App.")
|
||||
|
||||
# Communicate with the FastAPI `/` endpoint
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
st.success(data.get("message", "Welcome to the API!"))
|
||||
else:
|
||||
st.error("Failed to fetch data from the backend.")
|
||||
except Exception as e:
|
||||
st.error(f"Error: {e}")
|
||||
|
||||
def prediction_page():
|
||||
st.title("Wine Quality Prediction")
|
||||
st.write("Enter the wine features to predict its quality:")
|
||||
|
||||
# Input fields for wine features
|
||||
fixed_acidity = st.number_input("Fixed Acidity", value=7.0)
|
||||
volatile_acidity = st.number_input("Volatile Acidity", value=0.27)
|
||||
citric_acid = st.number_input("Citric Acid", value=0.36)
|
||||
residual_sugar = st.number_input("Residual Sugar", value=20.7)
|
||||
chlorides = st.number_input("Chlorides", value=0.045)
|
||||
free_sulfur_dioxide = st.number_input("Free Sulfur Dioxide", value=45.0)
|
||||
total_sulfur_dioxide = st.number_input("Total Sulfur Dioxide", value=170.0)
|
||||
density = st.number_input("Density", value=1.001)
|
||||
pH = st.number_input("pH", value=3.0)
|
||||
sulphates = st.number_input("Sulphates", value=0.45)
|
||||
alcohol = st.number_input("Alcohol", value=8.8)
|
||||
|
||||
# Predict button
|
||||
if st.button("Predict Quality"):
|
||||
# Prepare the payload
|
||||
payload = {
|
||||
"fixed_acidity": fixed_acidity,
|
||||
"volatile_acidity": volatile_acidity,
|
||||
"citric_acid": citric_acid,
|
||||
"residual_sugar": residual_sugar,
|
||||
"chlorides": chlorides,
|
||||
"free_sulfur_dioxide": free_sulfur_dioxide,
|
||||
"total_sulfur_dioxide": total_sulfur_dioxide,
|
||||
"density": density,
|
||||
"pH": pH,
|
||||
"sulphates": sulphates,
|
||||
"alcohol": alcohol
|
||||
}
|
||||
|
||||
# Communicate with the FastAPI `/predict` endpoint
|
||||
try:
|
||||
response = requests.post(f"{BASE_URL}/predict", json=payload)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
st.success(f"Predicted Wine Quality: {result['predicted_quality']}")
|
||||
else:
|
||||
st.error("Failed to get prediction from the backend.")
|
||||
except Exception as e:
|
||||
st.error(f"Error: {e}")
|
||||
|
||||
# Streamlit page navigation
|
||||
st.sidebar.title("Navigation")
|
||||
page = st.sidebar.radio("Go to", ["Home", "Predict Wine Quality"])
|
||||
|
||||
if page == "Home":
|
||||
home_page()
|
||||
elif page == "Predict Wine Quality":
|
||||
prediction_page()
|
||||
BIN
best_model.pkl
Normal file
BIN
best_model.pkl
Normal file
Binary file not shown.
61
main.py
Normal file
61
main.py
Normal file
@ -0,0 +1,61 @@
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Load the saved model and scaler
|
||||
model = joblib.load('best_model.pkl')
|
||||
scaler = joblib.load('scaler.pkl')
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class WineFeatures(BaseModel):
|
||||
fixed_acidity: float
|
||||
volatile_acidity: float
|
||||
citric_acid: float
|
||||
residual_sugar: float
|
||||
chlorides: float
|
||||
free_sulfur_dioxide: float
|
||||
total_sulfur_dioxide: float
|
||||
density: float
|
||||
pH: float
|
||||
sulphates: float
|
||||
alcohol: float
|
||||
|
||||
@app.get("/")
|
||||
def home():
|
||||
return {
|
||||
"message": "Welcome to the Wine Quality Prediction API! Use the /predict endpoint to predict wine quality."
|
||||
}
|
||||
|
||||
|
||||
# Define the prediction endpoint
|
||||
@app.post("/predict")
|
||||
def predict(wine: WineFeatures):
|
||||
# Extract the features from the incoming request
|
||||
features = np.array([
|
||||
[
|
||||
wine.fixed_acidity,
|
||||
wine.volatile_acidity,
|
||||
wine.citric_acid,
|
||||
wine.residual_sugar,
|
||||
wine.chlorides,
|
||||
wine.free_sulfur_dioxide,
|
||||
wine.total_sulfur_dioxide,
|
||||
wine.density,
|
||||
wine.pH,
|
||||
wine.sulphates,
|
||||
wine.alcohol
|
||||
]
|
||||
])
|
||||
# Scale the input features using the saved scaler
|
||||
scaled_features = scaler.transform(features)
|
||||
|
||||
|
||||
# Make the prediction using the loaded model
|
||||
prediction = model.predict(scaled_features)
|
||||
|
||||
# Return the prediction (wine quality)
|
||||
return {"predicted_quality": str(prediction[0])}
|
||||
BIN
scaler.pkl
Normal file
BIN
scaler.pkl
Normal file
Binary file not shown.
1600
winequality-red.csv
Normal file
1600
winequality-red.csv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user