Skip to content

Commit

Permalink
Merge pull request #3 from RobotSail/add-log-exporting-2
Browse files Browse the repository at this point in the history
create script to output loss curve
  • Loading branch information
RobotSail authored Oct 16, 2024
2 parents 68b4b76 + 9956e21 commit 2b62889
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
44 changes: 44 additions & 0 deletions .github/workflows/print-loss.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Print Loss

on:
pull_request:
types: [opened, reopened]
branches:
- main

jobs:
print-loss:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Generate test data
run: |
printf '{"total_loss": 2.0}\n{"total_loss": 1.8023}\n{"total_loss": 1.52324}\n{"total_loss": 1.3234}' > test-log.jsonl
ls -al
- name: Print loss
run: |
python scripts/create-loss-graph.py --log-file test-log.jsonl
- name: Print a comment
uses: actions/github-script@v7
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: 'Thank you for your contribution! Please make sure to review our contribution guidelines.'
})
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ ipython
ipykernel
jupyter


# for printing out the loss
matplotlib
numpy
52 changes: 52 additions & 0 deletions scripts/create-loss-graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from argparse import ArgumentParser, Namespace
from base64 import b64encode
from io import BytesIO
from pathlib import Path
import json

# Third Party
from matplotlib import pyplot as plt


def main(args: Namespace):
log_file = Path(args.log_file)
if not log_file.exists():
raise FileNotFoundError(f'log file "{args.log_file}" does not exist')

if not log_file.is_file():
raise RuntimeError(f'log file cannot be a directory: "{log_file}"')

with open(Path(log_file), "r", encoding="utf-8") as infile:
contents = [json.loads(l) for l in infile.read().splitlines()]

loss_data = [item["total_loss"] for item in contents if "total_loss" in item]

# create the plot
plt.figure()
plt.plot(loss_data)
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training performance over fixed dataset")

buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)

imgb64 = b64encode(buf.read()).decode("utf-8")

output_file = Path(args.output_file) if args.output_file else None
if output_file:
output_file.write_text(imgb64, encoding="utf-8")
else:
# just print the file without including a newline, this way it can be piped
print(imgb64, end="")


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--log-file", type=str)
parser.add_argument("--output-file", type=str, default=None)
args = parser.parse_args()
main(args)

0 comments on commit 2b62889

Please sign in to comment.