Skip to content

Commit

Permalink
Merge branch 'IrreducibleOSS:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurpaulino authored Jan 31, 2025
2 parents 0cf7bb6 + f1d3e11 commit 2be3fe5
Show file tree
Hide file tree
Showing 152 changed files with 2,579 additions and 1,689 deletions.
5 changes: 4 additions & 1 deletion crates/circuits/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ version.workspace = true
edition.workspace = true
authors.workspace = true

[lints]
workspace = true

[dependencies]
binius_core = { path = "../core" }
binius_field = { path = "../field" }
Expand All @@ -24,4 +27,4 @@ bumpalo.workspace = true
[dev-dependencies]
binius_hal = { path = "../hal" }
groestl_crypto = { package = "groestl", version = "0.10.1" }
sha2 = { version = "0.10.8", features = ["compress"] }
sha2 = { version = "0.10.8", features = ["compress"] }
100 changes: 75 additions & 25 deletions crates/circuits/src/builder/constraint_system.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// Copyright 2024-2025 Irreducible Inc.

use core::iter::IntoIterator;
use std::{cell::RefCell, rc::Rc};
use std::{cell::RefCell, collections::HashMap, rc::Rc};

use anyhow::anyhow;
use anyhow::{anyhow, ensure};
use binius_core::{
constraint_system::{
channel::{ChannelId, Flush, FlushDirection},
Expand All @@ -14,9 +13,12 @@ use binius_core::{
ProjectionVariant, ShiftVariant,
},
polynomial::MultivariatePoly,
transparent::step_down::StepDown,
witness::MultilinearExtensionIndex,
};
use binius_field::{as_packed_field::PackScalar, underlier::UnderlierType, TowerField};
use binius_field::{
as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField,
};
use binius_math::ArithExpr;
use binius_utils::bail;

Expand All @@ -32,6 +34,7 @@ where
constraints: ConstraintSetBuilder<F>,
non_zero_oracle_ids: Vec<OracleId>,
flushes: Vec<Flush>,
step_down_dedup: HashMap<(usize, usize), OracleId>,
witness: Option<witness::Builder<'arena, U, F>>,
next_channel_id: ChannelId,
namespace_path: Vec<String>,
Expand Down Expand Up @@ -68,9 +71,9 @@ where
table_constraints,
non_zero_oracle_ids: self.non_zero_oracle_ids,
oracles: Rc::into_inner(self.oracles)
.ok_or(anyhow!(
"Failed to build ConstraintSystem: references still exist to oracles"
))?
.ok_or_else(|| {
anyhow!("Failed to build ConstraintSystem: references still exist to oracles")
})?
.into_inner(),
flushes: self.flushes,
})
Expand All @@ -95,49 +98,96 @@ where
direction: FlushDirection,
channel_id: ChannelId,
count: usize,
oracle_ids: impl IntoIterator<Item = OracleId>,
) {
self.flushes.push(Flush {
channel_id,
direction,
count,
oracles: oracle_ids.into_iter().collect(),
multiplicity: 1,
})
oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
) -> anyhow::Result<()>
where
U: PackScalar<BinaryField1b>,
{
self.flush_with_multiplicity(direction, channel_id, count, oracle_ids, 1)
}

pub fn flush_with_multiplicity(
&mut self,
direction: FlushDirection,
channel_id: ChannelId,
count: usize,
oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
multiplicity: u64,
) -> anyhow::Result<()>
where
U: PackScalar<BinaryField1b>,
{
let n_vars = self.log_rows(oracle_ids.clone())?;

let selector = if let Some(&selector) = self.step_down_dedup.get(&(n_vars, count)) {
selector
} else {
let step_down = StepDown::new(n_vars, count)?;
let selector = self.add_transparent(
format!("internal step_down {count}-{n_vars}"),
step_down.clone(),
)?;

if let Some(witness) = self.witness() {
step_down.populate(witness.new_column::<BinaryField1b>(selector).packed());
}

self.step_down_dedup.insert((n_vars, count), selector);
selector
};

self.flush_custom(direction, channel_id, selector, oracle_ids, multiplicity)
}

pub fn flush_custom(
&mut self,
direction: FlushDirection,
channel_id: ChannelId,
selector: OracleId,
oracle_ids: impl IntoIterator<Item = OracleId>,
multiplicity: u64,
) {
) -> anyhow::Result<()> {
let oracles = oracle_ids.into_iter().collect::<Vec<_>>();
let log_rows = self.log_rows(oracles.iter().copied())?;
ensure!(
log_rows == self.log_rows([selector])?,
"Selector {} n_vars does not match flush {:?}",
selector,
oracles
);

self.flushes.push(Flush {
channel_id,
direction,
count,
oracles: oracle_ids.into_iter().collect(),
selector,
oracles,
multiplicity,
})
});

