85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
from fastapi import FastAPI, Form, Request
|
|
from fastapi.responses import HTMLResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
import logging
|
|
import joblib
|
|
import numpy as np
|
|
import pandas as pd
|
|
import os
|
|
import uvicorn
|
|
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
handlers=[logging.StreamHandler()]
|
|
)
|
|
|
|
logger = logging.getLogger("FastAPI Iris Predictor")
|
|
|
|
try:
|
|
model = joblib.load("model.pkl")
|
|
logger.info("Model loaded successfully.")
|
|
except Exception as e:
|
|
logger.error("Failed to load the model: %s", e)
|
|
raise RuntimeError("Model loading failed.") from e
|
|
|
|
app = FastAPI()
|
|
|
|
# Mount static folder for serving static files like CSS, HTML
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
species_mapping = {0: "Iris Setosa", 1: "Iris Versicolor", 2: "Iris Virginica"}
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def home():
|
|
try:
|
|
# Serve the home page
|
|
with open("static/index.html", "r") as file:
|
|
logger.info("Home page served.")
|
|
return file.read()
|
|
except Exception as e:
|
|
logger.error("Error serving home page: %s", e)
|
|
return HTMLResponse(content="Error loading the home page.", status_code=500)
|
|
|
|
@app.get("/predict", response_class=HTMLResponse)
|
|
async def predict():
|
|
# Serve the prediction page
|
|
try:
|
|
with open("static/predict.html", "r") as file:
|
|
logger.info("Prediction page served.")
|
|
return file.read()
|
|
except Exception as e:
|
|
logger.error("Error serving prediction page: %s", e)
|
|
return HTMLResponse(content="Error loading the prediction page.", status_code=500)
|
|
|
|
@app.post("/predict")
|
|
async def predict_species(
|
|
sepal_length: float = Form(...),
|
|
sepal_width: float = Form(...),
|
|
petal_length: float = Form(...),
|
|
petal_width: float = Form(...)
|
|
):
|
|
feature_names = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
|
|
try:
|
|
input_df = pd.DataFrame([[sepal_length, sepal_width, petal_length, petal_width]], columns=feature_names)
|
|
logger.info("Input data received: %s", input_df)
|
|
|
|
# Predict using the model
|
|
prediction = model.predict(input_df)[0]
|
|
|
|
# Get the species name
|
|
species = species_mapping.get(prediction, "Unknown")
|
|
logger.info("Prediction made: %s", species)
|
|
|
|
return { "prediction": species}
|
|
except Exception as e:
|
|
logger.error("Error making prediction: %s", e)
|
|
return JSONResponse(content={"error": "Failed to make a prediction"}, status_code=500)
|
|
|
|
if __name__ == "__main__":
|
|
port = int(os.getenv("PORT", 8000)) # Use Heroku's PORT
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|