Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise script/task to refine logic for updating model #46

Merged
merged 11 commits into from
Feb 6, 2025
2 changes: 1 addition & 1 deletion src/htr2hpc/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def download_file(

outfile = save_location / filename
# report on filename and size based on content-length header
logger.info(
logger.debug(
f"Saving as {filename} ({humanize.naturalsize(content_length)})"
)
with outfile.open("wb") as filehandle:
Expand Down
68 changes: 48 additions & 20 deletions src/htr2hpc/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def start_remote_training(

# script output is stored in result.stdout/result.stderr
# add output to task report
task_report.append(f"remote script output:\n\n{result.stdout}")
task_report.append(
f"\n\nremote script output:\n\n{result.stdout}\n\n{result.stderr}\n\n"
)
if "Slurm job was cancelled" in result.stdout:
task_report.cancel("(slurm cancellation)")
# notify the user of the error
Expand Down Expand Up @@ -164,6 +166,14 @@ def segtrain(
task_group = TaskGroup.objects.get(pk=task_group_pk)
task_report = task_group.taskreport_set.first()

# if the model is older than the task group, then we infer that
# overwrite was requested on the form (update an existing model)
model_overwrite = model.version_created_at < task_group.created_at
if model_overwrite:
logger.debug(
f"Inferring model overwrite requested based on model/task creation dates (model:{model.version_created_at} task:{task_group.created_at})"
)

# mark the model as being in training
# would be nice if the script could handle, but that field is listed
# as read only in the api
Expand Down Expand Up @@ -204,9 +214,15 @@ def segtrain(
# model is technically optional for this task but it should
# always be passed in by escriptorium calling code
if model_pk:
arg_options.append(f"--model {model_pk}")
# eScriptorium behavior is to create a new model that will be
# updated after training, so if we have a model we always want --update
arg_options.append(f"--model {model_pk} --update")
# updated after training, so if we have a model we always want to update
# the model; but when overwriting an existing model, only update if improved
if model_overwrite:
arg_options.append("--update-if-improved")
else:
arg_options.append("--update")

opts = " ".join(arg_options)

cmd = f"htr2hpc-train segmentation {site_url} {outdir} {opts}"
Expand All @@ -224,7 +240,9 @@ def segtrain(
if success:
# check for case where training completed but model did not improve.
# i.e., no new model was uploaded or cloned model is still parent file
if model.file is None or model.file == model.parent.file:
if model.file is None or (
model.parent is not None and model.file == model.parent.file
):
user.notify(
"Training completed but did not result in an improved model",
id="training-warning",
Expand All @@ -240,9 +258,8 @@ def segtrain(
},
)

# delete the empty model unless it is a pre-existing one
# (i.e., overwrite was requested)
if task_group.created_at < model.version_created_at:
# delete the original model unless overwrite was requested
if not model_overwrite:
model.delete()

else:
Expand All @@ -261,10 +278,9 @@ def segtrain(
# if training did not suceeed:

# escriptorium task deletes the model if there is an error;
# we want to do that, but check if the model was created after
# this task started so we don't delete a pre-existing model
# we want to do that, but don't delete a pre-existing model
# when overwrite was requested
if model.file is None or task_group.created_at < model.version_created_at:
if not model_overwrite:
model.delete()
return

Expand Down Expand Up @@ -321,6 +337,14 @@ def train(
task_group = TaskGroup.objects.get(pk=task_group_pk)
task_report = task_group.taskreport_set.first()

# if the model is older than the task group, then we infer that
# overwrite was requested on the form (update an existing model)
model_overwrite = model.version_created_at < task_group.created_at
if model_overwrite:
logger.debug(
f"Inferring model overwrite requested based on model/task creation dates (model:{model.version_created_at} task:{task_group.created_at})"
)

# mark the model as being in training
# would be nice if the script could handle, but that field is listed
# as read only in the api
Expand Down Expand Up @@ -357,9 +381,14 @@ def train(
# model is technically optional for this task but it should
# always be passed in by escriptorium calling code
if model_pk:
arg_options.append(f"--model {model_pk}")
# eScriptorium behavior is to create a new model that will be
# updated after training, so if we have a model we always want --update
arg_options.append(f"--model {model_pk} --update")
# updated after training, so if we have a model we always want to update
# the model; but when overwriting an existing model, only update if improved
if model_overwrite:
arg_options.append("--update-if-improved")
else:
arg_options.append("--update")

opts = " ".join(arg_options)

Expand All @@ -380,7 +409,9 @@ def train(

# check for case where training completed but model did not improve.
# i.e., no new model was uploaded or cloned model is still parent file
if model.file is None or model.file == model.parent.file:
if model.file is None or (
model.parent is not None and model.file == model.parent.file
):
user.notify(
"Training completed but did not result in an improved model",
id="training-warning",
Expand All @@ -396,9 +427,8 @@ def train(
},
)

# delete the empty model unless it is a pre-existing one
# (i.e., overwrite was requested)
if task_group.created_at < model.version_created_at:
# delete the original model unless overwrite was requested
if not model_overwrite:
model.delete()

else:
Expand All @@ -416,10 +446,8 @@ def train(

else:
# escriptorium task deletes the model if there is an error;
# we want to do that, but check if the model was created after
# this task started so we don't delete a pre-existing model
# when overwrite was requested
if model.file is None or task_group.created_at < model.version_created_at:
# we want to do that, unless overwrite of an existing model was requested
if not model_overwrite:
model.delete()
return

Expand Down
60 changes: 36 additions & 24 deletions src/htr2hpc/train/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ def get_segmentation_data(
for line in part.lines
]

logger.info(f"Document {document_id} part {part_id}: {len(baselines)} baselines")

logger.info(
logger.debug(f"Document {document_id} part {part_id}: {len(baselines)} baselines")
logger.debug(
f"Document {document_id} part {part_id}: {len(part.regions)} regions, {len(regions.keys())} block types"
)

Expand Down Expand Up @@ -239,37 +238,49 @@ def get_training_data(
def get_best_model(
model_dir: pathlib.Path, original_model: pathlib.Path = None
) -> pathlib.Path | None:
# kraken should normally identify the best model for us
best = list(model_dir.glob("*_best.mlmodel"))
# if one was found, return it
if best:
print(f"Using kraken identified best model {best[0].name}")
return best[0]

# if not, try to find one based on accuracy metadata
"""Find the best model in the specified `model_dir` directory.
By default, looks for a file named `*_best.mlmodel`. If no best model
is found by filename, looks for best model based on accuracy score
in kraken metadata. When `original_model` is specified, accuracy
must be better than the original to be considered 'best'.
"""
best_accuracy = 0
# when original model is specified, initialize
# best accuracy value from that model
print(f"Looking for best model by accuracy")
if original_model:
best = original_model
best_accuracy = get_model_accuracy(original_model)
print(
f"Must be better than original model {original_model.name} accuracy {best_accuracy:0.3f}"
)
for model in model_dir.glob("*.mlmodel"):
accuracy = get_model_accuracy(model)
print(f"model: {model.name} accuracy: {accuracy:0.3f}")
# if accuracy is better than our current best, this model is new best
# kraken should normally identify the best model for us
best = list(model_dir.glob("*_best.mlmodel"))
# if one was found, return it
if best:
accuracy = get_model_accuracy(best[0])
if accuracy > best_accuracy:
best = model
best_accuracy = accuracy
print(f"Using kraken identified best model {best[0].name}")
return best[0]
else:
print("Training did not improve on original model")

# if we found a model better than the original, return it
if best and best != original_model:
return best
if best == original_model:
print("Training did not improve on original model")
# if not, try to find one based on accuracy metadata
else:
if original_model:
best = original_model
print(f"Looking for best model by accuracy")
for model in model_dir.glob("*.mlmodel"):
accuracy = get_model_accuracy(model)
print(f"model: {model.name} accuracy: {accuracy:0.3f}")
# if accuracy is better than our current best, this model is new best
if accuracy > best_accuracy:
best = model
best_accuracy = accuracy

# if we found a model better than the original, return it
if best and best != original_model:
return best
if best == original_model:
print("Training did not improve on original model")


def upload_models(
Expand All @@ -282,6 +293,7 @@ def upload_models(

# segtrain creates models based on modelname with _0, _1, _2 ... _49
# sort numerically on the latter portion of the name
# NOTE: this older logic breaks with new -q early option that creates a _best model
modelfiles = sorted(
model_dir.glob("*.mlmodel"), key=lambda path: int(path.stem.split("_")[-1])
)
Expand Down
55 changes: 45 additions & 10 deletions src/htr2hpc/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import os
import sys
from enum import Enum

import logging
import pathlib
Expand Down Expand Up @@ -43,6 +44,16 @@ class JobCancelled(Exception):
"Custom exception for when slurm job was cancelled"


class UpdateMode(Enum):
NEVER = 0
ALWAYS = 1
IF_IMPROVED = 2

def __bool__(self):
# override so boolean check of never will evaluate to false
return self != UpdateMode.NEVER


@dataclass
class TrainingManager:
base_url: str
Expand All @@ -55,7 +66,7 @@ class TrainingManager:
parts: Optional[intspan] = None
model_id: Optional[int] = None
task_report_id: Optional[int] = None
update: bool = False
update: UpdateMode = UpdateMode.NEVER
transcription_id: Optional[int] = None
existing_data: bool = False
show_progress: bool = True
Expand Down Expand Up @@ -173,8 +184,12 @@ def monitor_slurm_job(self, job_id):
print(f"Job output is in {job_output}")

if self.task_report_id is not None:
with open(job_output) as job_output_file:
slurm_output = job_output_file.read()
try:
with open(job_output) as job_output_file:
slurm_output = job_output_file.read()
except FileNotFoundError:
print(f"File {job_output} not found.")
slurm_output = ""

# get current task report so we can add to messages
task_report = self.api.task_details(self.task_report_id)
Expand Down Expand Up @@ -249,14 +264,20 @@ def upload_best(self):
else:
model_id = None

abs_model_file = self.model_file.absolute() if self.model_file else None
# in certain cases we only want to upload the model to
# eScriptorium if it has improved on the original model;
# pass in original model for minimum accuracy comparison
# when update mode is update-if-improved
compare_model_file = None
if self.update == UpdateMode.IF_IMPROVED and self.model_file:
compare_model_file = self.model_file.absolute()

best_model = upload_best_model(
self.api,
self.output_modelfile.parent,
self.training_mode,
model_id=model_id,
original_model=abs_model_file,
original_model=compare_model_file,
)
if best_model:
# TODO: revise message to include info about created/updated model id ##
Expand Down Expand Up @@ -330,12 +351,24 @@ def main():
type=int,
dest="model_id",
)
parser.add_argument(
update_group = parser.add_mutually_exclusive_group()
update_group.add_argument(
"-u",
"--update",
help="Update the specified model with the best model from training (requires --model)",
action="store_true",
default=False,
dest="update",
default=UpdateMode.NEVER,
action="store_const",
const=UpdateMode.ALWAYS,
required=False,
)
update_group.add_argument(
"--update-if-improved",
help="Update the specified model with the best model from training ONLY if improved on original",
dest="update",
action="store_const",
const=UpdateMode.IF_IMPROVED,
required=False,
)
parser.add_argument(
"--model-name",
Expand Down Expand Up @@ -401,6 +434,8 @@ def main():
)
args = parser.parse_args()
# validate argument combinations

# when update or update-if-modified is specified, model is required
if args.update:
error_messages = []
if not args.model_id:
Expand Down Expand Up @@ -435,8 +470,8 @@ def main():
args.work_dir.mkdir()

logging.basicConfig(encoding="utf-8", level=logging.WARN)
logger_upscope = logging.getLogger("htr2hpc")
# logger_upscope.setLevel(logging.DEBUG)
logger_local = logging.getLogger("htr2hpc")
logger_local.setLevel(logging.INFO)
# output kraken logging details to confirm binary data looks ok
logger_kraken = logging.getLogger("kraken")
# logger_kraken.setLevel(logging.INFO)
Expand Down
Loading