Ok(())
}

pub fn send(
&mut self,
channel_id: ChannelId,
count: usize,
oracle_ids: impl IntoIterator<Item = OracleId>,
) {
oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
) -> anyhow::Result<()>
where
U: PackScalar<BinaryField1b>,
{
self.flush(FlushDirection::Push, channel_id, count, oracle_ids)
}

pub fn receive(
&mut self,
channel_id: ChannelId,
count: usize,
oracle_ids: impl IntoIterator<Item = OracleId>,
) {
oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
) -> anyhow::Result<()>
where
U: PackScalar<BinaryField1b>,
{
self.flush(FlushDirection::Pull, channel_id, count, oracle_ids)
}

Expand Down Expand Up @@ -334,7 +384,7 @@ where
pub fn log_rows(
&self,
oracle_ids: impl IntoIterator<Item = OracleId>,
) -> Result<usize, anyhow::Error> {
) -> anyhow::Result<usize> {
let mut oracle_ids = oracle_ids.into_iter();
let oracles = self.oracles.borrow();
let Some(first_id) = oracle_ids.next() else {
Expand Down
8 changes: 4 additions & 4 deletions crates/circuits/src/builder/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ where
pub fn build(self) -> Result<MultilinearExtensionIndex<'arena, U, FW>, Error> {
let mut result = MultilinearExtensionIndex::new();
let entries = Rc::into_inner(self.entries)
.ok_or(anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
.ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
.into_inner()
.into_iter()
.enumerate()
Expand All @@ -172,7 +172,7 @@ impl<'arena, U: PackScalar<FS>, FS: TowerField> WitnessEntry<'arena, U, FS> {
WithUnderlier::from_underliers_ref(self.data)
}

pub fn repacked<FW>(&self) -> WitnessEntry<'arena, U, FW>
pub const fn repacked<FW>(&self) -> WitnessEntry<'arena, U, FW>
where
FW: TowerField + ExtensionField<FS>,
U: PackScalar<FW>,
Expand All @@ -184,14 +184,14 @@ impl<'arena, U: PackScalar<FS>, FS: TowerField> WitnessEntry<'arena, U, FS> {
}
}

pub fn low_rows(&self) -> usize {
pub const fn low_rows(&self) -> usize {
self.log_rows
}
}

impl<'arena, U: PackScalar<FS> + Pod, FS: TowerField> WitnessEntry<'arena, U, FS> {
#[inline]
pub fn as_slice<T: Pod>(&self) -> &'arena [T] {
pub const fn as_slice<T: Pod>(&self) -> &'arena [T] {
must_cast_slice(self.data)
}
}
Expand Down
10 changes: 5 additions & 5 deletions crates/circuits/src/collatz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub struct Collatz {
}

impl Collatz {
pub fn new(x0: u32) -> Self {
pub const fn new(x0: u32) -> Self {
Self {
x0,
evens: vec![],
Expand Down Expand Up @@ -82,10 +82,10 @@ impl Collatz {
let half = arithmetic::u32::half(builder, "half", even, arithmetic::Flags::Checked)?;

let even_packed = arithmetic::u32::packed(builder, "even_packed", even)?;
builder.receive(channel, count, [even_packed]);
builder.receive(channel, count, [even_packed])?;

let half_packed = arithmetic::u32::packed(builder, "half_packed", half)?;
builder.send(channel, count, [half_packed]);
builder.send(channel, count, [half_packed])?;

Ok(())
}
Expand Down Expand Up @@ -127,11 +127,11 @@ impl Collatz {
)?;

let odd_packed = arithmetic::u32::packed(builder, "odd_packed", odd)?;
builder.receive(channel, count, [odd_packed]);
builder.receive(channel, count, [odd_packed])?;

let triple_plus_one_packed =
arithmetic::u32::packed(builder, "triple_plus_one_packed", triple_plus_one)?;
builder.send(channel, count, [triple_plus_one_packed]);
builder.send(channel, count, [triple_plus_one_packed])?;

Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions crates/circuits/src/lasso/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use anyhow::Ok;
use binius_core::oracle::OracleId;
use binius_field::{
as_packed_field::{PackScalar, PackedType},
ExtensionField, PackedFieldIndexable, TowerField,
BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField,
};
use itertools::Itertools;

Expand Down Expand Up @@ -53,7 +53,7 @@ impl LookupBatch {
builder: &mut ConstraintSystemBuilder<U, F>,
) -> Result<(), anyhow::Error>
where
U: PackScalar<FC> + PackScalar<F>,
U: PackScalar<FC> + PackScalar<F> + PackScalar<BinaryField1b>,
PackedType<U, FC>: PackedFieldIndexable,
FC: TowerField,
F: ExtensionField<FC> + TowerField,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ where
return Ok((carry_out, sum_arr));
}

builder.push_namespace(name.clone());
builder.push_namespace(name);

let (lower_half_x, upper_half_x) = Level::split(x_in);
let (lower_half_y, upper_half_y) = Level::split(y_in);
Expand Down
19 changes: 10 additions & 9 deletions crates/circuits/src/lasso/lasso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use binius_core::{constraint_system::channel::ChannelId, oracle::OracleId};
use binius_field::{
as_packed_field::{PackScalar, PackedType},
underlier::UnderlierType,
ExtensionField, PackedFieldIndexable, TowerField,
BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField,
};
use itertools::{izip, Itertools};

Expand All @@ -21,7 +21,7 @@ pub fn lasso<U, F, FC>(
channel: ChannelId,
) -> Result<()>
where
U: UnderlierType + PackScalar<F> + PackScalar<FC>,
U: UnderlierType + PackScalar<F> + PackScalar<FC> + PackScalar<BinaryField1b>,
F: TowerField + ExtensionField<FC> + From<FC>,
PackedType<U, FC>: PackedFieldIndexable,
FC: TowerField,
Expand Down Expand Up @@ -119,19 +119,20 @@ where
let oracles_prefix_t = lookup_t.as_ref().iter().copied();

// populate table using initial timestamps
builder.send(channel, 1 << t_log_rows, oracles_prefix_t.clone().chain([lookup_o]));
builder.send(channel, 1 << t_log_rows, oracles_prefix_t.clone().chain([lookup_o]))?;

// for every value looked up, pull using current timestamp and push with incremented timestamp
izip!(lookups_u, lookups_r, lookups_w, n_lookups).for_each(
|(lookup_u, lookup_r, lookup_w, &n_lookup)| {
izip!(lookups_u, lookups_r, lookups_w, n_lookups).try_for_each(
|(lookup_u, lookup_r, lookup_w, &n_lookup)| -> Result<()> {
let oracle_prefix_u = lookup_u.as_ref().iter().copied();
builder.receive(channel, n_lookup, oracle_prefix_u.clone().chain([lookup_r]));
builder.send(channel, n_lookup, oracle_prefix_u.chain([lookup_w]));
builder.receive(channel, n_lookup, oracle_prefix_u.clone().chain([lookup_r]))?;
builder.send(channel, n_lookup, oracle_prefix_u.chain([lookup_w]))?;
Ok(())
},
);
)?;

// depopulate table using final timestamps
builder.receive(channel, 1 << t_log_rows, oracles_prefix_t.chain([lookup_f]));
builder.receive(channel, 1 << t_log_rows, oracles_prefix_t.chain([lookup_f]))?;

Ok(())
}
8 changes: 4 additions & 4 deletions crates/circuits/src/lasso/lookups/u8_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ where
PackedType<U, B32>: PackedFieldIndexable,
F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
builder.push_namespace(name.clone());
builder.push_namespace(name);

let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_MUL, B32::TOWER_LEVEL);

Expand Down Expand Up @@ -64,7 +64,7 @@ where
PackedType<U, B32>: PackedFieldIndexable,
F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
builder.push_namespace(name.clone());
builder.push_namespace(name);

let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL);

Expand Down Expand Up @@ -106,7 +106,7 @@ where
PackedType<U, B32>: PackedFieldIndexable,
F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
builder.push_namespace(name.clone());
builder.push_namespace(name);

let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL);

Expand Down Expand Up @@ -150,7 +150,7 @@ where
PackedType<U, B32>: PackedFieldIndexable,
F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
builder.push_namespace(name.clone());
builder.push_namespace(name);

let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_DCI, B32::TOWER_LEVEL);

Expand Down
1 change: 0 additions & 1 deletion crates/circuits/src/lasso/u32add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ where
+ PackScalar<FOutput>,
PackedType<U, B32>: PackedFieldIndexable,
PackedType<U, B8>: PackedFieldIndexable,
PackedType<U, B32>: PackedFieldIndexable,
B8: ExtensionField<FInput> + ExtensionField<FOutput>,
F: TowerField
+ ExtensionField<B32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ where
PackedType<U, B32>: PackedFieldIndexable,
F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
builder.push_namespace(name.clone());
builder.push_namespace(name);

let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL);

Expand Down
2 changes: 1 addition & 1 deletion crates/circuits/src/lasso/u8add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where
PackedType<U, B32>: PackedFieldIndexable,
F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
builder.push_namespace(name.clone());
builder.push_namespace(name);

let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL);

Expand Down
Loading

0 comments on commit 2be3fe5

Please sign in to comment.