Skip to content

Commit

Permalink
refactor!: Start VM with initialized Sponge state
Browse files Browse the repository at this point in the history
Previously, it was necessary to initialize Triton VM's Sponge state
manually before any other Sponge instruction could be used. Now,
instructions `sponge_absorb`, `sponge_absorb_mem`, and `sponge_squeeze`
can be executed without first executing instruction `sponge_init`.
  • Loading branch information
jan-ferdinand committed May 29, 2024
1 parent 72238cf commit 09521e3
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 93 deletions.
8 changes: 4 additions & 4 deletions specification/src/arithmetization-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ Before automatic degree lowering:
| [OpStackTable](operational-stack-table.md) | 3 | 0 | 5 | 0 | 4 |
| [RamTable](random-access-memory-table.md) | 7 | 0 | 12 | 1 | 5 |
| [JumpStackTable](jump-stack-table.md) | 6 | 0 | 6 | 0 | 4 |
| [HashTable](hash-table.md) | 22 | 45 | 47 | 2 | 9 |
| [HashTable](hash-table.md) | 22 | 45 | 48 | 2 | 10 |
| [CascadeTable](cascade-table.md) | 2 | 1 | 3 | 0 | 4 |
| [LookupTable](lookup-table.md) | 3 | 1 | 4 | 1 | 3 |
| [U32Table](u32-table.md) | 1 | 15 | 22 | 2 | 12 |
| [Grand Cross-Table Argument](table-linking.md) | 0 | 0 | 0 | 14 | 1 |
| **TOTAL** | **79** | **76** | **178** | **23** | **19** |
| **TOTAL** | **79** | **76** | **179** | **23** | **19** |

After automatically lowering degree to 4:

