Skip to content

Commit

Permalink
ready for test
Browse files Browse the repository at this point in the history
  • Loading branch information
SkalskiP committed Sep 4, 2024
1 parent 36251e5 commit 5f025af
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 82 deletions.
70 changes: 70 additions & 0 deletions maestro/trainer/common/utils/metrics_tracing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from collections import defaultdict
from typing import Dict, Tuple, List

import matplotlib.pyplot as plt


class MetricsTracker:

Expand All @@ -26,3 +29,70 @@ def get_metric_values(
if with_index:
return self._metrics[metric]
return [value[2] for value in self._metrics[metric]]


def aggregate_by_epoch(metric_values: List[Tuple[int, int, float]]) -> Dict[int, float]:
epoch_data = defaultdict(list)
for epoch, step, value in metric_values:
epoch_data[epoch].append(value)
avg_per_epoch = {
epoch: sum(values) / len(values)
for epoch, values
in epoch_data.items()
}
return avg_per_epoch


def save_metric_plots(
training_tracker: MetricsTracker,
validation_tracker: MetricsTracker,
output_dir: str
):
training_metrics = training_tracker.describe_metrics()
validation_metrics = validation_tracker.describe_metrics()
all_metrics = set(training_metrics + validation_metrics)

for metric in all_metrics:
plt.figure(figsize=(8, 6))

if metric in training_metrics:
training_values = training_tracker.get_metric_values(
metric=metric, with_index=True)
training_avg_values = aggregate_by_epoch(training_values)
training_epochs = sorted(training_avg_values.keys())
training_vals = [training_avg_values[epoch] for epoch in training_epochs]
plt.plot(
x=training_epochs,
y=training_vals,
label=f'Training {metric}',
marker='o',
linestyle='-',
color='blue'
)

if metric in validation_metrics:
validation_values = validation_tracker.get_metric_values(
metric=metric, with_index=True)
validation_avg_values = aggregate_by_epoch(validation_values)
validation_epochs = sorted(validation_avg_values.keys())
validation_vals = [
validation_avg_values[epoch]
for epoch
in validation_epochs
]
plt.plot(
x=validation_epochs,
y=validation_vals,
label=f'Validation {metric}',
marker='o',
linestyle='--',
color='orange'
)

plt.title(f'{metric.capitalize()} over Epochs')
plt.xlabel('Epoch')
plt.ylabel(f'{metric.capitalize()} Value')
plt.legend()
plt.grid(True)
plt.savefig(f'{output_dir}/{metric}_plot.png')
plt.close()
43 changes: 0 additions & 43 deletions maestro/trainer/models/florence_2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import re
from typing import List, Tuple

import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
Expand All @@ -13,7 +12,6 @@

from maestro.trainer.common.data_loaders.datasets import DetectionDataset
from maestro.trainer.common.utils.file_system import save_json
from maestro.trainer.common.utils.metrics_tracing import MetricsTracker
from maestro.trainer.models.florence_2.data_loading import prepare_detection_dataset


Expand Down Expand Up @@ -188,44 +186,3 @@ def dump_visualised_samples(
concatenated = cv2.hconcat([target_image, prediction_image])
target_image_path = os.path.join(target_dir, image_name)
cv2.imwrite(target_image_path, concatenated)


def summarise_training_metrics(
training_metrics_tracker: MetricsTracker,
validation_metrics_tracker: MetricsTracker,
training_dir: str,
) -> None:
summarise_metrics(metrics_tracker=training_metrics_tracker, training_dir=training_dir, split_name="train")
summarise_metrics(metrics_tracker=validation_metrics_tracker, training_dir=training_dir, split_name="valid")


def summarise_metrics(
metrics_tracker: MetricsTracker,
training_dir: str,
split_name: str,
) -> None:
plots_dir_path = os.path.join(training_dir, "metrics", split_name)
os.makedirs(plots_dir_path, exist_ok=True)
for metric_name in metrics_tracker.describe_metrics():
plot_path = os.path.join(plots_dir_path, f"metric_{metric_name}_plot.png")
plt.clf()
metric_values_with_index = metrics_tracker.get_metric_values(
metric=metric_name,
with_index=True,
)
xs = np.arange(0, len(metric_values_with_index))
xticks_xs, xticks_labels = [], []
previous = None
for v, x in zip(metric_values_with_index, xs):
if v[0] != previous:
xticks_xs.append(x)
xticks_labels.append(v[0])
previous = v[0]
ys = [e[2] for e in metric_values_with_index]
plt.scatter(xs, ys, marker="x")
plt.plot(xs, ys, linestyle="dashed", linewidth=0.3)
plt.title(f"Value of {metric_name} for {split_name} set")
plt.xticks(xticks_xs, labels=xticks_labels)
plt.xlabel("Epochs")
plt.savefig(plot_path, dpi=120)
plt.clf()
76 changes: 38 additions & 38 deletions maestro/trainer/models/florence_2/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor, get_scheduler

from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, DEFAULT_CUDA_DEVICE
from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, \
DEFAULT_CUDA_DEVICE
from maestro.trainer.common.utils.leaderboard import CheckpointsLeaderboard
from maestro.trainer.common.utils.metrics_tracing import MetricsTracker
from maestro.trainer.common.utils.metrics_tracing import MetricsTracker, \
save_metric_plots
from maestro.trainer.common.utils.reproducibility import make_it_reproducible
from maestro.trainer.models.florence_2.data_loading import prepare_data_loaders
from maestro.trainer.models.florence_2.metrics import prepare_detection_training_summary, summarise_training_metrics
from maestro.trainer.models.florence_2.metrics import prepare_detection_training_summary
from maestro.trainer.models.paligemma.training import LoraInitLiteral


Expand Down Expand Up @@ -102,41 +104,39 @@ def train(configuration: TrainingConfiguration) -> None:
validation_metrics_tracker=validation_metrics_tracker,
)

return training_metrics_tracker, validation_metrics_tracker

# best_model_path = checkpoints_leaderboard.get_best_model()
# print(f"Loading best model from {best_model_path}")
# processor, model = load_model(
# model_id_or_path=best_model_path,
# )
# if test_loader is not None:
# run_validation_epoch(
# processor=processor,
# model=model,
# loader=test_loader,
# epoch_number=None,
# configuration=configuration,
# title="Test",
# )
# best_model_dir = os.path.join(configuration.training_dir, "best_model")
# print(f"Saving best model: {best_model_dir}")
# model.save_pretrained(best_model_dir)
# processor.save_pretrained(best_model_dir)
# summarise_training_metrics(
# training_metrics_tracker=training_metrics_tracker,
# validation_metrics_tracker=validation_metrics_tracker,
# training_dir=configuration.training_dir,
# )
# for split_name in ["valid", "test"]:
# prepare_detection_training_summary(
# processor=processor,
# model=model,
# dataset_location=configuration.dataset_location,
# split_name=split_name,
# training_dir=configuration.training_dir,
# num_samples_to_visualise=configuration.num_samples_to_visualise,
# device=configuration.device,
# )
best_model_path = checkpoints_leaderboard.get_best_model()
print(f"Loading best model from {best_model_path}")
processor, model = load_model(
model_id_or_path=best_model_path,
)
if test_loader is not None:
run_validation_epoch(
processor=processor,
model=model,
loader=test_loader,
epoch_number=None,
configuration=configuration,
title="Test",
)
best_model_dir = os.path.join(configuration.training_dir, "best_model")
print(f"Saving best model: {best_model_dir}")
model.save_pretrained(best_model_dir)
processor.save_pretrained(best_model_dir)
save_metric_plots(
training_tracker=training_metrics_tracker,
validation_tracker=validation_metrics_tracker,
output_dir=configuration.training_dir,
)
for split_name in ["valid", "test"]:
prepare_detection_training_summary(
processor=processor,
model=model,
dataset_location=configuration.dataset_location,
split_name=split_name,
training_dir=configuration.training_dir,
num_samples_to_visualise=configuration.num_samples_to_visualise,
device=configuration.device,
)


def load_model(
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ torch~=2.4.0
accelerate~=0.33.0
sentencepiece~=0.2.0
peft~=0.12.0
flash-attn~=2.6.3
flash-attn~=2.6.3 # does not work on mac
einops~=0.8.0
timm~=1.0.9

0 comments on commit 5f025af

Please sign in to comment.