Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bootloader #9

Merged
merged 7 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# python virtual environment
venv/

# Generated by Cargo
# will have compiled files and executables
debug/
Expand Down
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@ libp2p = { version = "0.53.2", features = [
] }
libsecp256k1 = "0.7.1"
num-bigint = "0.4.4"
proptest = "1.4.0"
proptest-derive = "0.4.0"
rand = "0.8.5"
serde = "1.0.197"
serde_json = "1.0.115"
starknet = "0.9.0"
serde_with = "3.7.0"
starknet = "0.10.0"
starknet-crypto = "0.6.2"
strum = { version = "0.26", features = ["derive"] }
tempfile = "3.10.1"
thiserror = "1.0.58"
Expand Down
32 changes: 25 additions & 7 deletions cairo/bootloader/hash_program.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import argparse
import json

from starkware.cairo.common.hash_chain import compute_hash_chain
from starkware.cairo.lang.compiler.program import Program, ProgramBase
from starkware.cairo.lang.version import __version__
from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager, poseidon_hash_many
from starkware.cairo.lang.vm.crypto import (
get_crypto_lib_context_manager,
poseidon_hash_many,
)
from starkware.python.utils import from_bytes


def compute_program_hash_chain(program: ProgramBase, use_poseidon: bool, bootloader_version=0):
def compute_program_hash_chain(
program: ProgramBase, use_poseidon: bool, bootloader_version=0
):
"""
Computes a hash chain over a program, including the length of the data chain.
"""
builtin_list = [from_bytes(builtin.encode("ascii")) for builtin in program.builtins]
# The program header below is missing the data length, which is later added to the data_chain.
program_header = [bootloader_version, program.main, len(program.builtins)] + builtin_list
program_header = [
bootloader_version,
program.main,
len(program.builtins),
] + builtin_list
data_chain = program_header + program.data

