diff --git a/.github/workflows/print-loss.yml b/.github/workflows/print-loss.yml new file mode 100644 index 00000000..55737432 --- /dev/null +++ b/.github/workflows/print-loss.yml @@ -0,0 +1,44 @@ +name: Print Loss + +on: + pull_request: + types: [opened, reopened] + branches: + - main + +jobs: + print-loss: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + + - name: Generate test data + run: | + printf '{"total_loss": 2.0}\n{"total_loss": 1.8023}\n{"total_loss": 1.52324}\n{"total_loss": 1.3234}' > test-log.jsonl + ls -al + - name: Print loss + run: | + python scripts/create-loss-graph.py --log-file test-log.jsonl + - name: Print a comment + uses: actions/github-script@v7 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: 'Thank you for your contribution! Please make sure to review our contribution guidelines.' + }) \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index a0dff1ed..26c9141a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,3 +11,7 @@ ipython ipykernel jupyter + +# for printing out the loss +matplotlib +numpy diff --git a/scripts/create-loss-graph.py b/scripts/create-loss-graph.py new file mode 100644 index 00000000..f4cc6b1e --- /dev/null +++ b/scripts/create-loss-graph.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from argparse import ArgumentParser, Namespace +from base64 import b64encode +from io import BytesIO +from pathlib import Path +import json + +# Third Party +from matplotlib import pyplot as plt + + +def main(args: Namespace): + log_file = Path(args.log_file) + if not log_file.exists(): + raise FileNotFoundError(f'log file "{args.log_file}" does not exist') + + if not log_file.is_file(): + raise RuntimeError(f'log file cannot be a directory: "{log_file}"') + + with open(Path(log_file), "r", encoding="utf-8") as infile: + contents = [json.loads(l) for l in infile.read().splitlines()] + + loss_data = [item["total_loss"] for item in contents if "total_loss" in item] + + # create the plot + plt.figure() + plt.plot(loss_data) + plt.xlabel("Steps") + plt.ylabel("Loss") + plt.title("Training performance over fixed dataset") + + buf = BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + + imgb64 = b64encode(buf.read()).decode("utf-8") + + output_file = Path(args.output_file) if args.output_file else None + if output_file: + output_file.write_text(imgb64, encoding="utf-8") + else: + # just print the file without including a newline, this way it can be piped + print(imgb64, end="") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--log-file", type=str) + parser.add_argument("--output-file", type=str, default=None) + args = parser.parse_args() + main(args)