Expand All @@ -54,10 +54,10 @@ After automatically lowering degree to 4:
| [OpStackTable](operational-stack-table.md) | 3 | 0 | 5 | 0 |
| [RamTable](random-access-memory-table.md) | 7 | 0 | 13 | 1 |
| [JumpStackTable](jump-stack-table.md) | 6 | 0 | 6 | 0 |
| [HashTable](hash-table.md) | 22 | 52 | 84 | 2 |
| [HashTable](hash-table.md) | 22 | 52 | 85 | 2 |
| [CascadeTable](cascade-table.md) | 2 | 1 | 3 | 0 |
| [LookupTable](lookup-table.md) | 3 | 1 | 4 | 1 |
| [U32Table](u32-table.md) | 1 | 26 | 34 | 2 |
| [Grand Cross-Table Argument](table-linking.md) | 0 | 0 | 0 | 14 |
| **TOTAL** | **158** | **152** | **356** | **46** |
| **TOTAL** | **158** | **152** | **358** | **46** |
<!-- auto-gen info stop -->
3 changes: 0 additions & 3 deletions triton-vm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ pub enum InstructionError {
#[error("division by 0 is impossible")]
DivisionByZero,

#[error("the Sponge state must be initialized before it can be used")]
SpongeNotInitialized,

#[error("the logarithm of 0 does not exist")]
LogarithmOfZero,

Expand Down
98 changes: 61 additions & 37 deletions triton-vm/src/table/hash_table.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use itertools::izip;
use itertools::Itertools;
use ndarray::*;
use num_traits::Zero;
Expand Down Expand Up @@ -435,14 +436,15 @@ impl ExtHashTable {
let if_ci_is_sponge_init_then_round_number_is_0 =
if_ci_is_sponge_init_then_.clone() * round_number.clone();

let if_ci_is_sponge_init_then_rate_is_0 = (10..=15).map(|state_index| {
// the rate being zero is guaranteed by the evaluation argument
let if_ci_is_sponge_init_then_capacity_is_0 = (RATE..STATE_SIZE).map(|state_index| {
let state_element = base_row(Self::state_column_by_index(state_index));
if_ci_is_sponge_init_then_.clone() * state_element
});

let if_mode_is_hash_and_round_no_is_0_then_ = round_number_is_not_0 * mode_is_not_hash;
let if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1 =
(10..=15).map(|state_index| {
let if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1 = (RATE..STATE_SIZE)
.map(|state_index| {
let state_element = base_row(Self::state_column_by_index(state_index));
if_mode_is_hash_and_round_no_is_0_then_.clone() * (state_element - constant(1))
});
Expand Down Expand Up @@ -538,7 +540,7 @@ impl ExtHashTable {
if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0,
];

constraints.extend(if_ci_is_sponge_init_then_rate_is_0);
constraints.extend(if_ci_is_sponge_init_then_capacity_is_0);
constraints.extend(if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1);

for round_constant_column_idx in 0..NUM_ROUND_CONSTANTS {
Expand Down Expand Up @@ -711,6 +713,14 @@ impl ExtHashTable {
]
.map(challenge);

assert_eq!(STATE_SIZE, state_current.len());
assert_eq!(STATE_SIZE, state_next.len());
assert_eq!(STATE_SIZE, state_weights.len());
assert_eq!(
STATE_SIZE,
hash_function_round_correctly_performs_update.len()
);

let round_number_is_not_num_rounds =
Self::round_number_deselector(circuit_builder, &round_number, NUM_ROUNDS);

Expand All @@ -735,10 +745,16 @@ impl ExtHashTable {
* Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing)
* (compressed_digest - expected_program_digest);

let if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init =
// correct transition of the rate is handled by the evaluation argument
let randomized_sum_of_capacity_next = state_next[RATE..]
.iter()
.zip_eq(&state_weights[RATE..])
.map(|(state_next, weight)| weight.clone() * state_next.clone())
.sum();
let if_mode_is_program_hashing_and_next_mode_is_sponge_then_capacity_resets =
Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing)
* Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Sponge)
* (ci_next.clone() - opcode_sponge_init.clone());
* randomized_sum_of_capacity_next;

let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change =
(round_number.clone() - constant(NUM_ROUNDS as u64))
Expand All @@ -765,36 +781,43 @@ impl ExtHashTable {
Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad)
* Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad);

let difference_of_capacity_registers = state_current[RATE..]
.iter()
.zip_eq(state_next[RATE..].iter())
.map(|(current, next)| next.clone() - current.clone())
.collect_vec();
let randomized_sum_of_capacity_differences = state_weights[RATE..]
.iter()
.zip_eq(difference_of_capacity_registers)
.map(|(weight, state_difference)| weight.clone() * state_difference)
.sum::<ConstraintCircuitMonad<_>>();

let capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing =
let randomized_sum_of_capacity_differences = izip!(
&state_weights[RATE..],
&state_current[RATE..],
&state_next[RATE..]
)
.map(|(weight, current, next)| weight.clone() * (next.clone() - current.clone()))
.sum::<ConstraintCircuitMonad<_>>();

// If the next round number is 0, then
// - the next mode is either `Hash` or `Pad`, or
// - the next instruction is `sponge_init`, or
// - the capacity does not change, or
// - the current mode is `ProgramHashing` and (!) the next mode is `sponge`.
//
// The “and” marked with an exclamation mark requires above constraint to be
// expressed through two polynomials.
let capacity_doesnt_change_at_section_start_when_program_hashing_or_ =
Self::round_number_deselector(circuit_builder, &round_number_next, 0)
* Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
* Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad)
* (ci_next.clone() - opcode_sponge_init.clone())
* randomized_sum_of_capacity_differences.clone();

let difference_of_state_registers = state_current
.iter()
.zip_eq(state_next.iter())
.map(|(current, next)| next.clone() - current.clone())
.collect_vec();
let randomized_sum_of_state_differences = state_weights
.iter()
.zip_eq(difference_of_state_registers.iter())
.map(|(weight, state_difference)| weight.clone() * state_difference.clone())
.sum();
let if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change =
let capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing_0 =
capacity_doesnt_change_at_section_start_when_program_hashing_or_.clone()
* Self::select_mode(circuit_builder, &mode, HashTableMode::ProgramHashing);
let capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing_1 =
capacity_doesnt_change_at_section_start_when_program_hashing_or_
* Self::select_mode(circuit_builder, &mode_next, HashTableMode::Sponge);

let randomized_sum_of_state_differences =
izip!(&state_weights, &state_current, &state_next)
.map(|(weight, curr, next)| weight.clone() * (next.clone() - curr.clone()))
.sum();
let if_next_round_is_0_and_ci_next_is_squeeze_in_sponge_section_then_state_doesnt_change =
Self::round_number_deselector(circuit_builder, &round_number_next, 0)
* Self::select_mode(circuit_builder, &mode, HashTableMode::ProgramHashing)
* Self::instruction_deselector(circuit_builder, &ci_next, SpongeSqueeze)
* randomized_sum_of_state_differences;

Expand All @@ -807,7 +830,7 @@ impl ExtHashTable {
let tip5_input = state_next[..RATE].to_owned();
let compressed_row_from_processor = tip5_input
.into_iter()
.zip_eq(state_weights[..RATE].iter())
.zip_eq(&state_weights[..RATE])
.map(|(state, weight)| weight.clone() * state)
.sum();

Expand All @@ -831,7 +854,7 @@ impl ExtHashTable {
let hash_digest = state_next[..DIGEST_LENGTH].to_owned();
let compressed_row_hash_digest = hash_digest
.into_iter()
.zip_eq(state_weights[..DIGEST_LENGTH].iter())
.zip_eq(&state_weights[..DIGEST_LENGTH])
.map(|(state, weight)| weight.clone() * state)
.sum();
let running_evaluation_hash_digest_updates = running_evaluation_hash_digest_next
Expand All @@ -848,7 +871,7 @@ impl ExtHashTable {
// The running evaluation for “Sponge” updates correctly.
let compressed_row_next = state_weights[..RATE]
.iter()
.zip_eq(state_next[..RATE].iter())
.zip_eq(&state_next[..RATE])
.map(|(weight, st_next)| weight.clone() * st_next.clone())
.sum();
let running_evaluation_sponge_has_accumulated_ci = running_evaluation_sponge_next.clone()
Expand Down Expand Up @@ -900,14 +923,15 @@ impl ExtHashTable {
next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one,
receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0,
if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest,
if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init,
if_mode_is_program_hashing_and_next_mode_is_sponge_then_capacity_resets,
if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change,
if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change,
if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad,
if_mode_is_hash_then_mode_next_is_hash_or_pad,
if_mode_is_pad_then_mode_next_is_pad,
capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing,
if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change,
capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing_0,
capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing_1,
if_next_round_is_0_and_ci_next_is_squeeze_in_sponge_section_then_state_doesnt_change,
running_evaluation_hash_input_is_updated_correctly,
running_evaluation_hash_digest_is_updated_correctly,
running_evaluation_sponge_is_updated_correctly,
Expand Down Expand Up @@ -1692,7 +1716,7 @@ impl HashTable {
let compressed_row = |row: ArrayView1<BFieldElement>| -> XFieldElement {
rate_registers(row)
.iter()
.zip_eq(state_weights.iter())
.zip_eq(state_weights)
.map(|(&state, &weight)| weight * state)
.sum()
};
Expand Down Expand Up @@ -1758,7 +1782,7 @@ impl HashTable {
if in_hash_mode && in_last_round {
let compressed_digest: XFieldElement = rate_registers(row)[..DIGEST_LENGTH]
.iter()
.zip_eq(state_weights[..DIGEST_LENGTH].iter())
.zip_eq(&state_weights[..DIGEST_LENGTH])
.map(|(&state, &weight)| weight * state)
.sum();
hash_digest_running_evaluation = hash_digest_running_evaluation
Expand Down
Loading

0 comments on commit 09521e3

Please sign in to comment.