if use_poseidon:
Expand All @@ -23,8 +31,12 @@ def compute_program_hash_chain(program: ProgramBase, use_poseidon: bool, bootloa


def main():
parser = argparse.ArgumentParser(description="A tool to compute the hash of a cairo program")
parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}")
parser = argparse.ArgumentParser(
description="A tool to compute the hash of a cairo program"
)
parser.add_argument(
"-v", "--version", action="version", version=f"%(prog)s {__version__}"
)
parser.add_argument(
"--program",
type=argparse.FileType("r"),
Expand All @@ -48,7 +60,13 @@ def main():

with get_crypto_lib_context_manager(args.flavor):
program = Program.Schema().load(json.load(args.program))
print(hex(compute_program_hash_chain(program=program, use_poseidon=args.use_poseidon)))
print(
hex(
compute_program_hash_chain(
program=program, use_poseidon=args.use_poseidon
)
)
)


if __name__ == "__main__":
Expand Down
58 changes: 41 additions & 17 deletions cairo/bootloader/objects.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,68 @@
import dataclasses
from abc import abstractmethod
from dataclasses import field
from typing import ClassVar, Dict, List, Optional, Type

import marshmallow
import marshmallow.fields as mfields
from typing import List, Optional
import marshmallow_dataclass
from marshmallow_oneofschema import OneOfSchema

from starkware.cairo.lang.compiler.program import Program, ProgramBase, StrippedProgram
from starkware.cairo.lang.compiler.program import ProgramBase, StrippedProgram
from starkware.cairo.lang.vm.cairo_pie import CairoPie
from starkware.starkware_utils.marshmallow_dataclass_fields import additional_metadata
from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass


class TaskSpec(ValidatedMarshmallowDataclass):
@abstractmethod
def load_task(self) -> "Task":
"""
Returns the corresponding task.
"""


class Task:
@abstractmethod
def get_program(self) -> ProgramBase:
"""
Returns the task's Cairo program.
"""


@dataclasses.dataclass(frozen=True)
class Job(Task):
reward: int
num_of_steps: int
class CairoPieTask(Task):
cairo_pie: CairoPie
registry_address: bytearray
public_key: bytearray
signature: bytearray
use_poseidon: bool

def get_program(self) -> StrippedProgram:
return self.cairo_pie.program


@dataclasses.dataclass(frozen=True)
class JobData(Task):
reward: int
num_of_steps: int
cairo_pie_compressed: List[int]
registry_address: str

def load_task(self) -> "CairoPieTask":
return CairoPieTask(
cairo_pie=CairoPie.deserialize(bytes(self.cairo_pie_compressed)),
use_poseidon=True,
)


@dataclasses.dataclass(frozen=True)
class Job(Task):
job_data: JobData
public_key: List[int]
signature: List[int]

def load_task(self) -> "CairoPieTask":
return self.job_data.load_task()


@marshmallow_dataclass.dataclass(frozen=True)
class SimpleBootloaderInput(ValidatedMarshmallowDataclass):
identity: bytearray
identity: str
job: Job

fact_topologies_path: Optional[str]

# If true, the bootloader will put all the outputs in a single page, ignoring the
# tasks' fact topologies.
single_page: bool
single_page: bool
2 changes: 1 addition & 1 deletion cairo/bootloader/recursive_with_poseidon/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
BITWISE_BUILTIN,
POSEIDON_BUILTIN,
]
)
)
16 changes: 2 additions & 14 deletions cairo/bootloader/recursive_with_poseidon/execute_task.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,17 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
%{
from bootloader.objects import (
CairoPieTask,
RunProgramTask,
Task,
)
from bootloader.utils import (
load_cairo_pie,
prepare_output_runner,
)

assert isinstance(task, Task)
n_builtins = len(task.get_program().builtins)
new_task_locals = {}
if isinstance(task, RunProgramTask):
new_task_locals['program_input'] = task.program_input
new_task_locals['WITH_BOOTLOADER'] = True

vm_load_program(task.program, program_address)
elif isinstance(task, CairoPieTask):

if isinstance(task, CairoPieTask):
ret_pc = ids.ret_pc_label.instruction_offset_ - ids.call_task.instruction_offset_ + pc
load_cairo_pie(
task=task.cairo_pie, memory=memory, segments=segments,
Expand All @@ -169,10 +163,6 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
else:
raise NotImplementedError(f'Unexpected task type: {type(task).__name__}.')

output_runner_data = prepare_output_runner(
task=task,
output_builtin=output_builtin,
output_ptr=ids.pre_execution_builtin_ptrs.output)
vm_enter_scope(new_task_locals)
%}

Expand Down Expand Up @@ -243,8 +233,6 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
fact_topologies.append(get_task_fact_topology(
output_size=output_end - output_start,
task=task,
output_builtin=output_builtin,
output_runner_data=output_runner_data,
))
%}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func run_simple_bootloader{
local task_range_check_ptr;

%{
n_tasks = len(simple_bootloader_input.tasks)
n_tasks = 1
memory[ids.output_ptr] = n_tasks

# Task range checks are located right after simple bootloader validation range checks, and
Expand Down Expand Up @@ -65,7 +65,6 @@ func run_simple_bootloader{
// Call execute_tasks.
let (__fp__, _) = get_fp_and_pc();

%{ tasks = simple_bootloader_input.tasks %}
let builtin_ptrs = &builtin_ptrs_before;
let self_range_check_ptr = range_check_ptr;
with builtin_ptrs, self_range_check_ptr {
Expand Down Expand Up @@ -141,8 +140,7 @@ func execute_tasks{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
from bootloader.objects import Task

# Pass current task to execute_task.
task_id = len(simple_bootloader_input.tasks) - ids.n_tasks
task = simple_bootloader_input.tasks[task_id].load_task()
task = simple_bootloader_input.job.load_task()
%}
tempvar use_poseidon = nondet %{ 1 if task.use_poseidon else 0 %};
// Call execute_task to execute the current task.
Expand Down
Loading
Loading