-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
180 lines (157 loc) · 4.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Backend module for FastAPI application.
Author
------
Nicolas Rojas
"""
# imports
import os
from pydantic import BaseModel
import pandas as pd
from fastapi import FastAPI, HTTPException
import mlflow
import mysql.connector
def check_table_exists(table_name: str):
"""Check whether table exists in raw_data database. If not, create it.
Parameters
----------
table_name : str
Name of table to check.
"""
# count number of rows in predictions data table
query = f'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="{table_name}"'
connection = mysql.connector.connect(
url="http://db_raw:8088",
user="sqluser",
password="supersecretaccess2024",
database="raw_data",
)
cursor = connection.cursor()
cursor.execute(query)
results = cursor.fetchall()
# check whether table exists
if results[0][0] == 0:
# create table
print("----- table does not exists, creating it")
create_sql = f"CREATE TABLE `{table_name}`\
(`id` BIGINT,\
`age` SMALLINT,\
`anual_income` BIGINT,\
`credit_score` SMALLINT,\
`loan_amount` BIGINT,\
`loan_duration_years` TINYINT,\
`number_of_open_accounts` SMALLINT,\
`had_past_default` TINYINT,\
`loan_approval` TINYINT\
)"
cursor.execute(create_sql)
else:
# no need to create table
print("----- table already exists")
cursor.close()
connection.close()
def store_data(dataframe: pd.DataFrame, table_name: str):
"""Store dataframe data in given table, in raw data database.
Parameters
----------
dataframe : pd.DataFrame
Dataframe to store in database.
table_name : str
Name of the table to store the data.
"""
check_table_exists(table_name)
# insert every dataframe row into sql table
connection = mysql.connector.connect(
url="http://db_raw:8088",
user="sqluser",
password="supersecretaccess2024",
database="raw_data",
)
sql_column_names = ", ".join(
["`" + name + "`" for name in dataframe.columns]
)
cur = connection.cursor()
# VALUES in query are %s repeated as many columns are in dataframe
query = f"INSERT INTO `{table_name}` ({sql_column_names}) \
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
dataframe = list(dataframe.itertuples(index=False, name=None))
cur.executemany(query, dataframe)
connection.commit()
cur.close()
connection.close()
# connect to mlflow
os.environ["MLFLOW_S3_ENDPOINT_URL"] = "http://minio:8081"
os.environ["AWS_ACCESS_KEY_ID"] = "access2024minio"
os.environ["AWS_SECRET_ACCESS_KEY"] = "supersecretaccess2024"
mlflow.set_tracking_uri("http://mlflow:8083")
mlflow.set_experiment("mlflow_tracking_model")
# load model
MODEL_NAME = "clients_model"
MODEL_PRODUCTION_URI = f"models:/{MODEL_NAME}/production"
loaded_model = mlflow.pyfunc.load_model(model_uri=MODEL_PRODUCTION_URI)
# create FastAPI app
app = FastAPI()
class ModelInput(BaseModel):
"""Input model for FastAPI endpoint."""
id: int
age: float
annual_income: float
credit_score: float
loan_amount: float
loan_duration_years: int
number_of_open_accounts: float
had_past_default: int
@app.post("/predict/")
def predict(item: ModelInput):
"""Predict with loaded model over client data.
Parameters
----------
item : ModelInput
Input data for model, received as JSON.
Returns
-------
dict
Dictionary with prediction.
Raises
------
HTTPException
When receiving bad request.
"""
try:
global loaded_model
# get data from model_input
received_data = item.model_dump()
# preprocess data
preprocessed_data = received_data.copy()
preprocessed_data.pop("id")
# transform data into DataFrame
preprocessed_data = pd.DataFrame(
{
key: [
value,
]
for key, value in preprocessed_data.items()
}
)
# fill nan
preprocessed_data.fillna(0, inplace=True)
# predict with model
prediction = loaded_model.predict(preprocessed_data)
prediction = int(prediction[0])
# store data in raw_data database
received_data = pd.DataFrame(
{
key: [
value,
]
for key, value in received_data.items()
}
)
received_data["loan_approval"] = prediction
store_data(received_data, "predictions_data")
# return prediction as JSON
return {"prediction": prediction}
except Exception as error:
raise HTTPException(
status_code=400, detail=f"Bad Request\n{error}"
) from error