2024-09-04 10:15:43 +02:00
|
|
|
# Import libraries
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
|
|
|
|
import mlflow
|
|
|
|
import mlflow.sklearn
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
from sklearn.ensemble import GradientBoostingClassifier
|
|
|
|
from sklearn.metrics import roc_auc_score
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
"""Main function of the script."""
|
|
|
|
|
|
|
|
# Input and output arguments
|
|
|
|
|
|
|
|
# Get script arguments
|
2024-09-05 13:03:42 +02:00
|
|
|
parser = argparse()
|
2024-09-04 10:15:43 +02:00
|
|
|
|
|
|
|
# Input dataset
|
|
|
|
parser.add_argument(
|
2024-09-05 13:03:42 +02:00
|
|
|
"--data",
|
2024-09-04 10:15:43 +02:00
|
|
|
type=str,
|
|
|
|
help="path to input data",
|
|
|
|
)
|
|
|
|
|
|
|
|
# Model name
|
2024-09-05 13:03:42 +02:00
|
|
|
parser.add_argument("--registered_model_name", type=str, help="model name")
|
2024-09-04 10:15:43 +02:00
|
|
|
|
|
|
|
# Hyperparameters
|
|
|
|
parser.add_argument(
|
2024-09-05 13:03:42 +02:00
|
|
|
"--learning_rate",
|
2024-09-04 10:15:43 +02:00
|
|
|
type=float,
|
|
|
|
dest="learning_rate",
|
|
|
|
default=0.1,
|
|
|
|
help="learning rate",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
2024-09-05 13:03:42 +02:00
|
|
|
"--n_estimators",
|
2024-09-04 10:15:43 +02:00
|
|
|
type=int,
|
|
|
|
dest="n_estimators",
|
|
|
|
default=100,
|
|
|
|
help="number of estimators",
|
|
|
|
)
|
|
|
|
|
|
|
|
# Add arguments to args collection
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(" ".join(f"{k}={v}" for k, v in vars(args).items()))
|
|
|
|
|
|
|
|
# Start Logging
|
2024-09-05 13:03:42 +02:00
|
|
|
mlflow.start_run()
|
2024-09-04 10:15:43 +02:00
|
|
|
|
|
|
|
# enable autologging
|
2024-09-05 13:03:42 +02:00
|
|
|
mlflow.sklearn.autolog()
|
2024-09-04 10:15:43 +02:00
|
|
|
|
|
|
|
# load the diabetes data (passed as an input dataset)
|
|
|
|
print("input data:", args.data)
|
|
|
|
|
|
|
|
diabetes = pd.read_csv(args.data)
|
|
|
|
|
|
|
|
# Separate features and labels
|
|
|
|
X, y = (
|
|
|
|
diabetes[
|
|
|
|
[
|
|
|
|
"Pregnancies",
|
|
|
|
"PlasmaGlucose",
|
|
|
|
"DiastolicBloodPressure",
|
|
|
|
"TricepsThickness",
|
|
|
|
"SerumInsulin",
|
|
|
|
"BMI",
|
|
|
|
"DiabetesPedigree",
|
|
|
|
"Age",
|
|
|
|
]
|
|
|
|
].values,
|
|
|
|
diabetes["Diabetic"].values,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Split data into training set and test set
|
2024-09-05 13:03:42 +02:00
|
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
2024-09-04 10:15:43 +02:00
|
|
|
X, y, test_size=0.30, random_state=0
|
|
|
|
)
|
|
|
|
|
|
|
|
# Train a Gradient Boosting classification model
|
|
|
|
# with the specified hyperparameters
|
|
|
|
print("Training a classification model")
|
2024-09-05 13:03:42 +02:00
|
|
|
model = GradientBoostingClassifier(
|
|
|
|
learning_rate=args.learning_rate, n_estimators=args.n_estimators
|
2024-09-04 10:15:43 +02:00
|
|
|
).fit(X_train, y_train)
|
|
|
|
|
|
|
|
# calculate accuracy
|
2024-09-05 13:03:42 +02:00
|
|
|
y_hat = model.predict(X_test)
|
2024-09-04 10:15:43 +02:00
|
|
|
accuracy = np.average(y_hat == y_test)
|
|
|
|
print("Accuracy:", accuracy)
|
|
|
|
mlflow.log_metric("Accuracy", float(accuracy))
|
|
|
|
|
|
|
|
# calculate AUC
|
2024-09-05 13:03:42 +02:00
|
|
|
y_scores = model.predict_proba(X_test)
|
2024-09-04 10:15:43 +02:00
|
|
|
auc = roc_auc_score(y_test, y_scores[:, 1])
|
|
|
|
print("AUC: " + str(auc))
|
|
|
|
mlflow.log_metric("AUC", float(auc))
|
|
|
|
|
|
|
|
# Registering the model to the workspace
|
|
|
|
print("Registering the model via MLFlow")
|
2024-09-05 13:03:42 +02:00
|
|
|
mlflow.sklearn.log_model(
|
2024-09-04 10:15:43 +02:00
|
|
|
sk_model=model,
|
|
|
|
registered_model_name=args.registered_model_name,
|
|
|
|
artifact_path=args.registered_model_name,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Saving the model to a file
|
|
|
|
mlflow.sklearn.save_model(
|
|
|
|
sk_model=model,
|
|
|
|
path=os.path.join(args.registered_model_name, "trained_model"),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Stop Logging
|
2024-09-05 13:03:42 +02:00
|
|
|
mlflow.end_run()
|
2024-09-04 10:15:43 +02:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|