Skip to content

Commit

Permalink
update script to use s3
Browse files Browse the repository at this point in the history
  • Loading branch information
RobotSail committed Oct 25, 2024
1 parent cef0f5a commit 24a5658
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 41 deletions.
57 changes: 43 additions & 14 deletions .github/workflows/print-loss.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
name: Print Loss

on:
Expand All @@ -22,24 +23,52 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --upgrade pip
# python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Generate test data
run: |
printf '{"total_loss": 2.0}\n{"total_loss": 1.8023}\n{"total_loss": 1.52324}\n{"total_loss": 1.3234}' > test-log.jsonl
ls -al
- name: Print loss
run: |
python scripts/create-loss-graph.py --log-file test-log.jsonl
- name: Print a comment
uses: actions/github-script@v7
- name: Upload loss
uses: actions/upload-artifact@v4
with:
name: training-log.jsonl
path: test-log.jsonl
retention-days: 1
overwrite: true
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: 'Thank you for your contribution! Please make sure to review our contribution guidelines.'
})
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ vars.AWS_REGION }}
- name: Download loss data
id: download-logs
uses: actions/download-artifact@v4
with:
name: training-log.jsonl
path: downloaded-data


- name: Try to upload to s3
run: |
echo "$(which aws)"
bucket_name='os-ci-loss-curve-test'
output_file='./test.md'
python scripts/create-loss-graph.py \
--log-file "${{ steps.download-logs.outputs.download-path }}/test-log.jsonl" \
--output-file "${output_file}" \
--aws-region "${{ vars.AWS_REGION }}" \
--bucket-name "${bucket_name}" \
--base-branch "${{ github.event.pull_request.base.ref }}" \
--pr-number "${{ github.event.pull_request.number }}" \
--head-sha "${{ github.event.pull_request.head.sha }}" \
--origin-repository "${{ github.repository }}"
cat "${output_file}" >> "${GITHUB_STEP_SUMMARY}"
echo "test 1: https://github.com/${{ github.repository }}/commit/${{ github.sha }}"
echo "test 2: ${{ github.event.pull_request.html_url }}/commits/${{ github.sha }}"
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,5 @@ run_training(
train_args=training_args,
)
```

Test update
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

-r requirements.txt
# blot this out
# -r requirements.txt

pre-commit>=3.0.4,<5.0
pylint>=2.16.2,<4.0
Expand Down
186 changes: 160 additions & 26 deletions scripts/create-loss-graph.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from argparse import ArgumentParser, Namespace
from base64 import b64encode
from io import BytesIO
from argparse import ArgumentParser
from pathlib import Path
from subprocess import run
from typing import Dict, List
import json

# Third Party
from matplotlib import pyplot as plt
from pydantic import BaseModel


def main(args: Namespace):
log_file = Path(args.log_file)
if not log_file.exists():
raise FileNotFoundError(f'log file "{args.log_file}" does not exist')

if not log_file.is_file():
raise RuntimeError(f'log file cannot be a directory: "{log_file}"')

with open(Path(log_file), "r", encoding="utf-8") as infile:
contents = [json.loads(l) for l in infile.read().splitlines()]
class Arguments(BaseModel):
log_file: str | None = None
output_file: str
aws_region: str
bucket_name: str
base_branch: str
pr_number: str
head_sha: str
origin_repository: str

loss_data = [item["total_loss"] for item in contents if "total_loss" in item]

def render_image(loss_data: List[float], outfile: Path) -> str:
# create the plot
plt.figure()
plt.plot(loss_data)
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training performance over fixed dataset")

buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
if outfile.exists():
outfile.unlink()

plt.savefig(outfile, format="png")


def contents_from_file(log_file: Path) -> List[Dict]:
if not log_file.exists():
raise FileNotFoundError(f"Log file {log_file} does not exist")
if log_file.is_dir():
raise ValueError(f"Log file {log_file} is a directory")
with open(log_file, "r") as f:
return [json.loads(l) for l in f.read().splitlines()]


def read_loss_data(log_file: Path) -> List[float]:
if not log_file:
raise ValueError("log_file must be provided when source is file")
contents = contents_from_file(log_file)

# select the loss data
loss_data = [item["total_loss"] for item in contents if "total_loss" in item]

if not loss_data:
raise ValueError("Loss data is empty")

# ensure that the loss data is valid
if not all(isinstance(l, float) for l in loss_data):
raise ValueError("Loss data must be a list of floats")

imgb64 = b64encode(buf.read()).decode("utf-8")
return loss_data

output_file = Path(args.output_file) if args.output_file else None
if output_file:
output_file.write_text(imgb64, encoding="utf-8")

def write_to_s3(
file: Path,
bucket_name: str,
destination: str,
):
if not file.exists():
raise RuntimeError(f"File {file} does not exist")

s3_path = f"s3://{bucket_name}/{destination}"
results = run(
["aws", "s3", "cp", str(file), s3_path], capture_output=True, check=True
)
if results.returncode != 0:
raise RuntimeError(f"failed to upload to s3: {results.stderr.decode('utf-8')}")
else:
# just print the file without including a newline, this way it can be piped
print(imgb64, end="")
print(results.stdout.decode("utf-8"))


def get_destination_path(base_ref: str, pr_number: str, head_sha: str):
return f"pulls/{base_ref}/{pr_number}/{head_sha}/loss-graph.png"


def write_md_file(
output_file: Path, url: str, pr_number: str, head_sha: str, origin_repository: str
):
commit_url = f"https://github.com/{origin_repository}/commit/{head_sha}"
md_template = f"""
# Loss Graph for PR {args.pr_number} ([{args.head_sha[:7]}]({commit_url}))
![Loss Graph]({url})
"""
output_file.write_text(md_template, encoding="utf-8")


def get_url(bucket_name: str, destination: str, aws_region: str) -> str:
return f"https://{bucket_name}.s3.{aws_region}.amazonaws.com/{destination}"


def main(args: Arguments):
# first things first, we create the png file to upload to S3
log_file = Path(args.log_file)
loss_data = read_loss_data(log_file=log_file)
output_image = Path("/tmp/loss-graph.png")
output_file = Path(args.output_file)
render_image(loss_data=loss_data, outfile=output_image)
destination_path = get_destination_path(
base_ref=args.base_branch, pr_number=args.pr_number, head_sha=args.head_sha
)
write_to_s3(
file=output_image, bucket_name=args.bucket_name, destination=destination_path
)
s3_url = get_url(
bucket_name=args.bucket_name,
destination=destination_path,
aws_region=args.aws_region,
)
write_md_file(
output_file=output_file,
url=s3_url,
pr_number=args.pr_number,
head_sha=args.head_sha,
origin_repository=args.origin_repository,
)
print(f"Loss graph uploaded to '{s3_url}'")
print(f"Markdown file written to '{output_file}'")


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--log-file", type=str)
parser.add_argument("--output-file", type=str, default=None)

parser.add_argument(
"--log-file",
type=str,
required=True,
help="The log file to read the loss data from.",
)
parser.add_argument(
"--output-file",
type=str,
required=True,
help="The output file where the resulting markdown will be written.",
)
parser.add_argument(
"--aws-region",
type=str,
required=True,
help="S3 region to which the bucket belongs.",
)
parser.add_argument(
"--bucket-name", type=str, required=True, help="The S3 bucket name"
)
parser.add_argument(
"--base-branch",
type=str,
required=True,
help="The base branch being merged to.",
)
parser.add_argument("--pr-number", type=str, required=True, help="The PR number")
parser.add_argument(
"--head-sha", type=str, required=True, help="The head SHA of the PR"
)
parser.add_argument(
"--origin-repository",
type=str,
required=True,
help="The repository to which the originating branch belongs to.",
)

args = parser.parse_args()
main(args)

arguments = Arguments(
log_file=args.log_file,
output_file=args.output_file,
aws_region=args.aws_region,
bucket_name=args.bucket_name,
base_branch=args.base_branch,
pr_number=args.pr_number,
head_sha=args.head_sha,
origin_repository=args.origin_repository,
)
main(arguments)
58 changes: 58 additions & 0 deletions scripts/render-results.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0

set -eo pipefail

# This script exists to upload data to s3 and render the final markdown
# file for the results of the benchmarking.

function upload_to_s3() {
local -r bucket_name="$1"
local -r file_path="$2"
local -r destination_path="$3"

local -r bucket_path="s3://${bucket_name}/${destination_path}"
printf 'Uploading result to S3: %s\n' "${bucket_path}"
if [[ ! -f "${file_path}" ]]; then
echo "Error: File '${file_path}' does not exist."
exit 1
fi
aws s3 cp "${file_path}" "${bucket_path}"
}

################################################################################
# Returns the path to where we'll be uploading the loss.png file to.
# Currently, the format is in the form of:
# pulls/<base_branch>/<pr_number>/<sha>/loss.png
# This way, a single PR can have multiple runs and we can keep track of them.
# Globals:
# github (read-only) - The github context
# Arguments:
# None
# Returns:
# (string) The path to where we'll be uploading the loss.png file to.
################################################################################
function get_s3_path() {
printf 'pulls/%s/%s/%s/loss.png' "${{ github.event.pull_request.base.ref }}" "${{ github.event.pull_request.number }}" "${{ github.event.pull_request.head.sha}}"
}

function export_results() {
local -r img_url="$1"
printf '### Test performance:\n\n![Loss curve](%s)\n' "${img_url}" >> "${GITHUB_STEP_SUMMARY}"
}

function main() {
local -r output_path=$(get_s3_path)
local -r bucket_name='os-ci-loss-curve-test'
local -r access_region="${{ vars.AWS_REGION }}"
local -r input_file='./loss.png'
local -r final_url="https://${bucket_name}.s3.${access_region}.amazonaws.com/${output_path}"

printf 'Uploading image "%s" to bucket "%s" at output path "%s"\n' "${input_file}" "${bucket_name}" "${output_path}"
upload_to_s3 "${bucket_name}" "${input_file}" "${output_path}"

printf 'Final url should be: "%s"\n' "${final_url}"
export_results "${final_url}"
}

main

0 comments on commit 24a5658

Please sign in to comment.