User Name c06dea8d7e init
2025-06-07 23:27:56 +02:00

85 lines
2.8 KiB
Python

from fastapi import FastAPI, Form, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
import logging
import joblib
import numpy as np
import pandas as pd
import os
import uvicorn
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("FastAPI Iris Predictor")
try:
model = joblib.load("model.pkl")
logger.info("Model loaded successfully.")
except Exception as e:
logger.error("Failed to load the model: %s", e)
raise RuntimeError("Model loading failed.") from e
app = FastAPI()
# Mount static folder for serving static files like CSS, HTML
app.mount("/static", StaticFiles(directory="static"), name="static")
species_mapping = {0: "Iris Setosa", 1: "Iris Versicolor", 2: "Iris Virginica"}
@app.get("/", response_class=HTMLResponse)
async def home():
try:
# Serve the home page
with open("static/index.html", "r") as file:
logger.info("Home page served.")
return file.read()
except Exception as e:
logger.error("Error serving home page: %s", e)
return HTMLResponse(content="Error loading the home page.", status_code=500)
@app.get("/predict", response_class=HTMLResponse)
async def predict():
# Serve the prediction page
try:
with open("static/predict.html", "r") as file:
logger.info("Prediction page served.")
return file.read()
except Exception as e:
logger.error("Error serving prediction page: %s", e)
return HTMLResponse(content="Error loading the prediction page.", status_code=500)
@app.post("/predict")
async def predict_species(
sepal_length: float = Form(...),
sepal_width: float = Form(...),
petal_length: float = Form(...),
petal_width: float = Form(...)
):
feature_names = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
try:
input_df = pd.DataFrame([[sepal_length, sepal_width, petal_length, petal_width]], columns=feature_names)
logger.info("Input data received: %s", input_df)
# Predict using the model
prediction = model.predict(input_df)[0]
# Get the species name
species = species_mapping.get(prediction, "Unknown")
logger.info("Prediction made: %s", species)
return { "prediction": species}
except Exception as e:
logger.error("Error making prediction: %s", e)
return JSONResponse(content={"error": "Failed to make a prediction"}, status_code=500)
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000)) # Use Heroku's PORT
uvicorn.run(app, host="0.0.0.0", port=port)