diff --git a/.github/workflows/print-loss.yml b/.github/workflows/print-loss.yml index f517967e..74514f44 100644 --- a/.github/workflows/print-loss.yml +++ b/.github/workflows/print-loss.yml @@ -31,15 +31,5 @@ jobs: ls -al - name: Print loss run: | - python scripts/create-loss-graph.py --log-file test-log.jsonl - - name: Print a comment - uses: actions/github-script@v7 - with: - github-token: ${{secrets.GITHUB_TOKEN}} - script: | - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: 'Thank you for your contribution! Please make sure to review our contribution guidelines.' - }) \ No newline at end of file + python scripts/create-loss-graph.py markdown --log-file test-log.jsonl --output-file 'results.md' + cat 'results.md' >> "${GITHUB_STEP_SUMMARY}" diff --git a/README.md b/README.md index 645583f3..3d2de24c 100644 --- a/README.md +++ b/README.md @@ -332,3 +332,5 @@ run_training( train_args=training_args, ) ``` + +Test update \ No newline at end of file diff --git a/scripts/create-loss-graph.py b/scripts/create-loss-graph.py index f4cc6b1e..cf173902 100644 --- a/scripts/create-loss-graph.py +++ b/scripts/create-loss-graph.py @@ -10,7 +10,7 @@ from matplotlib import pyplot as plt -def main(args: Namespace): +def create_b64_data(log_file: Path) -> str: log_file = Path(args.log_file) if not log_file.exists(): raise FileNotFoundError(f'log file "{args.log_file}" does not exist') @@ -35,6 +35,22 @@ def main(args: Namespace): buf.seek(0) imgb64 = b64encode(buf.read()).decode("utf-8") + return imgb64 + + +def create_md_file(b64_data: str, output_file: Path | None): + content = f"""## Training Performance\n + +![Training Performance](data:image/png;base64,{b64_data}) +""" + if not output_file: + print(content) + else: + output_file.write_text(content, encoding="utf-8") + + +def main(args: Namespace): + imgb64 = create_b64_data(args.log_file) output_file = Path(args.output_file) if args.output_file else None if output_file: @@ -46,7 +62,24 @@ def main(args: Namespace): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--log-file", type=str) - parser.add_argument("--output-file", type=str, default=None) + subparsers = parser.add_subparsers(dest="command", required=True) + + image_parser = subparsers.add_parser("image") + image_parser.add_argument("--log-file", type=str, required=True) + image_parser.add_argument("--output-file", type=str, default=None) + + markdown_parser = subparsers.add_parser("markdown") + markdown_parser.add_argument("--log-file", type=str, required=True) + markdown_parser.add_argument("--output-file", type=str, default=None) + args = parser.parse_args() - main(args) + match args.command: + case "image": + print("creating image") + main(args) + case "markdown": + print("creating md file") + b64_data = create_b64_data(log_file=Path(args.log_file)) + create_md_file(b64_data=b64_data, output_file=Path(args.output_file)) + case _: + raise ValueError(f"Unknown command: {args.command}")