diff --git a/assets/training/finetune_acft_image/components/finetune/medimage_adapter/asset.yaml b/assets/training/finetune_acft_image/components/finetune/medimage_adapter/asset.yaml new file mode 100644 index 0000000000..1c98f3aee5 --- /dev/null +++ b/assets/training/finetune_acft_image/components/finetune/medimage_adapter/asset.yaml @@ -0,0 +1,3 @@ +type: component +spec: spec.yaml +categories: ["Foundational Models Finetune"] diff --git a/assets/training/finetune_acft_image/components/finetune/medimage_adapter/spec.yaml b/assets/training/finetune_acft_image/components/finetune/medimage_adapter/spec.yaml new file mode 100644 index 0000000000..91bdd983d6 --- /dev/null +++ b/assets/training/finetune_acft_image/components/finetune/medimage_adapter/spec.yaml @@ -0,0 +1,114 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: medimgage_adapter_finetune +version: 0.0.3 +type: command + +is_deterministic: True + +display_name: Medical Image Adapter Finetune +description: Component to finetune the model using the medical image data + +environment: azureml:/subscriptions/dbd697c3-ef40-488f-83e6-5ad4dfb78f9b/resourceGroups/rdondera/providers/Microsoft.MachineLearningServices/workspaces/validatr/environments/medimage-adapter-pt-ft/versions/13 + +code: ../../../src/finetune + +distribution: + type: pytorch + +inputs: + train_data_path: + type: uri_file + optional: false + description: Path to the training data file. + mode: ro_mount + + validation_data_path: + type: uri_file + optional: false + description: Path to the validation data file. + mode: ro_mount + + train_dataloader_batch_size: + type: integer + min: 1 + default: 8 + optional: true + description: Batch size for the training dataloader. + + validation_dataloader_batch_size: + type: integer + min: 1 + default: 1 + optional: true + description: Batch size for the validation dataloader. + + train_dataloader_workers: + type: integer + min: 0 + default: 2 + optional: true + description: Number of workers for the training dataloader. + + validation_dataloader_workers: + type: integer + min: 0 + default: 2 + optional: true + description: Number of workers for the validation dataloader. + + output_classes: + type: integer + min: 1 + default: 5 + optional: true + description: Number of output classes. + + hidden_dimensions: + type: integer + min: 1 + default: 512 + optional: true + description: Number of hidden dimensions. + + input_channels: + type: integer + min: 1 + default: 1024 + optional: true + description: Number of input channels. + + learning_rate: + type: number + default: 0.0003 + optional: true + description: Learning rate for training. + + max_epochs: + type: integer + min: 1 + default: 10 + optional: true + description: Maximum number of epochs for training. + +outputs: + output_model_path: + type: uri_folder + description: Path to save the output model. + mode: rw_mount + +command: >- + python medimage_train.py + --task_name "AdapterTrain" + --train_data_path "${{inputs.train_data_path}}" + --validation_data_path "${{inputs.validation_data_path}}" + $[[--train_dataloader_batch_size "${{inputs.train_dataloader_batch_size}}"]] + $[[--validation_dataloader_batch_size "${{inputs.validation_dataloader_batch_size}}"]] + $[[--train_dataloader_workers "${{inputs.train_dataloader_workers}}"]] + $[[--validation_dataloader_workers "${{inputs.validation_dataloader_workers}}"]] + $[[--output_classes "${{inputs.output_classes}}"]] + $[[--hidden_dimensions "${{inputs.hidden_dimensions}}"]] + $[[--input_channels "${{inputs.input_channels}}"]] + $[[--learning_rate "${{inputs.learning_rate}}"]] + $[[--max_epochs "${{inputs.max_epochs}}"]] + --output_model_path "${{outputs.output_model_path}}" + diff --git a/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/asset.yaml b/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/asset.yaml new file mode 100644 index 0000000000..2d8741e825 --- /dev/null +++ b/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/asset.yaml @@ -0,0 +1,3 @@ +type: component +spec: spec.yaml +categories: ["Foundational Models", "Finetune"] diff --git a/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/spec.yaml b/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/spec.yaml new file mode 100644 index 0000000000..da19e53035 --- /dev/null +++ b/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/spec.yaml @@ -0,0 +1,150 @@ +$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json +name: medimage_insight_ft_pipeline +version: 0.0.1.yesh5 +type: pipeline +display_name: Medical Image Insight Classification Adapter Pipeline +description: Pipeline Component to finetune Hugging Face pretrained models for chat completion task. The component supports optimizations such as LoRA, Deepspeed and ONNXRuntime for performance enhancement. See [docs](https://aka.ms/azureml/components/chat_completion_pipeline) to learn more. + +inputs: + instance_type_preprocess: + type: string + optional: true + default: Standard_d12_v2 + description: Instance type to be used for preprocess component in case of serverless compute, eg. standard_d12_v2. + The parameter compute_preprocess must be set to 'serverless' for instance_type to be used + instance_type_finetune: + type: string + optional: true + default: Standard_nc24rs_v3 + description: Instance type to be used for finetune component in case of serverless compute, eg. standard_nc24rs_v3. + The parameter compute_finetune must be set to 'serverless' for instance_type to be used + compute_preprocess: + type: string + optional: true + default: serverless + description: compute to be used for preprocess eg. provide 'FT-Cluster' if your + compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value. + If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used + compute_finetune: + type: string + optional: true + default: serverless + description: compute to be used for finetune eg. provide 'FT-Cluster' if your + compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value. + If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used + mlflow_model_path: + type: uri_folder + optional: false + description: Path to the MLflow model to be used. + + zeroshot_path: + type: uri_file + optional: false + description: Path to the zeroshot data file. + mode: rw_mount + + test_train_split_csv_path: + type: uri_file + optional: false + description: Path to the CSV file containing test-train split information. + mode: rw_mount + + train_dataloader_batch_size: + type: integer + min: 1 + default: 8 + optional: true + description: Batch size for the training dataloader. + + validation_dataloader_batch_size: + type: integer + min: 1 + default: 1 + optional: true + description: Batch size for the validation dataloader. + + train_dataloader_workers: + type: integer + min: 0 + default: 2 + optional: true + description: Number of workers for the training dataloader. + + validation_dataloader_workers: + type: integer + min: 0 + default: 2 + optional: true + description: Number of workers for the validation dataloader. + + output_classes: + type: integer + min: 1 + default: 5 + optional: true + description: Number of output classes. + + hidden_dimensions: + type: integer + min: 1 + default: 512 + optional: true + description: Number of hidden dimensions. + + input_channels: + type: integer + min: 1 + default: 1024 + optional: true + description: Number of input channels. + + learning_rate: + type: number + default: 0.0003 + optional: true + description: Learning rate for training. + + max_epochs: + type: integer + min: 1 + default: 10 + optional: true + description: Maximum number of epochs for training. + +outputs: + output_model_path: + type: uri_folder + description: Path to save the output model. + mode: rw_mount + +jobs: + medical_image_embedding_datapreprocessing: + type: command + component: azureml:medical_image_embedding_datapreprocessing:0.0.1.yesh5 + compute: '${{parent.inputs.compute_finetune}}' + resources: + instance_type: '${{parent.inputs.instance_type_preprocess}}' + inputs: + mlflow_model_path: '${{parent.inputs.mlflow_model_path}}' + zeroshot_path: '${{parent.inputs.zeroshot_path}}' + test_train_split_csv_path: '${{parent.inputs.test_train_split_csv_path}}' + medimgage_adapter_finetune: + type: command + component: azureml:medimgage_adapter_finetune:0.0.1.yesh1 + compute: '${{parent.inputs.compute_finetune}}' + resources: + instance_type: '${{parent.inputs.instance_type_finetune}}' + inputs: + train_data_path: '${{parent.jobs.medical_image_embedding_datapreprocessing.outputs.output_train_pkl}}' + validation_data_path: '${{parent.jobs.medical_image_embedding_datapreprocessing.outputs.output_validation_pkl}}' + train_dataloader_batch_size: '${{parent.inputs.train_dataloader_batch_size}}' + validation_dataloader_batch_size: '${{parent.inputs.validation_dataloader_batch_size}}' + train_dataloader_workers: '${{parent.inputs.train_dataloader_workers}}' + validation_dataloader_workers: '${{parent.inputs.validation_dataloader_workers}}' + output_classes: '${{parent.inputs.output_classes}}' + hidden_dimensions: '${{parent.inputs.hidden_dimensions}}' + input_channels: '${{parent.inputs.input_channels}}' + learning_rate: '${{parent.inputs.learning_rate}}' + max_epochs: '${{parent.inputs.max_epochs}}' + outputs: + output_model_path: '${{parent.outputs.output_model_path}}' \ No newline at end of file diff --git a/assets/training/finetune_acft_image/components/preprocess/image_embedding/asset.yaml b/assets/training/finetune_acft_image/components/preprocess/image_embedding/asset.yaml new file mode 100644 index 0000000000..49b5bc2c41 --- /dev/null +++ b/assets/training/finetune_acft_image/components/preprocess/image_embedding/asset.yaml @@ -0,0 +1,3 @@ +type: component +spec: spec.yaml +categories: ["Foundational Models", "Embedding"] diff --git a/assets/training/finetune_acft_image/components/preprocess/image_embedding/spec.yaml b/assets/training/finetune_acft_image/components/preprocess/image_embedding/spec.yaml new file mode 100644 index 0000000000..6b4e502caa --- /dev/null +++ b/assets/training/finetune_acft_image/components/preprocess/image_embedding/spec.yaml @@ -0,0 +1,46 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: medical_image_embedding_datapreprocessing +version: 0.0.1.yesh5 +type: command + +is_deterministic: True + +display_name: Embedding Generation for Medical Images +description: To genrate embeddings for medical images. See [docs](https://aka.ms/azureml/components/medical_image_embedding_datapreprocessing) to learn more. + +environment: azureml:/subscriptions/dbd697c3-ef40-488f-83e6-5ad4dfb78f9b/resourceGroups/rdondera/providers/Microsoft.MachineLearningServices/workspaces/validatr/environments/medimage-embedding-generation/versions/5 + +code: ../../../src/medimage_insight_adapter_preprocess + +inputs: + zeroshot_path: + type: uri_file + optional: false + description: Path to the zeroshot data file. + mode: rw_mount + + test_train_split_csv_path: + type: uri_file + optional: false + description: Path to the CSV file containing test-train split information. + mode: rw_mount + + mlflow_model_path: + type: uri_folder + optional: false + description: Path to the MLflow model to be imported. + mode: ro_mount + +outputs: + output_train_pkl: + type: uri_folder + description: Path to the output training PKL file. + mode: rw_mount + + output_validation_pkl: + type: uri_folder + description: Path to the output validation PKL file. + mode: rw_mount + +command: >- + python medimage_datapreprocess.py --task_name "MedEmbedding" --zeroshot_path "${{inputs.zeroshot_path}}" --test_train_split_csv_path "${{inputs.test_train_split_csv_path}}" --output_train_pkl "${{outputs.output_train_pkl}}" --output_validation_pkl "${{outputs.output_validation_pkl}}" --mlflow_model_path "${{inputs.mlflow_model_path}}" diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/asset.yaml b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/asset.yaml new file mode 100644 index 0000000000..1bb5c7d5dc --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/asset.yaml @@ -0,0 +1,6 @@ +name: acft-medimageinsight-embedding +version: auto +type: environment +spec: spec.yaml +extra_config: environment.yaml +categories: ["PyTorch", "Training"] diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/context/Dockerfile b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/context/Dockerfile new file mode 100644 index 0000000000..83e3983deb --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/context/Dockerfile @@ -0,0 +1,12 @@ +# PTCA image +FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu118-py310-torch222:biweekly.202409.3 + +USER root +RUN apt-get -y update + +# Install unzip +RUN apt-get -y install unzip + +# Install required packages from pypi +COPY requirements.txt . +RUN pip install -r requirements.txt --no-cache-dir \ No newline at end of file diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/context/requirements.txt b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/context/requirements.txt new file mode 100644 index 0000000000..385476d070 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/context/requirements.txt @@ -0,0 +1,25 @@ +azureml-acft-common-components==0.0.65 +azureml-acft-contrib-hf-nlp==0.0.65 +https://automlcesdkdataresources.blob.core.windows.net/wheels/classification_demo-0.1.2-py3-none-any.whl +mlflow==2.14.3 +cloudpickle==2.2.1 +colorama==0.4.6 +einops==0.8.0 +ftfy==6.3.1 +fvcore==0.1.5.post20221221 +jinja2==3.1.5 +mup==1.0.0 +numpy==1.23.5 +packaging==24.2 +pandas==2.2.3 +psutil==6.1.1 +pyyaml==6.0.2 +requests==2.32.3 +scikit-learn==1.5.2 +scipy==1.13.1 +sentencepiece==0.2.0 +tenacity==9.0.0 +timm==1.0.13 +tornado==6.4.2 +SimpleITK~=2.4.0 +transformers==4.16.2 \ No newline at end of file diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/environment.yaml b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/environment.yaml new file mode 100644 index 0000000000..4c37ddb078 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/environment.yaml @@ -0,0 +1,12 @@ +image: + name: azureml/curated/acft-medimageinsight-embedding + os: linux + context: + dir: context + dockerfile: Dockerfile + template_files: + - Dockerfile + - requirements.txt + publish: + location: mcr + visibility: public diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/spec.yaml b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/spec.yaml new file mode 100644 index 0000000000..027ea18a08 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_embedding/spec.yaml @@ -0,0 +1,15 @@ +$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json + +description: Environment used by MedImageInsight Embedding generation component + +name: "{{asset.name}}" +version: "{{asset.version}}" + +build: + path: "{{image.context.path}}" + dockerfile_path: "{{image.dockerfile.path}}" + +os_type: linux + +tags: + Preview: "" diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/asset.yaml b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/asset.yaml new file mode 100644 index 0000000000..dd73485e56 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/asset.yaml @@ -0,0 +1,6 @@ +name: acft-medimageinsight-image-gpu +version: auto +type: environment +spec: spec.yaml +extra_config: environment.yaml +categories: ["PyTorch", "Training"] diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/context/Dockerfile b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/context/Dockerfile new file mode 100644 index 0000000000..bb81be9c42 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/context/Dockerfile @@ -0,0 +1,12 @@ +# PTCA image +FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2204-cu118-py310-torch222:latest + +USER root +RUN apt-get -y update + +# Install unzip +RUN apt-get -y install unzip + +# Install required packages from pypi +COPY requirements.txt . +RUN pip install -r requirements.txt --no-cache-dir diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/context/requirements.txt b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/context/requirements.txt new file mode 100644 index 0000000000..af98b251e3 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/context/requirements.txt @@ -0,0 +1,29 @@ +azureml-acft-common-components==0.0.61 +azureml-acft-contrib-hf-nlp==0.0.61 +opencv-python +pyyaml +json_tricks +yacs +scikit-learn +pandas +timm==0.4.12 +numpy==1.22.2 +einops +fvcore +transformers==4.16.2 +sentencepiece +ftfy +regex +vision-datasets==0.2.7 +tenacity +requests +azure-storage-blob==12.11.0 +deepspeed==0.6.3 +GPUtil +mup==1.0.0 +mlflow +azureml-mlflow +torchvision==0.12.0 --index-url https://download.pytorch.org/whl/cu118 +torch==1.11.0 --index-url https://download.pytorch.org/whl/cu118 +https://automlcesdkdataresources.blob.core.windows.net/wheels/MainzVision-0.3.0-py3-none-any.whl +https://automlcesdkdataresources.blob.core.windows.net/wheels/MainzTrain-0.2.0-py3-none-any.whl diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/environment.yaml b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/environment.yaml new file mode 100644 index 0000000000..10185e7ce0 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/environment.yaml @@ -0,0 +1,12 @@ +image: + name: azureml/curated/acft-medimageinsight-image-gpu + os: linux + context: + dir: context + dockerfile: Dockerfile + template_files: + - Dockerfile + - requirements.txt + publish: + location: mcr + visibility: public diff --git a/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/spec.yaml b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/spec.yaml new file mode 100644 index 0000000000..a3fdb3f882 --- /dev/null +++ b/assets/training/finetune_acft_image/environments/acft_image_medimageinsight_finetune/spec.yaml @@ -0,0 +1,15 @@ +$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json + +description: Environment used by MedImageInsight Finetune components + +name: "{{asset.name}}" +version: "{{asset.version}}" + +build: + path: "{{image.context.path}}" + dockerfile_path: "{{image.dockerfile.path}}" + +os_type: linux + +tags: + Preview: "" diff --git a/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/medimage_train.py b/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/medimage_train.py new file mode 100644 index 0000000000..92439280d2 --- /dev/null +++ b/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/medimage_train.py @@ -0,0 +1,269 @@ +import argparse +from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals +from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException +from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError +from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import ( + swallow_all_exceptions, +) +from azureml._common._error_definition.azureml_error import AzureMLError + +from azureml.acft.contrib.hf import VERSION, PROJECT_NAME +from azureml.acft.contrib.hf.nlp.constants.constants import LOGS_TO_BE_FILTERED_IN_APPINSIGHTS +import pandas as pd +import torch +import os +from classification_demo.MedImageInsight import medimageinsight_package +from classification_demo.adaptor_training import training +import matplotlib.pyplot as plt +import SimpleITK as sitk +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score + +# Suppress SimpleITK warnings +sitk.ProcessObject_SetGlobalWarningDisplay(False) + + +COMPONENT_NAME = "ACFT-MedImage-Classification-Training" +logger = get_logger_app("azureml.acft.contrib.hf.scripts.src.train.classification_adaptor_train") +TRAIN_EMBEDDING_FILE_NAME = "train_embeddings.pkl" +VALIDATION_EMBEDDING_FILE_NAME = "validation_embeddings.pkl" + + +def get_parser(): + """ + Add arguments and returns the parser. Here we add all the arguments for all the tasks. + + Those arguments that are not relevant for the input task should be ignored. + """ + parser = argparse.ArgumentParser(description='Process medical images and get embeddings', allow_abbrev=False) + + parser.add_argument( + "--task_name", + type=str, + required=True, + help="The name of the task to be executed", + ) + parser.add_argument( + '--train_data_path', + type=str, + required=True, + help='The path to the training data.' + ) + parser.add_argument( + '--validation_data_path', + type=str, + required=True, + help='The path to the validation data.' + ) + parser.add_argument( + '--train_dataloader_batch_size', + type=int, + required=True, + help='Batch size for the training dataloader.' + ) + parser.add_argument( + '--validation_dataloader_batch_size', + type=int, + required=True, + help='Batch size for the validation dataloader.' + ) + parser.add_argument( + '--train_dataloader_workers', + type=int, + required=True, + help='Number of workers for the training dataloader.' + ) + parser.add_argument( + '--validation_dataloader_workers', + type=int, + required=True, + help='Number of workers for the validation dataloader.' + ) + parser.add_argument( + '--output_classes', + type=int, + required=True, + help='Number of output classes.' + ) + parser.add_argument( + '--hidden_dimensions', + type=int, + required=True, + help='Number of hidden dimensions.' + ) + parser.add_argument( + '--input_channels', + type=int, + required=True, + help='Number of input channels.' + ) + parser.add_argument( + '--learning_rate', + type=float, + required=True, + help='Learning rate for the model.' + ) + parser.add_argument( + '--max_epochs', + type=int, + required=True, + help='Maximum number of epochs for training.' + ) + parser.add_argument( + '--output_model_path', + type=str, + required=True, + help='Path to save the output model.' + ) + + return parser + + +def load_data(train_data_path: str, validation_data_path: str, train_file_name: str, + validation_file_name: str) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Load the training and validation data from the provided folder paths. + + Args: + train_data_path (str): The path to the folder containing the training data file. + validation_data_path (str): The path to the folder containing the validation data file. + train_file_name (str): The name of the training data file. + validation_file_name (str): The name of the validation data file. + + Returns: + tuple[pd.DataFrame, pd.DataFrame]: DataFrames containing the training and validation data. + """ + + train_data_file = os.path.join(train_data_path, train_file_name) + validation_data_file = os.path.join(validation_data_path, validation_file_name) + train_data = pd.read_pickle(train_data_file) + validation_data = pd.read_pickle(validation_data_file) + return train_data, validation_data + + +def initialize_model(args: argparse.Namespace) -> torch.nn.Module: + """ + Initialize the model with the provided arguments. + + Args: + args (argparse.Namespace): Parsed command line arguments. + + Returns: + torch.nn.Module: Initialized model. + """ + return training.create_model( + in_channels=args.input_channels, + hidden_dim=args.hidden_dimensions, + num_class=args.output_classes + ) + + +def prepare_dataloaders( + train_data: pd.DataFrame, + validation_data: pd.DataFrame, + args: argparse.Namespace +) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: + """ + Prepare the dataloaders for training and validation datasets. + + Args: + train_data (pd.DataFrame): DataFrame containing the training data. + validation_data (pd.DataFrame): DataFrame containing the validation data. + args (argparse.Namespace): Parsed command line arguments. + + Returns: + tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: + Dataloaders for the training and validation datasets. + """ + train_samples = { + "features": train_data["features"].tolist(), + "img_name": train_data["Name"].tolist(), + "labels": train_data["Label"].tolist(), + } + val_samples = { + "features": validation_data["features"].tolist(), + "img_name": validation_data["Name"].tolist(), + "labels": validation_data["Label"].tolist(), + } + train_dataloader = training.create_data_loader( + train_samples, + csv=train_data, + mode="train", + batch_size=args.train_dataloader_batch_size, + num_workers=args.train_dataloader_workers, + pin_memory=True + ) + validation_dataloader = training.create_data_loader( + val_samples, + csv=validation_data, + mode="val", + batch_size=args.validation_dataloader_batch_size, + num_workers=args.validation_dataloader_workers, + pin_memory=True + ) + return train_dataloader, validation_dataloader + + +def train_model( + train_dataloader: torch.utils.data.DataLoader, + validation_dataloader: torch.utils.data.DataLoader, + model: torch.nn.Module, + args: argparse.Namespace +) -> tuple[float, float]: + """ + Train the model using the provided dataloaders and arguments. + + Args: + train_dataloader (torch.utils.data.DataLoader): Dataloader for the training data. + validation_dataloader (torch.utils.data.DataLoader): Dataloader for the validation data. + model (torch.nn.Module): The model to be trained. + args (argparse.Namespace): Parsed command line arguments. + + Returns: + tuple[float, float]: Best accuracy and best AUC achieved during training. + """ + learning_rate = args.learning_rate + loss_function_ts = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + output_dir = args.output_model_path + training.create_output_directory(output_dir) + + best_accuracy, best_auc = training.trainer( + train_dataloader, + validation_dataloader, + model, + loss_function_ts, + optimizer, + epochs=int(args.max_epochs), + root_dir=output_dir, + ) + return best_accuracy, best_auc + + +def main(): + parser = get_parser() + args, _ = parser.parse_known_args() + logger.info("Parsed arguments: %s", args) + + set_logging_parameters( + task_type=args.task_name, + acft_custom_dimensions={ + LoggingLiterals.PROJECT_NAME: PROJECT_NAME, + LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION, + LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME + }, + azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS, + ) + train_data, validation_data = load_data(args.train_data_path, args.validation_data_path, + TRAIN_EMBEDDING_FILE_NAME, VALIDATION_EMBEDDING_FILE_NAME) + model = initialize_model(args) + train_dataloader, validation_dataloader = prepare_dataloaders(train_data, validation_data, args) + best_accuracy, best_auc = train_model(train_dataloader, validation_dataloader, model, args) + print(f"Best Accuracy of the Adaptor: {best_accuracy:.4f}") + print(f"Best AUC of the Adaptor: {best_auc:.4f}") + + +if __name__ == "__main__": + main() + +# Example command to run this script: +# python medimage_train.py --task_name "AdapterTrain" --train_data_path "/home/healthcare-ai/train_merged.pkl" --validation_data_path "/home/healthcare-ai/val_merged.pkl" --train_dataloader_batch_size 8 --validation_dataloader_batch_size 1 --train_dataloader_workers 2 --validation_dataloader_workers 2 --output_classes 5 --hidden_dimensions 512 --input_channels 1024 --learning_rate 0.0003 --max_epochs 10 --output_model_path "/home/healthcare-ai/" diff --git a/assets/training/finetune_acft_image/src/medimage_insight_adapter_preprocess/medimage_datapreprocess.py b/assets/training/finetune_acft_image/src/medimage_insight_adapter_preprocess/medimage_datapreprocess.py new file mode 100644 index 0000000000..285335858a --- /dev/null +++ b/assets/training/finetune_acft_image/src/medimage_insight_adapter_preprocess/medimage_datapreprocess.py @@ -0,0 +1,280 @@ +import argparse +from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals +from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException +from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError +from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import ( + swallow_all_exceptions, +) +from azureml._common._error_definition.azureml_error import AzureMLError + +from azureml.acft.contrib.hf import VERSION, PROJECT_NAME +from azureml.acft.contrib.hf.nlp.constants.constants import LOGS_TO_BE_FILTERED_IN_APPINSIGHTS +import mlflow +import pandas as pd +import torch +import os +from classification_demo.MedImageInsight import medimageinsight_package +from classification_demo.adaptor_training import training +import matplotlib.pyplot as plt +import SimpleITK as sitk +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score + +# Suppress SimpleITK warnings +sitk.ProcessObject_SetGlobalWarningDisplay(False) + + +COMPONENT_NAME = "ACFT-MedImage-Embedding-Generator" +TRAIN_EMBEDDING_FILE_NAME = "train_embeddings.pkl" +VALIDATION_EMBEDDING_FILE_NAME = "validation_embeddings.pkl" + + +logger = get_logger_app("azureml.acft.contrib.hf.scripts.src.process_embedding.embeddings_generator") +''' +Input Arguments: endpoint_url, endpoint_key, zeroshot_path, test_train_split_pkl_path +''' + + +def get_parser(): + """ + Add arguments and returns the parser. Here we add all the arguments for all the tasks. + + Those arguments that are not relevant for the input task should be ignored. + """ + parser = argparse.ArgumentParser(description='Process medical images and get embeddigns', allow_abbrev=False) + + parser.add_argument( + '--zeroshot_path', + type=str, + required=True, + help='The path to the zeroshot dataset' + ) + parser.add_argument( + '--mlflow_model_path', + type=str, + required=True, + help='The path to the MLflow model' + ) + parser.add_argument( + '--test_train_split_csv_path', + type=str, + required=True, + help='The path to the test/train split CSV file' + ) + parser.add_argument( + "--task_name", + type=str, + required=True, + help="The name of the task to be executed", + ) + parser.add_argument( + "--output_train_pkl", + type=str, + help="Output train PKL file path", + ) + parser.add_argument( + "--output_validation_pkl", + type=str, + help="Output validation PKL file path", + ) + + return parser + + +def load_csv_files(test_train_split_csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Load training and validation CSV files into DataFrames. + + This function loads the training and validation csv files from the specified path + and returns them as pandas DataFrames. + + Args: + test_train_split_csv_path (str): The path to the directory containing the test/train split CSV files. + + Returns: + tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training and validation DataFrames. + """ + train_csv_path = f"{test_train_split_csv_path}/adaptor_tutorial_train_split.csv" + val_csv_path = f"{test_train_split_csv_path}/adaptor_tutorial_test_split.csv" + logger.info("Train CSV path: %s", train_csv_path) + logger.info("Validation CSV path: %s", val_csv_path) + + train_df = pd.read_csv(train_csv_path) + val_df = pd.read_csv(val_csv_path) + logger.info("Loaded training and validation CSV files into DataFrames") + return train_df, val_df + + +def create_features_dataframe(image_embedding_dict: dict) -> pd.DataFrame: + """ + Create a DataFrame from image embeddings. + + This function creates a DataFrame from the provided image embedding dictionary. + The DataFrame contains two columns: "Name" and "features". + + Args: + image_embedding_dict (dict): A dictionary containing image embeddings. + + Returns: + pd.DataFrame: A DataFrame containing the image features. + """ + df_features = pd.DataFrame( + { + "Name": list(image_embedding_dict.keys()), + "features": [v["image_feature"] for v in image_embedding_dict.values()], + } + ) + logger.info("Created DataFrame for image features") + return df_features + + +def generate_image_embeddings(medimageinsight: medimageinsight_package, zeroshot_path: str) -> dict: + """ + Generate image embeddings using the MedImageInsight package. + + This function generates image embeddings for the provided zeroshot path using the MedImageInsight package. + + Args: + medimageinsight (medimageinsight_package): An instance of the MedImageInsight package. + zeroshot_path (str): The path to the zeroshot data. + + Returns: + dict: A dictionary containing the image embeddings. + """ + image_embedding_dict, _ = medimageinsight.generate_embeddings( + data={"image": zeroshot_path, "text": None, "params": {"get_scaling_factor": False}} + ) + logger.info("Generated embeddings for images") + return image_embedding_dict + + +def initialize_medimageinsight(mlflow_model_path: str) -> medimageinsight_package: + """ + Initialize the MedImageInsight package. + + This function initializes the MedImageInsight package using the provided MLflow model path. + + Args: + mlflow_model_path (str): The path to the MLflow model. + + Returns: + medimageinsight_package: An instance of the MedImageInsight package. + """ + medimageinsight = medimageinsight_package( + option="run_local", + mlflow_model_path=mlflow_model_path, + ) + logger.info("Initialized MedImageInsight package") + return medimageinsight + + +def merge_dataframes(train_df: pd.DataFrame, val_df: pd.DataFrame, + df_features: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Merge training and validation DataFrames with features DataFrame. + + This function merges the provided training and validation DataFrames with the features DataFrame + based on the "Name" column. + + Args: + train_df (pd.DataFrame): The training DataFrame. + val_df (pd.DataFrame): The validation DataFrame. + df_features (pd.DataFrame): The features DataFrame containing image features. + + Returns: + tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the merged training and validation DataFrames. + """ + train_merged = pd.merge(train_df, df_features, on="Name", how="inner") + val_merged = pd.merge(val_df, df_features, on="Name", how="inner") + logger.info("Merged training and validation DataFrames with features DataFrame") + return train_merged, val_merged + + +def save_merged_dataframes(train_merged: pd.DataFrame, val_merged: pd.DataFrame, output_train_pkl_path: str, + output_validation_pkl_path: str, train_embedding_file_name: str, + validation_embedding_file_name: str) -> None: + """ Save merged DataFrames to PKL files. + + This function saves the provided training and validation merged DataFrames + to the specified PKL file paths with the given file names. It also creates + the directories if they do not exist. + + Args: + train_merged (pd.DataFrame): The merged training DataFrame to be saved. + val_merged (pd.DataFrame): The merged validation DataFrame to be saved. + output_train_pkl_path (str): The directory path where the training PKL file will be saved. + output_validation_pkl_path (str): The directory path where the validation PKL file will be saved. + train_embedding_file_name (str): The file name for the training PKL file. + validation_embedding_file_name (str): The file name for the validation PKL file. + Returns: + None + """ + os.makedirs(output_train_pkl_path, exist_ok=True) + os.makedirs(output_validation_pkl_path, exist_ok=True) + + train_merged.to_pickle(os.path.join(output_train_pkl_path, train_embedding_file_name)) + val_merged.to_pickle(os.path.join(output_validation_pkl_path, validation_embedding_file_name)) + logger.info("Saved merged DataFrames to PKL files") + + +def process_embeddings(args): + """ + Process medical image embeddings and save the results to PKL files. + This function initializes the medimageinsight object, generates image embeddings, + creates a features dataframe, loads train and validation PKL files, merges the dataframes, + and saves the merged dataframes to specified output PKL files. + Args: + args (Namespace): A namespace object containing the following attributes: + - mlflow_model_path (str): The path to the MLflow model. + - zeroshot_path (str): The path to the zeroshot data. + - output_train_pkl (str): The path to save the output training PKL file. + - output_validation_pkl (str): The path to save the output validation PKL file. + - test_train_split_csv_path (str): The path to the test/train split CSV file. + Returns: + None + """ + + model_path = args.mlflow_model_path + zeroshot_path = args.zeroshot_path + output_train_pkl = args.output_train_pkl + output_validation_pkl = args.output_validation_pkl + test_train_split_csv_path = args.test_train_split_csv_path + logger.info("Zeroshot path: %s", zeroshot_path) + logger.info("Test/train split PKL path: %s", test_train_split_csv_path) + medimageinsight = initialize_medimageinsight(model_path) + image_embedding_dict = generate_image_embeddings(medimageinsight, zeroshot_path) + df_features = create_features_dataframe(image_embedding_dict) + + train_df, val_df = load_csv_files(test_train_split_csv_path) + train_merged, val_merged = merge_dataframes(train_df, val_df, df_features) + + save_merged_dataframes(train_merged, val_merged, output_train_pkl, output_validation_pkl, + TRAIN_EMBEDDING_FILE_NAME, VALIDATION_EMBEDDING_FILE_NAME) + logger.info("Processing medical images and getting embeddings completed") + + +def main(): + parser = get_parser() + args, _ = parser.parse_known_args() + logger.info("Parsed arguments: %s", args) + + set_logging_parameters( + task_type=args.task_name, + acft_custom_dimensions={ + LoggingLiterals.PROJECT_NAME: PROJECT_NAME, + LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION, + LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME + }, + azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS, + ) + logger.info("Logging parameters set") + + process_embeddings(args) + + +if __name__ == '__main__': + main() + +''' +python medimage_datapreprocess.py --task_name "MedEmbedding" --mlflow_model_path "/mnt/model/MedImageInsight/mlflow_model_folder" --zeroshot_path "/home/healthcare-ai/medimageinsight-zeroshot/" --test_train_split_csv_path "/home/healthcare-ai/medimageinsight/classification_demo/data_input/" --output_train_pkl "/home/healthcare-ai/" --output_validation_pkl "/home/healthcare-ai/" + +''' diff --git a/assets/training/finetune_acft_image/src/medimage_insight_finetune/medimage_embedding_finetune.py b/assets/training/finetune_acft_image/src/medimage_insight_finetune/medimage_embedding_finetune.py new file mode 100644 index 0000000000..d8e28dee3c --- /dev/null +++ b/assets/training/finetune_acft_image/src/medimage_insight_finetune/medimage_embedding_finetune.py @@ -0,0 +1,403 @@ +import argparse +from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals +from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException +from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError +from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import ( + swallow_all_exceptions, +) +from azureml._common._error_definition.azureml_error import AzureMLError +from azureml.acft.contrib.hf import VERSION, PROJECT_NAME +from azureml.acft.contrib.hf.nlp.constants.constants import LOGS_TO_BE_FILTERED_IN_APPINSIGHTS +import faulthandler +import os +import re +import sys +import torch +import yaml +from typing import Any, Dict, List, Tuple + +from MainzTrain.Trainers.MainzTrainer import MainzTrainer +from MainzTrain.Utils.Timing import Timer + +COMPONENT_NAME = "ACFT-MedImage-Embedding-Finetune" +logger = get_logger_app("azureml.acft.contrib.hf.scripts.src.train.medimage_embedding_finetune") + + +def add_env_parser_to_yaml() -> None: + """ + Adding ability of resolving environment variables to the yaml SafeLoader. + Environment variables in the form of "${}" can be resolved as strings. + If the is not in the env, itself would be used. + + E.g.: + config: + username: admin + password: ${SERVICE_PASSWORD} + service: https://${SERVICE_HOST}/service + """ + loader = yaml.SafeLoader + env_pattern = re.compile(r".*?\${(.*?)}.*?") + + def env_constructor(loader: yaml.Loader, node: yaml.Node) -> str: + value = loader.construct_scalar(node) + for group in env_pattern.findall(value): + value = value.replace(f"${{{group}}}", os.environ.get(group, group)) + return value + + yaml.add_implicit_resolver("!ENV", env_pattern, Loader=loader) + yaml.add_constructor("!ENV", env_constructor, Loader=loader) + + +def load_config_dict_to_opt(opt: Dict[str, Any], config_dict: Dict[str, Any], splitter: str = '.', log_new: bool = False) -> None: + """ + Args: + opt (Dict[str, Any]): The dictionary to be updated with values from config_dict. + config_dict (Dict[str, Any]): The dictionary containing configuration key-value pairs to load into opt. + splitter (str, optional): The delimiter used to split keys in config_dict. Defaults to '.'. + log_new (bool, optional): If True, logs new keys added to opt. Defaults to False. + Raises: + TypeError: If config_dict is not a dictionary. + AssertionError: If the structure of keys in config_dict does not match the expected format. + Returns: + None + Load the key, value pairs from config_dict to opt, overriding existing values in opt + if there is any. + """ + if not isinstance(config_dict, dict): + raise TypeError("Config must be a Python dictionary") + for k, v in config_dict.items(): + k_parts = k.split(splitter) + pointer = opt + for k_part in k_parts[:-1]: + if '[' in k_part and ']' in k_part: + # for the format "a.b[0][1].c: d" + k_part_splits = k_part.split('[') + k_part = k_part_splits.pop(0) + pointer = pointer[k_part] + for i in k_part_splits: + assert i[-1] == ']' + pointer = pointer[int(i[:-1])] + else: + if k_part not in pointer: + pointer[k_part] = {} + pointer = pointer[k_part] + assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." + if '[' in k_parts[-1] and ']' in k_parts[-1]: + k_part_splits = k_parts[-1].split('[') + k_part = k_part_splits.pop(0) + pointer = pointer[k_part] + for i in k_part_splits[:-1]: + assert i[-1] == ']' + pointer = pointer[int(i[:-1])] + assert k_part_splits[-1][-1] == ']' + ori_value = pointer[int(k_part_splits[-1][:-1])] + pointer[int(k_part_splits[-1][:-1])] = v + else: + ori_value = pointer.get(k_parts[-1]) + pointer[k_parts[-1]] = v + if ori_value: + logger.info(f"Overrided {k} from {ori_value} to {v}") + elif log_new: + logger.info(f"Added {k}: {v}") + + +def load_opt_from_config_files(conf_files: List[str]) -> Dict[str, Any]: + """ + Load opt from the config files, settings in later files can override those in previous files. + + Args: + conf_files (list): a list of config file paths + + Returns: + dict: a dictionary of opt settings + """ + opt = {} + for conf_file in conf_files: + with open(conf_file, encoding='utf-8') as f: + # config_dict = yaml.safe_load(f) + config_dict = yaml.unsafe_load(f) + + load_config_dict_to_opt(opt, config_dict) + + return opt + + +def get_parser() -> argparse.ArgumentParser: + """ + Add arguments and returns the parser. Here we add all the arguments for all the tasks. + + Those arguments that are not relevant for the input task should be ignored. + """ + parser = argparse.ArgumentParser(description='Process medical images and get embeddings', allow_abbrev=False) + parser.add_argument( + "--task_name", + type=str, + required=True, + help="The name of the task to be executed", + ) + parser.add_argument( + '--log_every', + type=int, + default=10, + help='Log every n steps.' + ) + parser.add_argument( + '--resume', + action='store_true', + help='Resume training from checkpoint.' + ) + parser.add_argument( + '--reset_data_loader', + action='store_false', + help='Reset data loader.' + ) + parser.add_argument( + '--fp16', + action='store_true', + help='Use FP16 precision.' + ) + parser.add_argument( + '--zero_stage', + type=int, + default=0, + help='ZeRO optimization stage.' + ) + parser.add_argument( + '--deepspeed', + action='store_false', + help='Use DeepSpeed optimization.' + ) + parser.add_argument( + '--save_per_optim_steps', + type=int, + default=100, + help='Save checkpoint every n optimization steps.' + ) + parser.add_argument( + '--eval_per_optim_steps', + type=int, + default=100, + help='Evaluate every n optimization steps.' + ) + parser.add_argument( + '--grad_clipping', + type=float, + default=1.0, + help='Gradient clipping value.' + ) + parser.add_argument( + '--set_sampler_epoch', + action='store_false', + help='Set sampler epoch.' + ) + parser.add_argument( + '--verbose', + action='store_true', + help='Enable verbose logging.' + ) + parser.add_argument( + '--workers', + type=int, + default=6, + help='Number of workers.' + ) + parser.add_argument( + '--pin_memory', + action='store_true', + help='Pin memory in data loader.' + ) + parser.add_argument( + '--dataset_root', + type=str, + help='Root directory of the dataset.' + ) + parser.add_argument( + '--eval_image_tsv', + type=str, + help='Path to evaluation image TSV file.' + ) + parser.add_argument( + '--eval_text_tsv', + type=str, + help='Path to evaluation text TSV file.' + ) + parser.add_argument( + '--image_tsv', + type=str, + help='Path to training image TSV file.' + ) + parser.add_argument( + '--text_tsv', + type=str, + help='Path to training text TSV file.' + ) + parser.add_argument( + '--label_file', + type=str, + help='Path to label file.' + ) + parser.add_argument( + '--binary_metrics', + type=int, + default=1, + help='Use binary metrics.' + ) + parser.add_argument( + '--cweight_file', + type=str, + help='Path to class weight file.' + ) + parser.add_argument( + '--zs_mode', + type=int, + default=2, + help='Zero-shot mode.' + ) + parser.add_argument( + '--zs_weight', + type=float, + default=1.0, + help='Zero-shot weight.' + ) + parser.add_argument( + '--knn', + type=int, + default=200, + help='Number of nearest neighbors for KNN.' + ) + parser.add_argument( + '--eval_zip_file', + type=str, + help='Path to evaluation zip file.' + ) + parser.add_argument( + '--eval_zip_map_file', + type=str, + help='Path to evaluation zip map file.' + ) + parser.add_argument( + '--eval_label_file', + type=str, + help='Path to evaluation label file.' + ) + parser.add_argument( + '--batch_size_per_gpu', + type=int, + default=2, + help='Batch size per GPU.' + ) + parser.add_argument( + '--max_num_epochs', + type=int, + default=10000, + help='Maximum number of epochs.' + ) + parser.add_argument( + '--gradient_accumulate_step', + type=int, + default=1, + help='Number of gradient accumulation steps.' + ) + parser.add_argument( + '--save_dir', + type=str, + help='Directory to save the output.' + ) + parser.add_argument( + '--conf_files', + nargs='+', + required=True, + help='Path(s) to the MainzTrain config file(s).' + ) + + return parser + + +def load_opt_command(cmdline_args: argparse.Namespace) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Load and combine command line arguments with configuration options. + This function processes command line arguments, loads configuration options + from specified configuration files, and combines them into a single dictionary. + Args: + cmdline_args (argparse.Namespace): The command line arguments passed to the script. + Returns: + Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the combined options dictionary + and the processed command line arguments dictionary. + """ + add_env_parser_to_yaml() + opt = load_opt_from_config_files(cmdline_args.conf_files) + cmdline_args = vars(cmdline_args) + cmdline_args = {k.upper() if k != 'conf_files' else k: v for k, v in cmdline_args.items()} + + load_config_dict_to_opt(opt, cmdline_args) + + logger.info("Command line arguments:") + for key, value in cmdline_args.items(): + logger.info(f"{key}: {value}") + + # combine cmdline_args into opt dictionary + for key, val in cmdline_args.items(): + if val is not None: + opt[key] = val + + return opt, cmdline_args + + +def main(args: List[str] = None) -> None: + ''' + Main execution point for PyLearn. + ''' + + logger.info('MainzTrain started') + parser = get_parser() + args = parser.parse_args(args) + set_logging_parameters( + task_type=args.task_name, + acft_custom_dimensions={ + LoggingLiterals.PROJECT_NAME: PROJECT_NAME, + LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION, + LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME + }, + azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS, + ) + opt, _ = load_opt_command(args) + command = 'train' + + if opt.get('SAVE_TIMER_LOG', False): + Timer.setEnabled(True) + + trainer = MainzTrainer(opt) + + if opt.get('DEBUG_DUMP_TRACEBACKS_INTERVAL', 0) > 0: + timeout = opt['DEBUG_DUMP_TRACEBACKS_INTERVAL'] + traceback_dir = trainer.log_folder if trainer.log_folder is not None else trainer.save_folder + traceback_file = os.path.join(traceback_dir, f"tracebacks_{opt['rank']}.txt") + faulthandler.dump_traceback_later(timeout, repeat=True, file=open(traceback_file, 'w')) + + splits = opt.get('EVALUATION_SPLITS', ["dev", "test"]) + + logger.info(f"Running command: {command}") + with torch.autograd.profiler.profile(use_cuda=True, enabled=opt.get('AUTOGRAD_PROFILER', False) and opt['rank'] == 0) as prof: + if command == "train": + trainer.train() + elif command == "evaluate": + trainer.eval(splits=splits) + elif command == 'train-and-evaluate': + best_checkpoint_path = trainer.train() + opt['PYLEARN_MODEL'] = best_checkpoint_path + trainer.eval(splits=splits) + else: + raise ValueError(f"Unknown command: {command}") + + if opt.get('AUTOGRAD_PROFILER', False): + logger.info(prof.key_averages().table(sort_by="cuda_time_total")) + logger.info(prof.total_average()) + + if opt.get('SAVE_TIMER_LOG', False): + timer_log_dir = trainer.log_folder if trainer.log_folder is not None else trainer.save_folder + timer_log_file = os.path.join(timer_log_dir, f"timer_log_{opt['rank']}.txt") + Timer.timer_report(timer_log_file) + + +if __name__ == "__main__": + main()