Skip to content

Commit

Permalink
update script to use s3
Browse files Browse the repository at this point in the history
  • Loading branch information
RobotSail committed Oct 25, 2024
1 parent accaaae commit 8038986
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 44 deletions.
29 changes: 26 additions & 3 deletions .github/workflows/print-loss.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,37 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --upgrade pip
# python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Generate test data
run: |
printf '{"total_loss": 2.0}\n{"total_loss": 1.8023}\n{"total_loss": 1.52324}\n{"total_loss": 1.3234}' > test-log.jsonl
ls -al
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ vars.AWS_REGION }}
- name: Print loss
run: |
python scripts/create-loss-graph.py markdown --log-file test-log.jsonl --output-file 'results.md'
cat 'results.md' >> "${GITHUB_STEP_SUMMARY}"
python scripts/create-loss-graph.py image --log-file test-log.jsonl --output-file 'output.png'
- name: Try to upload to s3
run: |
echo "$(which aws)"
bucket_name='os-ci-loss-curve-test'
printf 'curious what this has: %s\n' "${{ github.event.pull_request.commits }}"
printf 'curious what this has: %s\n' "${{ github.event.pull_request.id }}"
printf 'curious what this has: %s\n' "${{ github.event.pull_request.head.sha }}"
base_path="pulls/${{ github.event.pull_request.number }}/curve.png"
s3_path="${bucket_name}/${base_path}"
aws s3 cp output.png "s3://${s3_path}"
printf 'Test: us-east-2\n'
url="https://${bucket_name}.s3.${{vars.AWS_REGION }}.amazonaws.com/${base_path}"
printf 'url is: %s\n' "${url}"
printf '### Test performance:\n\n![Loss curve](%s)\n' "${url}" >> "${GITHUB_STEP_SUMMARY}"
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

-r requirements.txt
# blot this out
# -r requirements.txt

pre-commit>=3.0.4,<5.0
pylint>=2.16.2,<4.0
Expand Down
110 changes: 70 additions & 40 deletions scripts/create-loss-graph.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from argparse import ArgumentParser, Namespace
from base64 import b64encode
from io import BytesIO
from argparse import ArgumentParser
from base64 import b64decode
import os
from pathlib import Path
import json
from typing import List

# Third Party
from matplotlib import pyplot as plt

ENV_VAR_NAME = "LOSS_DATA"

def create_b64_data(log_file: Path) -> str:

def render_image(loss_data: List[float], outfile: Path) -> str:
log_file = Path(args.log_file)
if not log_file.exists():
raise FileNotFoundError(f'log file "{args.log_file}" does not exist')
Expand All @@ -30,56 +33,83 @@ def create_b64_data(log_file: Path) -> str:
plt.ylabel("Loss")
plt.title("Training performance over fixed dataset")

buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)

imgb64 = b64encode(buf.read()).decode("utf-8")
return imgb64


def create_md_file(b64_data: str, output_file: Path | None):
content = f"""## Training Performance\n
plt.savefig(outfile, format="png")


def test_visibility():
envs = os.environ.keys()
print(envs)

print(
"===================================================================================================="
)
environ_name = "${ vars.AWS_REGION }"
print(f"Environment variable name: {environ_name}")
print(f'Environment variable value: {os.getenv(environ_name, "Not found")}')
print(
"===================================================================================================="
)
import subprocess

result = subprocess.run(
["printf '%s' ${{ vars.AWS_REGION }}"], shell=True, capture_output=True
)
print(f"Result: {result.stdout.decode()}")


def read_loss_data(src: str, log_file: Path | None = None) -> List[float]:
match src:
case "env":
data = os.getenv(ENV_VAR_NAME, None)
if not data:
raise ValueError(f"Environment variable {ENV_VAR_NAME} not set")
# decode the base64 data
data = b64decode(data)
case "file":
if not log_file:
raise ValueError("log_file must be provided when source is file")
if not log_file.exists():
raise FileNotFoundError(f"Log file {log_file} does not exist")
if log_file.is_dir():
raise ValueError(f"Log file {log_file} is a directory")
with open(log_file, "r") as f:
data = f.read()
case _:
raise ValueError(f"Unknown source: {src}")

![Training Performance](data:image/png;base64,{b64_data})
"""
if not output_file:
print(content)
else:
output_file.write_text(content, encoding="utf-8")
# select the loss data
contents = [json.loads(l) for l in data.splitlines()]
loss_data = [item["total_loss"] for item in contents if "total_loss" in item]

if not loss_data:
raise ValueError("Loss data is empty")

def main(args: Namespace):
imgb64 = create_b64_data(args.log_file)
# ensure that the loss data is valid
if not all(isinstance(l, float) for l in loss_data):
raise ValueError("Loss data must be a list of floats")

output_file = Path(args.output_file) if args.output_file else None
if output_file:
output_file.write_text(imgb64, encoding="utf-8")
else:
# just print the file without including a newline, this way it can be piped
print(imgb64, end="")
return loss_data


if __name__ == "__main__":
parser = ArgumentParser()
subparsers = parser.add_subparsers(dest="command", required=True)

image_parser = subparsers.add_parser("image")
image_parser.add_argument("--log-file", type=str, required=True)
image_parser.add_argument("--output-file", type=str, default=None)

markdown_parser = subparsers.add_parser("markdown")
markdown_parser.add_argument("--log-file", type=str, required=True)
markdown_parser.add_argument("--output-file", type=str, default=None)
image_parser.add_argument(
"--source",
choices=["file", "env"],
default="file",
help="Source of the log file to read the loss data from. If file is selected, then we will read from the given file. If env is selected, we will read from the LOSS_DATA environment variable. If writing to env, the result should be a base64-encoded JSONL file.",
)
image_parser.add_argument("--log-file", type=str, default=None)
image_parser.add_argument("--output-file", type=str, required=True)

args = parser.parse_args()
match args.command:
case "image":
print("creating image")
main(args)
case "markdown":
print("creating md file")
b64_data = create_b64_data(log_file=Path(args.log_file))
create_md_file(b64_data=b64_data, output_file=Path(args.output_file))
test_visibility()
loss_data = read_loss_data(src=args.source, log_file=Path(args.log_file))
render_image(loss_data=loss_data, outfile=Path(args.output_file))
case _:
raise ValueError(f"Unknown command: {args.command}")

0 comments on commit 8038986

Please sign in to comment.