{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "aaad2c40-c784-466b-9f49-7a60e1cbdb4c",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import warnings\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn import metrics\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.svm import SVC\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.naive_bayes import GaussianNB\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "23f82a60-419e-4af2-9401-1996bec33332",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed acidity | \n",
" volatile acidity | \n",
" citric acid | \n",
" residual sugar | \n",
" chlorides | \n",
" free sulfur dioxide | \n",
" total sulfur dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" 5 | \n",
"
\n",
" \n",
" | 1 | \n",
" 7.8 | \n",
" 0.88 | \n",
" 0.00 | \n",
" 2.6 | \n",
" 0.098 | \n",
" 25.0 | \n",
" 67.0 | \n",
" 0.9968 | \n",
" 3.20 | \n",
" 0.68 | \n",
" 9.8 | \n",
" 5 | \n",
"
\n",
" \n",
" | 2 | \n",
" 7.8 | \n",
" 0.76 | \n",
" 0.04 | \n",
" 2.3 | \n",
" 0.092 | \n",
" 15.0 | \n",
" 54.0 | \n",
" 0.9970 | \n",
" 3.26 | \n",
" 0.65 | \n",
" 9.8 | \n",
" 5 | \n",
"
\n",
" \n",
" | 3 | \n",
" 11.2 | \n",
" 0.28 | \n",
" 0.56 | \n",
" 1.9 | \n",
" 0.075 | \n",
" 17.0 | \n",
" 60.0 | \n",
" 0.9980 | \n",
" 3.16 | \n",
" 0.58 | \n",
" 9.8 | \n",
" 6 | \n",
"
\n",
" \n",
" | 4 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" 5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality \n",
"0 9.4 5 \n",
"1 9.8 5 \n",
"2 9.8 5 \n",
"3 9.8 6 \n",
"4 9.4 5 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('winequality-red.csv')\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cbb402ec-23b9-4360-849a-2fccd12d48e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 1599 entries, 0 to 1598\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 fixed acidity 1599 non-null float64\n",
" 1 volatile acidity 1599 non-null float64\n",
" 2 citric acid 1599 non-null float64\n",
" 3 residual sugar 1599 non-null float64\n",
" 4 chlorides 1599 non-null float64\n",
" 5 free sulfur dioxide 1599 non-null float64\n",
" 6 total sulfur dioxide 1599 non-null float64\n",
" 7 density 1599 non-null float64\n",
" 8 pH 1599 non-null float64\n",
" 9 sulphates 1599 non-null float64\n",
" 10 alcohol 1599 non-null float64\n",
" 11 quality 1599 non-null int64 \n",
"dtypes: float64(11), int64(1)\n",
"memory usage: 150.0 KB\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | | \n",
" count | \n",
" mean | \n",
" std | \n",
" min | \n",
" 25% | \n",
" 50% | \n",
" 75% | \n",
" max | \n",
"
\n",
" \n",
" \n",
" \n",
" | fixed acidity | \n",
" 1599.000000 | \n",
" 8.319637 | \n",
" 1.741096 | \n",
" 4.600000 | \n",
" 7.100000 | \n",
" 7.900000 | \n",
" 9.200000 | \n",
" 15.900000 | \n",
"
\n",
" \n",
" | volatile acidity | \n",
" 1599.000000 | \n",
" 0.527821 | \n",
" 0.179060 | \n",
" 0.120000 | \n",
" 0.390000 | \n",
" 0.520000 | \n",
" 0.640000 | \n",
" 1.580000 | \n",
"
\n",
" \n",
" | citric acid | \n",
" 1599.000000 | \n",
" 0.270976 | \n",
" 0.194801 | \n",
" 0.000000 | \n",
" 0.090000 | \n",
" 0.260000 | \n",
" 0.420000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" | residual sugar | \n",
" 1599.000000 | \n",
" 2.538806 | \n",
" 1.409928 | \n",
" 0.900000 | \n",
" 1.900000 | \n",
" 2.200000 | \n",
" 2.600000 | \n",
" 15.500000 | \n",
"
\n",
" \n",
" | chlorides | \n",
" 1599.000000 | \n",
" 0.087467 | \n",
" 0.047065 | \n",
" 0.012000 | \n",
" 0.070000 | \n",
" 0.079000 | \n",
" 0.090000 | \n",
" 0.611000 | \n",
"
\n",
" \n",
" | free sulfur dioxide | \n",
" 1599.000000 | \n",
" 15.874922 | \n",
" 10.460157 | \n",
" 1.000000 | \n",
" 7.000000 | \n",
" 14.000000 | \n",
" 21.000000 | \n",
" 72.000000 | \n",
"
\n",
" \n",
" | total sulfur dioxide | \n",
" 1599.000000 | \n",
" 46.467792 | \n",
" 32.895324 | \n",
" 6.000000 | \n",
" 22.000000 | \n",
" 38.000000 | \n",
" 62.000000 | \n",
" 289.000000 | \n",
"
\n",
" \n",
" | density | \n",
" 1599.000000 | \n",
" 0.996747 | \n",
" 0.001887 | \n",
" 0.990070 | \n",
" 0.995600 | \n",
" 0.996750 | \n",
" 0.997835 | \n",
" 1.003690 | \n",
"
\n",
" \n",
" | pH | \n",
" 1599.000000 | \n",
" 3.311113 | \n",
" 0.154386 | \n",
" 2.740000 | \n",
" 3.210000 | \n",
" 3.310000 | \n",
" 3.400000 | \n",
" 4.010000 | \n",
"
\n",
" \n",
" | sulphates | \n",
" 1599.000000 | \n",
" 0.658149 | \n",
" 0.169507 | \n",
" 0.330000 | \n",
" 0.550000 | \n",
" 0.620000 | \n",
" 0.730000 | \n",
" 2.000000 | \n",
"
\n",
" \n",
" | alcohol | \n",
" 1599.000000 | \n",
" 10.422983 | \n",
" 1.065668 | \n",
" 8.400000 | \n",
" 9.500000 | \n",
" 10.200000 | \n",
" 11.100000 | \n",
" 14.900000 | \n",
"
\n",
" \n",
" | quality | \n",
" 1599.000000 | \n",
" 5.636023 | \n",
" 0.807569 | \n",
" 3.000000 | \n",
" 5.000000 | \n",
" 6.000000 | \n",
" 6.000000 | \n",
" 8.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.info()\n",
"data.describe().T.style.background_gradient(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4b1487c6-ae6c-4ee6-a0e1-03a156654bda",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fixed acidity 0\n",
"volatile acidity 0\n",
"citric acid 0\n",
"residual sugar 0\n",
"chlorides 0\n",
"free sulfur dioxide 0\n",
"total sulfur dioxide 0\n",
"density 0\n",
"pH 0\n",
"sulphates 0\n",
"alcohol 0\n",
"quality 0\n",
"dtype: int64"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.isna().sum()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2d090c6d-5de2-41d5-a098-3516dcdb7155",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed acidity | \n",
" volatile acidity | \n",
" citric acid | \n",
" residual sugar | \n",
" chlorides | \n",
" free sulfur dioxide | \n",
" total sulfur dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" Average | \n",
"
\n",
" \n",
" | 1 | \n",
" 7.8 | \n",
" 0.88 | \n",
" 0.00 | \n",
" 2.6 | \n",
" 0.098 | \n",
" 25.0 | \n",
" 67.0 | \n",
" 0.9968 | \n",
" 3.20 | \n",
" 0.68 | \n",
" 9.8 | \n",
" Average | \n",
"
\n",
" \n",
" | 2 | \n",
" 7.8 | \n",
" 0.76 | \n",
" 0.04 | \n",
" 2.3 | \n",
" 0.092 | \n",
" 15.0 | \n",
" 54.0 | \n",
" 0.9970 | \n",
" 3.26 | \n",
" 0.65 | \n",
" 9.8 | \n",
" Average | \n",
"
\n",
" \n",
" | 3 | \n",
" 11.2 | \n",
" 0.28 | \n",
" 0.56 | \n",
" 1.9 | \n",
" 0.075 | \n",
" 17.0 | \n",
" 60.0 | \n",
" 0.9980 | \n",
" 3.16 | \n",
" 0.58 | \n",
" 9.8 | \n",
" Average | \n",
"
\n",
" \n",
" | 4 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" Average | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality \n",
"0 9.4 Average \n",
"1 9.8 Average \n",
"2 9.8 Average \n",
"3 9.8 Average \n",
"4 9.4 Average "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = data.replace({'quality': {8: 'Good', 7: 'Good', 6: 'Average', 5: 'Average', 4: 'Bad', 3: 'Bad'}})\n",
"\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0d672775-f7ef-468e-ac49-a6ab5254ee3a",
"metadata": {},
"outputs": [],
"source": [
"X = data.drop(columns='quality')\n",
"y = data.quality\n",
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
"X_scaled = scaler.fit_transform(X)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "32357ab4-74d6-4a32-8ff9-249e04984341",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "13b6ce95-34a6-48ea-a4b8-bdd69c3ca496",
"metadata": {},
"outputs": [],
"source": [
"models = {\n",
" \"Logistic Regression\": LogisticRegression(),\n",
" \"SVM\": SVC(),\n",
" \"Random Forest\": RandomForestClassifier(),\n",
" \"Decision Tree\": DecisionTreeClassifier(),\n",
" \"K-Nearest Neighbors\": KNeighborsClassifier(),\n",
" \"Naive Bayes\": GaussianNB()\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9096f5a4-8329-47d0-9b7c-8e3123191a3b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Logistic Regression...\n",
"Logistic Regression Accuracy: 0.828125\n",
"Training SVM...\n",
"SVM Accuracy: 0.8375\n",
"Training Random Forest...\n",
"Random Forest Accuracy: 0.865625\n",
"Training Decision Tree...\n",
"Decision Tree Accuracy: 0.79375\n",
"Training K-Nearest Neighbors...\n",
"K-Nearest Neighbors Accuracy: 0.840625\n",
"Training Naive Bayes...\n",
"Naive Bayes Accuracy: 0.803125\n"
]
}
],
"source": [
"results = {}\n",
"for model_name, model in models.items():\n",
" print(f\"Training {model_name}...\")\n",
" model.fit(X_train, y_train)\n",
" y_pred = model.predict(X_test)\n",
" accuracy = metrics.accuracy_score(y_test, y_pred)\n",
" results[model_name] = accuracy\n",
" print(f\"{model_name} Accuracy: {accuracy}\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d8f786c8-67e1-4131-bb2c-22d91d7fafab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Model Comparison (Accuracy):\n",
"Random Forest: 0.8656\n",
"K-Nearest Neighbors: 0.8406\n",
"SVM: 0.8375\n",
"Logistic Regression: 0.8281\n",
"Naive Bayes: 0.8031\n",
"Decision Tree: 0.7937\n"
]
}
],
"source": [
"sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)\n",
"print(\"\\nModel Comparison (Accuracy):\")\n",
"for model_name, accuracy in sorted_results:\n",
" print(f\"{model_name}: {accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b6eec30b-3a13-4f99-b841-a8633a8b4ad7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 27 candidates, totalling 135 fits\n",
"Best parameters for Random Forest: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n"
]
}
],
"source": [
"# Example: Hyperparameter tuning for Random Forest\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200],\n",
" 'max_depth': [10, 20, 30],\n",
" 'min_samples_split': [2, 5, 10]\n",
"}\n",
"\n",
"grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5, n_jobs=-1, verbose=1)\n",
"grid_search.fit(X_train, y_train)\n",
"print(f\"Best parameters for Random Forest: {grid_search.best_params_}\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "488c6fe7-1f79-4e6a-9649-fd82c0acd44e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cross-validation scores for Random Forest: [0.834375 0.821875 0.840625 0.8125 0.84326019]\n",
"Mean cross-validation score: 0.830527037617555\n"
]
}
],
"source": [
"# Cross-validation score for the best model (Random Forest)\n",
"best_model = grid_search.best_estimator_\n",
"cv_scores = cross_val_score(best_model, X_scaled, y, cv=5)\n",
"print(f\"Cross-validation scores for Random Forest: {cv_scores}\")\n",
"print(f\"Mean cross-validation score: {cv_scores.mean()}\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "81b08573-19a3-4508-9c9e-4a36ced59197",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" Average 0.88 0.95 0.92 262\n",
" Bad 1.00 0.00 0.00 11\n",
" Good 0.67 0.51 0.58 47\n",
"\n",
" accuracy 0.86 320\n",
" macro avg 0.85 0.49 0.50 320\n",
"weighted avg 0.85 0.86 0.83 320\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"\n",
"# Classification report with zero_division parameter to handle undefined precision\n",
"y_pred = best_model.predict(X_test)\n",
"print(\"\\nClassification Report:\")\n",
"print(classification_report(y_test, y_pred, zero_division=1))\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "191fa955-1969-4224-8420-a608c1463e31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model and Scaler have been saved.\n"
]
}
],
"source": [
"import joblib\n",
"\n",
"# Save the best model (Random Forest with Hyperparameter Tuning)\n",
"joblib.dump(best_model, 'best_model.pkl')\n",
"\n",
"# Save the scaler (MinMax SScaler used for feature scaling)\n",
"joblib.dump(scaler, 'scaler.pkl')\n",
"\n",
"print(\"Model and Scaler have been saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f75662b-1b38-4510-bccd-adefabbe635f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}