-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
158 lines (133 loc) · 5.93 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from starlette.responses import HTMLResponse
from uvicorn import run as app_run
from typing import Optional
# Importing constants and pipeline modules from the project
from src.constants import APP_HOST, APP_PORT
from src.pipelines.predict_pipeline import AdData, AdDataClassifier
from src.pipelines.train_pipeline import TrainPipeline
# Initialize FastAPI application
app = FastAPI()
# Mount the 'static' directory for serving static files (like CSS)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Set up Jinja2 template engine for rendering HTML templates
templates = Jinja2Templates(directory='templates')
# Allow all origins for Cross-Origin Resource Sharing (CORS)
origins = ["*"]
# Configure middleware to handle CORS, allowing requests from any origin
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class DataForm:
"""
DataForm class to handle and process incoming form data.
This class defines the ad-related attributes expected from the form.
"""
def __init__(self, request: Request):
self.request: Request = request
self.age: Optional[int] = None
self.gender_Male: Optional[int] = None
self.gender_Non_Binary: Optional[int] = None
self.device_type_Mobile: Optional[int] = None
self.device_type_Tablet: Optional[int] = None
self.ad_position_Side: Optional[int] = None
self.ad_position_Top: Optional[int] = None
self.browsing_history_Entertainment: Optional[int] = None
self.browsing_history_News: Optional[int] = None
self.browsing_history_Shopping: Optional[int] = None
self.browsing_history_Social_Media: Optional[int] = None
self.time_of_day_Evening: Optional[int] = None
self.time_of_day_Morning: Optional[int] = None
self.time_of_day_Night: Optional[int] = None
async def get_user_data(self):
"""
Method to retrieve and assign form data to class attributes.
This method is asynchronous to handle form data fetching without blocking.
"""
form = await self.request.form()
self.age = form.get("age")
self.gender_Male = form.get("gender_Male")
self.gender_Non_Binary = form.get("gender_Non-Binary")
self.device_type_Mobile = form.get("device_type_Mobile")
self.device_type_Tablet = form.get("device_type_Tablet")
self.ad_position_Side = form.get("ad_position_Side")
self.ad_position_Top = form.get("ad_position_Top")
self.browsing_history_Entertainment = form.get("browsing_history_Entertainment")
self.browsing_history_News = form.get("browsing_history_News")
self.browsing_history_Shopping = form.get("browsing_history_Shopping")
self.browsing_history_Social_Media = form.get("browsing_history_Social-Media")
self.time_of_day_Evening = form.get("time_of_day_Evening")
self.time_of_day_Morning = form.get("time_of_day_Morning")
self.time_of_day_Night = form.get("time_of_day_Night")
# Route to render the main page with the form
@app.get("/", tags=["authentication"])
async def index(request: Request):
"""
Renders the main HTML form page for ad data input.
"""
return templates.TemplateResponse(
"addata.html", {"request": request, "context": "Rendering"})
# Route to trigger the model training process
@app.get("/train")
async def trainRouteClient():
"""
Endpoint to initiate the model training pipeline.
"""
try:
train_pipeline = TrainPipeline()
train_pipeline.run_pipeline()
return Response("Training successful!!!")
except Exception as e:
return Response(f"Error Occurred! {e}")
# Route to handle form submission and make predictions
@app.post("/")
async def predictRouteClient(request: Request):
"""
Endpoint to receive form data, process it, and make a prediction.
"""
try:
form = DataForm(request)
await form.get_user_data()
# Prepare data based on form input
ad_data = AdData(
age=form.age,
gender_Male=form.gender_Male,
gender_Non_Binary=form.gender_Non_Binary,
device_type_Mobile=form.device_type_Mobile,
device_type_Tablet=form.device_type_Tablet,
ad_position_Side=form.ad_position_Side,
ad_position_Top=form.ad_position_Top,
browsing_history_Entertainment=form.browsing_history_Entertainment,
browsing_history_News=form.browsing_history_News,
browsing_history_Shopping=form.browsing_history_Shopping,
browsing_history_Social_Media=form.browsing_history_Social_Media,
time_of_day_Evening=form.time_of_day_Evening,
time_of_day_Morning=form.time_of_day_Morning,
time_of_day_Night=form.time_of_day_Night
)
# Convert form data into a DataFrame for the model
ad_df = ad_data.get_ad_input_data_frame()
# Initialize the prediction pipeline
model_predictor = AdDataClassifier()
# Make a prediction and retrieve the result
value = model_predictor.predict(dataframe=ad_df)[0]
# Interpret the prediction result as 'Response-Yes' or 'Response-No'
status = "User Will Click Ad" if value == 1 else "User Will Not Click Ad"
# Render the same HTML page with the prediction result
return templates.TemplateResponse(
"addata.html",
{"request": request, "context": status},
)
except Exception as e:
return {"status": False, "error": f"{e}"}
# Main entry point to start the FastAPI server
if __name__ == "__main__":
app_run(app, host=APP_HOST, port=APP_PORT)