forked from instructlab/training
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from RobotSail/add-log-exporting-2
create script to output loss curve
- Loading branch information
Showing
3 changed files
with
100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
name: Print Loss | ||
|
||
on: | ||
pull_request: | ||
types: [opened, reopened] | ||
branches: | ||
- main | ||
|
||
jobs: | ||
print-loss: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v4 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.x' | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements-dev.txt | ||
- name: Generate test data | ||
run: | | ||
printf '{"total_loss": 2.0}\n{"total_loss": 1.8023}\n{"total_loss": 1.52324}\n{"total_loss": 1.3234}' > test-log.jsonl | ||
ls -al | ||
- name: Print loss | ||
run: | | ||
python scripts/create-loss-graph.py --log-file test-log.jsonl | ||
- name: Print a comment | ||
uses: actions/github-script@v7 | ||
with: | ||
github-token: ${{secrets.GITHUB_TOKEN}} | ||
script: | | ||
github.rest.issues.createComment({ | ||
issue_number: context.issue.number, | ||
owner: context.repo.owner, | ||
repo: context.repo.repo, | ||
body: 'Thank you for your contribution! Please make sure to review our contribution guidelines.' | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,7 @@ ipython | |
ipykernel | ||
jupyter | ||
|
||
|
||
# for printing out the loss | ||
matplotlib | ||
numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# Standard | ||
from argparse import ArgumentParser, Namespace | ||
from base64 import b64encode | ||
from io import BytesIO | ||
from pathlib import Path | ||
import json | ||
|
||
# Third Party | ||
from matplotlib import pyplot as plt | ||
|
||
|
||
def main(args: Namespace): | ||
log_file = Path(args.log_file) | ||
if not log_file.exists(): | ||
raise FileNotFoundError(f'log file "{args.log_file}" does not exist') | ||
|
||
if not log_file.is_file(): | ||
raise RuntimeError(f'log file cannot be a directory: "{log_file}"') | ||
|
||
with open(Path(log_file), "r", encoding="utf-8") as infile: | ||
contents = [json.loads(l) for l in infile.read().splitlines()] | ||
|
||
loss_data = [item["total_loss"] for item in contents if "total_loss" in item] | ||
|
||
# create the plot | ||
plt.figure() | ||
plt.plot(loss_data) | ||
plt.xlabel("Steps") | ||
plt.ylabel("Loss") | ||
plt.title("Training performance over fixed dataset") | ||
|
||
buf = BytesIO() | ||
plt.savefig(buf, format="png") | ||
buf.seek(0) | ||
|
||
imgb64 = b64encode(buf.read()).decode("utf-8") | ||
|
||
output_file = Path(args.output_file) if args.output_file else None | ||
if output_file: | ||
output_file.write_text(imgb64, encoding="utf-8") | ||
else: | ||
# just print the file without including a newline, this way it can be piped | ||
print(imgb64, end="") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--log-file", type=str) | ||
parser.add_argument("--output-file", type=str, default=None) | ||
args = parser.parse_args() | ||
main(args) |