Skip to content

Commit

Permalink
Download srs kzg commitments on validator init (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
HudsonGraeme authored Feb 10, 2025
1 parent 29c6880 commit 51acea2
Showing 1 changed file with 28 additions and 30 deletions.
58 changes: 28 additions & 30 deletions neurons/utils/pre_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import cli_parser
from constants import IGNORED_MODEL_HASHES
from execution_layer.circuit import ProofSystem

from functools import partial
from collections import OrderedDict
Expand Down Expand Up @@ -56,12 +55,14 @@ def run_shared_preflight_checks(role: Optional[Roles] = None):
Exception: If any of the pre-flight checks fail.
"""

preflight_checks = OrderedDict({
"Syncing model files": partial(sync_model_files, role=role),
"Ensuring Node.js version": ensure_nodejs_version,
"Checking SnarkJS installation": ensure_snarkjs_installed,
"Checking EZKL installation": ensure_ezkl_installed,
})
preflight_checks = OrderedDict(
{
"Syncing model files": partial(sync_model_files, role=role),
"Ensuring Node.js version": ensure_nodejs_version,
"Checking SnarkJS installation": ensure_snarkjs_installed,
"Checking EZKL installation": ensure_ezkl_installed,
}
)

bt.logging.info(" PreFlight | Running pre-flight checks")

Expand Down Expand Up @@ -176,6 +177,26 @@ def sync_model_files(role: Optional[Roles] = None):
MODEL_DIR = os.path.join(os.path.dirname(__file__), "..", "deployment_layer")
SYNC_LOG_PREFIX = " SYNC | "

loop = asyncio.get_event_loop()
for logrows in range(1, 26):
if os.path.exists(
os.path.join(os.path.expanduser("~"), ".ezkl", "srs", f"kzg{logrows}.srs")
):
bt.logging.info(
f"{SYNC_LOG_PREFIX}SRS for logrows={logrows} already exists, skipping..."
)
continue

try:
loop.run_until_complete(download_srs(logrows))
bt.logging.info(
f"{SYNC_LOG_PREFIX}Successfully downloaded SRS for logrows={logrows}"
)
except Exception as e:
bt.logging.error(
f"{SYNC_LOG_PREFIX}Failed to download SRS for logrows={logrows}: {e}"
)

for model_hash in os.listdir(MODEL_DIR):
if not model_hash.startswith("model_"):
continue
Expand Down Expand Up @@ -203,29 +224,6 @@ def sync_model_files(role: Optional[Roles] = None):
SYNC_LOG_PREFIX + f"Failed to parse JSON from {metadata_file}"
)
continue
# If it's an EZKL model, we'll try to download the SRS files
if metadata.get("proof_system") == ProofSystem.EZKL:
ezkl_settings_file = os.path.join(MODEL_DIR, model_hash, "settings.json")
if not os.path.isfile(ezkl_settings_file):
bt.logging.error(
f"{SYNC_LOG_PREFIX}Settings file not found at {ezkl_settings_file} for {model_hash}. Skipping sync."
)
continue

try:
with open(ezkl_settings_file, "r", encoding="utf-8") as f:
logrows = json.load(f).get("run_args", {}).get("logrows")
if logrows:
loop = asyncio.get_event_loop()
loop.run_until_complete(download_srs(logrows))
bt.logging.info(
f"{SYNC_LOG_PREFIX}Successfully downloaded SRS for logrows={logrows}"
)
except (json.JSONDecodeError, subprocess.CalledProcessError) as e:
bt.logging.error(
f"{SYNC_LOG_PREFIX}Failed to process settings or download SRS: {e}"
)
continue

external_files = metadata.get("external_files", {})
for key, url in external_files.items():
Expand Down

0 comments on commit 51acea2

Please sign in to comment.