forked from instructlab/training
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
265 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -335,3 +335,5 @@ run_training( | |
train_args=training_args, | ||
) | ||
``` | ||
|
||
Test update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,186 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# Standard | ||
from argparse import ArgumentParser, Namespace | ||
from base64 import b64encode | ||
from io import BytesIO | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
from subprocess import run | ||
from typing import Dict, List | ||
import json | ||
|
||
# Third Party | ||
from matplotlib import pyplot as plt | ||
from pydantic import BaseModel | ||
|
||
|
||
def main(args: Namespace): | ||
log_file = Path(args.log_file) | ||
if not log_file.exists(): | ||
raise FileNotFoundError(f'log file "{args.log_file}" does not exist') | ||
|
||
if not log_file.is_file(): | ||
raise RuntimeError(f'log file cannot be a directory: "{log_file}"') | ||
|
||
with open(Path(log_file), "r", encoding="utf-8") as infile: | ||
contents = [json.loads(l) for l in infile.read().splitlines()] | ||
class Arguments(BaseModel): | ||
log_file: str | None = None | ||
output_file: str | ||
aws_region: str | ||
bucket_name: str | ||
base_branch: str | ||
pr_number: str | ||
head_sha: str | ||
origin_repository: str | ||
|
||
loss_data = [item["total_loss"] for item in contents if "total_loss" in item] | ||
|
||
def render_image(loss_data: List[float], outfile: Path) -> str: | ||
# create the plot | ||
plt.figure() | ||
plt.plot(loss_data) | ||
plt.xlabel("Steps") | ||
plt.ylabel("Loss") | ||
plt.title("Training performance over fixed dataset") | ||
|
||
buf = BytesIO() | ||
plt.savefig(buf, format="png") | ||
buf.seek(0) | ||
if outfile.exists(): | ||
outfile.unlink() | ||
|
||
plt.savefig(outfile, format="png") | ||
|
||
|
||
def contents_from_file(log_file: Path) -> List[Dict]: | ||
if not log_file.exists(): | ||
raise FileNotFoundError(f"Log file {log_file} does not exist") | ||
if log_file.is_dir(): | ||
raise ValueError(f"Log file {log_file} is a directory") | ||
with open(log_file, "r") as f: | ||
return [json.loads(l) for l in f.read().splitlines()] | ||
|
||
|
||
def read_loss_data(log_file: Path) -> List[float]: | ||
if not log_file: | ||
raise ValueError("log_file must be provided when source is file") | ||
contents = contents_from_file(log_file) | ||
|
||
# select the loss data | ||
loss_data = [item["total_loss"] for item in contents if "total_loss" in item] | ||
|
||
if not loss_data: | ||
raise ValueError("Loss data is empty") | ||
|
||
# ensure that the loss data is valid | ||
if not all(isinstance(l, float) for l in loss_data): | ||
raise ValueError("Loss data must be a list of floats") | ||
|
||
imgb64 = b64encode(buf.read()).decode("utf-8") | ||
return loss_data | ||
|
||
output_file = Path(args.output_file) if args.output_file else None | ||
if output_file: | ||
output_file.write_text(imgb64, encoding="utf-8") | ||
|
||
def write_to_s3( | ||
file: Path, | ||
bucket_name: str, | ||
destination: str, | ||
): | ||
if not file.exists(): | ||
raise RuntimeError(f"File {file} does not exist") | ||
|
||
s3_path = f"s3://{bucket_name}/{destination}" | ||
results = run( | ||
["aws", "s3", "cp", str(file), s3_path], capture_output=True, check=True | ||
) | ||
if results.returncode != 0: | ||
raise RuntimeError(f"failed to upload to s3: {results.stderr.decode('utf-8')}") | ||
else: | ||
# just print the file without including a newline, this way it can be piped | ||
print(imgb64, end="") | ||
print(results.stdout.decode("utf-8")) | ||
|
||
|
||
def get_destination_path(base_ref: str, pr_number: str, head_sha: str): | ||
return f"pulls/{base_ref}/{pr_number}/{head_sha}/loss-graph.png" | ||
|
||
|
||
def write_md_file( | ||
output_file: Path, url: str, pr_number: str, head_sha: str, origin_repository: str | ||
): | ||
commit_url = f"https://github.com/{origin_repository}/commit/{head_sha}" | ||
md_template = f""" | ||
# Loss Graph for PR {args.pr_number} ([{args.head_sha[:7]}]({commit_url})) | ||
![Loss Graph]({url}) | ||
""" | ||
output_file.write_text(md_template, encoding="utf-8") | ||
|
||
|
||
def get_url(bucket_name: str, destination: str, aws_region: str) -> str: | ||
return f"https://{bucket_name}.s3.{aws_region}.amazonaws.com/{destination}" | ||
|
||
|
||
def main(args: Arguments): | ||
# first things first, we create the png file to upload to S3 | ||
log_file = Path(args.log_file) | ||
loss_data = read_loss_data(log_file=log_file) | ||
output_image = Path("/tmp/loss-graph.png") | ||
output_file = Path(args.output_file) | ||
render_image(loss_data=loss_data, outfile=output_image) | ||
destination_path = get_destination_path( | ||
base_ref=args.base_branch, pr_number=args.pr_number, head_sha=args.head_sha | ||
) | ||
write_to_s3( | ||
file=output_image, bucket_name=args.bucket_name, destination=destination_path | ||
) | ||
s3_url = get_url( | ||
bucket_name=args.bucket_name, | ||
destination=destination_path, | ||
aws_region=args.aws_region, | ||
) | ||
write_md_file( | ||
output_file=output_file, | ||
url=s3_url, | ||
pr_number=args.pr_number, | ||
head_sha=args.head_sha, | ||
origin_repository=args.origin_repository, | ||
) | ||
print(f"Loss graph uploaded to '{s3_url}'") | ||
print(f"Markdown file written to '{output_file}'") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--log-file", type=str) | ||
parser.add_argument("--output-file", type=str, default=None) | ||
|
||
parser.add_argument( | ||
"--log-file", | ||
type=str, | ||
required=True, | ||
help="The log file to read the loss data from.", | ||
) | ||
parser.add_argument( | ||
"--output-file", | ||
type=str, | ||
required=True, | ||
help="The output file where the resulting markdown will be written.", | ||
) | ||
parser.add_argument( | ||
"--aws-region", | ||
type=str, | ||
required=True, | ||
help="S3 region to which the bucket belongs.", | ||
) | ||
parser.add_argument( | ||
"--bucket-name", type=str, required=True, help="The S3 bucket name" | ||
) | ||
parser.add_argument( | ||
"--base-branch", | ||
type=str, | ||
required=True, | ||
help="The base branch being merged to.", | ||
) | ||
parser.add_argument("--pr-number", type=str, required=True, help="The PR number") | ||
parser.add_argument( | ||
"--head-sha", type=str, required=True, help="The head SHA of the PR" | ||
) | ||
parser.add_argument( | ||
"--origin-repository", | ||
type=str, | ||
required=True, | ||
help="The repository to which the originating branch belongs to.", | ||
) | ||
|
||
args = parser.parse_args() | ||
main(args) | ||
|
||
arguments = Arguments( | ||
log_file=args.log_file, | ||
output_file=args.output_file, | ||
aws_region=args.aws_region, | ||
bucket_name=args.bucket_name, | ||
base_branch=args.base_branch, | ||
pr_number=args.pr_number, | ||
head_sha=args.head_sha, | ||
origin_repository=args.origin_repository, | ||
) | ||
main(arguments) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#!/bin/bash | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
set -eo pipefail | ||
|
||
# This script exists to upload data to s3 and render the final markdown | ||
# file for the results of the benchmarking. | ||
|
||
function upload_to_s3() { | ||
local -r bucket_name="$1" | ||
local -r file_path="$2" | ||
local -r destination_path="$3" | ||
|
||
local -r bucket_path="s3://${bucket_name}/${destination_path}" | ||
printf 'Uploading result to S3: %s\n' "${bucket_path}" | ||
if [[ ! -f "${file_path}" ]]; then | ||
echo "Error: File '${file_path}' does not exist." | ||
exit 1 | ||
fi | ||
aws s3 cp "${file_path}" "${bucket_path}" | ||
} | ||
|
||
################################################################################ | ||
# Returns the path to where we'll be uploading the loss.png file to. | ||
# Currently, the format is in the form of: | ||
# pulls/<base_branch>/<pr_number>/<sha>/loss.png | ||
# This way, a single PR can have multiple runs and we can keep track of them. | ||
# Globals: | ||
# github (read-only) - The github context | ||
# Arguments: | ||
# None | ||
# Returns: | ||
# (string) The path to where we'll be uploading the loss.png file to. | ||
################################################################################ | ||
function get_s3_path() { | ||
printf 'pulls/%s/%s/%s/loss.png' "${{ github.event.pull_request.base.ref }}" "${{ github.event.pull_request.number }}" "${{ github.event.pull_request.head.sha}}" | ||
} | ||
|
||
function export_results() { | ||
local -r img_url="$1" | ||
printf '### Test performance:\n\n![Loss curve](%s)\n' "${img_url}" >> "${GITHUB_STEP_SUMMARY}" | ||
} | ||
|
||
function main() { | ||
local -r output_path=$(get_s3_path) | ||
local -r bucket_name='os-ci-loss-curve-test' | ||
local -r access_region="${{ vars.AWS_REGION }}" | ||
local -r input_file='./loss.png' | ||
local -r final_url="https://${bucket_name}.s3.${access_region}.amazonaws.com/${output_path}" | ||
|
||
printf 'Uploading image "%s" to bucket "%s" at output path "%s"\n' "${input_file}" "${bucket_name}" "${output_path}" | ||
upload_to_s3 "${bucket_name}" "${input_file}" "${output_path}" | ||
|
||
printf 'Final url should be: "%s"\n' "${final_url}" | ||
export_results "${final_url}" | ||
} | ||
|
||
main |