71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
"""
|
|
Script to train a model from a tabular dataset using a remote compute
|
|
Based on:
|
|
https://learn.microsoft.com/en-us/azure/machine-learning/how-to-train-scikit-learn
|
|
"""
|
|
from azure.ai.ml import Input, command
|
|
from azure.ai.ml.constants import AssetTypes
|
|
|
|
from compute_aml import create_or_load_aml
|
|
from data_tabular import create_tabular_dataset, name_dataset
|
|
from environment import custom_env_name
|
|
from initialize_constants import AML_COMPUTE_NAME
|
|
from ml_client import create_or_load_ml_client
|
|
|
|
experiment_name = "mslearn-train-diabetes"
|
|
experiment_folder = "./diabetes_training"
|
|
script_name = "diabetes_training.py"
|
|
registered_model_name = "diabetes_model"
|
|
|
|
|
|
def main():
|
|
# 1. Create or Load a ML client
|
|
ml_client = create_or_load_ml_client()
|
|
|
|
# 2. Create compute resources
|
|
create_or_load_aml()
|
|
|
|
# 3. Create and register a File Dataset
|
|
create_tabular_dataset()
|
|
latest_version_dataset = next(
|
|
dataset.latest_version
|
|
for dataset in ml_client.data.list()
|
|
if dataset.name == name_dataset
|
|
)
|
|
|
|
# 4. Run Job
|
|
job = command(
|
|
inputs=dict(
|
|
script_name=script_name,
|
|
data=Input(
|
|
type=AssetTypes.URI_FILE,
|
|
# @latest doesn't work with dataset paths
|
|
path=f"azureml:{name_dataset}:{latest_version_dataset}",
|
|
),
|
|
registered_model_name=registered_model_name,
|
|
),
|
|
code=experiment_folder,
|
|
command=(
|
|
"python ${{inputs.script_name}}"
|
|
+ " --data ${{inputs.data}}"
|
|
+ " --registered_model_name ${{inputs.registered_model_name}}"
|
|
),
|
|
environment=f"{custom_env_name}@latest",
|
|
compute=AML_COMPUTE_NAME,
|
|
experiment_name=experiment_name,
|
|
display_name=experiment_name,
|
|
)
|
|
|
|
# submit the command
|
|
returned_job = ml_client.jobs.create_or_update(job)
|
|
|
|
# stream the output and wait until the job is finished
|
|
ml_client.jobs.stream(returned_job.name)
|
|
|
|
# refresh the latest status of the job after streaming
|
|
returned_job = ml_client.jobs.get(name=returned_job.name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|