Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hadar/reduction from storage #745

Merged
merged 34 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d02cc1e
recover from mont branch
HadarIngonyama Jan 12, 2025
1b58dc4
remove extra math from field
HadarIngonyama Jan 13, 2025
2e34b68
how was this not formatted?
HadarIngonyama Jan 13, 2025
ba51110
Merge remote-tracking branch 'origin/main' into hadar/base-math
HadarIngonyama Jan 13, 2025
5d04ac8
fmt
HadarIngonyama Jan 13, 2025
90d2d7d
fffffmmmmttt
HadarIngonyama Jan 13, 2025
c1d532a
small fix
HadarIngonyama Jan 13, 2025
5d4fc30
small field fix
HadarIngonyama Jan 14, 2025
a42e08e
Merge remote-tracking branch 'origin/main' into hadar/reduce-from-sto…
HadarIngonyama Jan 15, 2025
337390e
first try
HadarIngonyama Jan 16, 2025
4527d93
first try
HadarIngonyama Jan 16, 2025
5a97dc0
first version works
HadarIngonyama Jan 16, 2025
960c6a7
secon version working
HadarIngonyama Jan 16, 2025
c928abb
adding params, from 3 working
HadarIngonyama Jan 17, 2025
4855b7d
bug fix - asymmetric mult 64
HadarIngonyama Jan 20, 2025
ef57d44
add mod sqr subs
HadarIngonyama Jan 22, 2025
67cdc90
add comments
HadarIngonyama Jan 22, 2025
c080d37
Merge remote-tracking branch 'origin/main' into hadar/basic-from-storage
HadarIngonyama Jan 22, 2025
47024ce
Merge remote-tracking branch 'origin/main' into hadar/basic-from-storage
HadarIngonyama Jan 22, 2025
c47a740
fix bugs
HadarIngonyama Jan 22, 2025
e007b43
added test
HadarIngonyama Jan 22, 2025
a406106
temp
HadarIngonyama Jan 23, 2025
12b2aa0
Merge remote-tracking branch 'origin/main' into hadar/reduce-from-sto…
HadarIngonyama Jan 23, 2025
7ad41f1
Merge branch 'hadar/reduce-from-storage' into hadar/basic-from-storage
HadarIngonyama Jan 23, 2025
e861c2e
merge
HadarIngonyama Jan 23, 2025
d7bfa3c
stark 252 works
HadarIngonyama Jan 27, 2025
fab2c90
all the fields work
HadarIngonyama Jan 27, 2025
3d91859
formatting
HadarIngonyama Jan 27, 2025
8efd7a7
small fix
HadarIngonyama Jan 27, 2025
db2a8cc
small fix
HadarIngonyama Jan 28, 2025
2e66dbb
spelling
HadarIngonyama Feb 3, 2025
1db8797
adding a non-template version
HadarIngonyama Feb 3, 2025
c805ca6
Merge remote-tracking branch 'origin/main' into hadar/basic-from-storage
HadarIngonyama Feb 3, 2025
8dee5b6
small change
HadarIngonyama Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion icicle/include/icicle/fields/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class Field
static HOST_DEVICE_INLINE Field inv_log_size(uint32_t logn)
{
if (logn == 0) { return Field{CONFIG::one}; }
base_math::inv_log_size_err(logn, CONFIG::omegas_count);
base_math::index_err(logn, CONFIG::omegas_count); // check if the requested size is within the valid range
HadarIngonyama marked this conversation as resolved.
Show resolved Hide resolved
storage_array<CONFIG::omegas_count, TLC> const inv = CONFIG::inv;
return Field{inv.storages[logn - 1]};
}
Expand Down Expand Up @@ -239,6 +239,20 @@ class Field
}
}

HadarIngonyama marked this conversation as resolved.
Show resolved Hide resolved
// access precomputed values for first step of the from storage function (see below)
static HOST_DEVICE_INLINE Field get_reduced_digit_for_storage_reduction(int i)
{
storage_array<CONFIG::reduced_digits_count, TLC> const reduced_digits = CONFIG::reduced_digits;
return Field{reduced_digits.storages[i]};
}

// access precomputed values for second step of the from storage function (see below)
static HOST_DEVICE_INLINE storage<2 * TLC + 2> get_mod_sub_for_storage_reduction(int i)
{
storage_array<CONFIG::mod_subs_count, 2 * TLC + 2> const mod_subs = CONFIG::mod_subs;
return mod_subs.storages[i];
}

template <unsigned NLIMBS, bool CARRY_OUT>
static constexpr HOST_DEVICE_INLINE uint32_t
add_limbs(const storage<NLIMBS>& xs, const storage<NLIMBS>& ys, storage<NLIMBS>& rs)
Expand Down Expand Up @@ -275,6 +289,16 @@ class Field
return rv;
}

template <unsigned NLIMBS>
static HOST_INLINE storage<NLIMBS> rand_storage(unsigned non_zero_limbs = NLIMBS)
{
std::uniform_int_distribution<unsigned> distribution;
storage<NLIMBS> value{};
for (unsigned i = 0; i < non_zero_limbs; i++)
value.limbs[i] = distribution(rand_generator);
return value;
}

// NOTE this function is used for test and examples - it assumed it is executed on a single-thread (no two threads
// accessing rand_generator at the same time)
static HOST_INLINE Field rand_host()
Expand Down Expand Up @@ -369,6 +393,54 @@ class Field
xs.limbs_storage, get_m(), get_modulus(), get_modulus<2>(), get_neg_modulus())};
}

/* This function receives a storage object (currently supports up to 576 bits) and reduces it to a field element
between 0 and p. This is done using 3 steps:
1. Splitting the number into TLC sized digits - xs = x_i * p_i = x_i * 2^(TLC*32*i).
p_i are precomputed modulo p and so the first step is performed multiplying by p_i and accumultaing.
At the end of this step the number is reduced from NLIMBS to 2*TLC+1 (assuming less than 2^32 additions).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at the code bellow 2*TLC+2 limbs are assigned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will clarify

2. The second step subtracts a single precomputed multiple of p in ordr to reduce the number into the range 0<x<2^(2n)
where n is the modulus bit count. This step makes use of a look-up table that looks at the top bits of the number (it
is enough to look at the bits from 2^(2n-1) and above).
3. The final step is the regular barrett reduction that reduces from the range 0<x<2^(2n) down to 0<x<p. */
template <unsigned NLIMBS>
static constexpr HOST_DEVICE_INLINE Field from(const storage<NLIMBS>& xs)
{
static_assert(NLIMBS * 32 <= 576); // for now we support up to 576 bits
storage<2 * TLC + 2> rs = {}; // we use 2*TLC+2 and not 2*TLC+1 because for now we don't support an odd number of
// limbs in the storage struct
int constexpr size = NLIMBS / TLC;
// first reduction step:
for (int i = 0; i < size; i++) // future optimization - because we assume a maximum value for size anyway, this loop
// can be unrolled with potential performance benefits
{
const Field& xi = *reinterpret_cast<const Field*>(xs.limbs + i * TLC); // use casting instead of copying
Field pi = get_reduced_digit_for_storage_reduction(i); // use precomputed values - pi = 2^(TLC*32*i) % p
storage<2 * TLC + 2> temp = {};
storage<2 * TLC>& temp_storage = *reinterpret_cast<storage<2 * TLC>*>(temp.limbs);
base_math::template multiply_raw<TLC>(xi.limbs_storage, pi.limbs_storage, temp_storage); // multiplication
base_math::template add_sub_limbs<2 * TLC + 2, false, false>(rs, temp, rs); // accumulation
}
int constexpr extra_limbs = NLIMBS - TLC * size;
if constexpr (extra_limbs > 0) { // handle the extra limbs (when TLC does not divide NLIMBS)
const storage<extra_limbs>& xi = *reinterpret_cast<const storage<extra_limbs>*>(xs.limbs + size * TLC);
Field pi = get_reduced_digit_for_storage_reduction(size);
storage<2 * TLC + 2> temp = {};
storage<extra_limbs + TLC>& temp_storage = *reinterpret_cast<storage<extra_limbs + TLC>*>(temp.limbs);
base_math::template multiply_raw<extra_limbs, TLC>(xi, pi.limbs_storage, temp_storage); // multiplication
base_math::template add_sub_limbs<2 * TLC + 2, false, false>(rs, temp, rs); // accumulation
}
// second reduction step: - an alternative for this step would be to use the barret reduction straight away but with
// a larger value of m.
unsigned constexpr msbits_count = 2 * TLC * 32 - (2 * NBITS - 1);
unsigned top_bits = (rs.limbs[2 * TLC] << msbits_count) + (rs.limbs[2 * TLC - 1] >> (32 - msbits_count));
base_math::template add_sub_limbs<2 * TLC + 2, true, false>(
rs, get_mod_sub_for_storage_reduction(top_bits),
rs); // subtracting the precomputed multiple of p from the look-up table
// third and final step:
storage<2 * TLC>& res = *reinterpret_cast<storage<2 * TLC>*>(rs.limbs);
return reduce(Wide{res}); // finally, use barret reduction
}

HOST_DEVICE Field& operator=(Field const& other)
{
#pragma unroll
Expand Down
19 changes: 10 additions & 9 deletions icicle/include/icicle/fields/host_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ namespace host_math {
static HOST_INLINE void multiply_raw_64(const uint64_t* a, const uint64_t* b, uint64_t* r)
{
#pragma unroll
for (unsigned j = 0; j < NLIMBS_A / 2; j++) {
for (unsigned i = 0; i < NLIMBS_B / 2; i++) {
uint64_t carry = 0;
#pragma unroll
for (unsigned i = 0; i < NLIMBS_B / 2; i++) {
for (unsigned j = 0; j < NLIMBS_A / 2; j++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you swap the loops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to fix a bug that was inserted in an earlier PR..

r[j + i] = host_math::madc_cc_64(a[j], b[i], r[j + i], carry);
}
r[NLIMBS_A / 2 + j] = carry;
r[NLIMBS_A / 2 + i] = carry;
}
}

Expand All @@ -256,8 +256,8 @@ namespace host_math {
multiply_raw(const storage<NLIMBS_A>& as, const storage<NLIMBS_B>& bs, storage<NLIMBS_A + NLIMBS_B>& rs)
{
static_assert(
(NLIMBS_A % 2 == 0 || NLIMBS_A == 1) && (NLIMBS_B % 2 == 0 || NLIMBS_B == 1),
"odd number of limbs is not supported\n");
((NLIMBS_A % 2 == 0 || NLIMBS_A == 1) && (NLIMBS_B % 2 == 0 || NLIMBS_B == 1)) || USE_32,
"odd number of limbs is not supported for 64 bit multiplication\n");
if constexpr (USE_32) {
multiply_raw_32<NLIMBS_A, NLIMBS_B>(as, bs, rs);
return;
Expand Down Expand Up @@ -363,12 +363,13 @@ namespace host_math {
{
return std::memcmp(xs.limbs, ys.limbs, NLIMBS * sizeof(xs.limbs[0])) == 0;
}
static constexpr void inv_log_size_err(uint32_t logn, uint32_t omegas_count)
// this function checks if the given index is within the array range
static constexpr void index_err(uint32_t index, uint32_t max_index)
HadarIngonyama marked this conversation as resolved.
Show resolved Hide resolved
{
if (logn > omegas_count)
if (index > max_index)
THROW_ICICLE_ERR(
icicle::eIcicleError::INVALID_ARGUMENT,
"Field: Invalid inv index" + std::to_string(logn) + ">" + std::to_string(omegas_count));
icicle::eIcicleError::INVALID_ARGUMENT, "Field: index out of range: given index -" + std::to_string(index) +
"> max index - " + std::to_string(max_index));
}

template <unsigned NLIMBS>
Expand Down
28 changes: 28 additions & 0 deletions icicle/include/icicle/fields/params_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,31 @@ namespace params_gen {
}
return invs;
}

// This function generates the precomputed values for the second step of the from storage function in the field class.
// However, it is not used due to constexpr computation overflow. Instead, the values are currently generated by a
// python script and appear explicitly in the field config struct.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for leaving it then? Do you plan on fixing the overflow issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason to leave it is to preserve the logic, because the python script is not in icicle, and maybe we will find a way to fix the overflow yes

template <unsigned NLIMBS, unsigned mod_subs_count, unsigned mod_bit_count>
constexpr storage_array<mod_subs_count, 2 * NLIMBS + 2> get_modulus_subs(const storage<NLIMBS>& modulus)
{
storage_array<mod_subs_count, 2 * NLIMBS + 2> mod_subs = {};
unsigned constexpr bit_shift = 2 * mod_bit_count - 1;
mod_subs.storages[0] = {0};
for (int i = 1; i < mod_subs_count; i++) {
storage<2 * NLIMBS + 2> temp = {};
storage<NLIMBS> rs = {};
storage<NLIMBS + 2> mod_sub_factor = {};
temp.limbs[0] = i;
storage<2 * NLIMBS + 2> candidate = host_math::template left_shift<2 * NLIMBS + 2, bit_shift>(temp);
host_math::template integer_division<2 * NLIMBS + 2, NLIMBS, NLIMBS + 2, true>( // find the closest multiple of p
// to subtract.
candidate, modulus, mod_sub_factor, rs);
storage<2 * NLIMBS + 2> temp2 = {};
host_math::template multiply_raw<NLIMBS + 2, NLIMBS, true>(mod_sub_factor, modulus, temp2);
mod_subs.storages[i] = temp2;
}
return mod_subs;
}
} // namespace params_gen

#define PARAMS(modulus) \
Expand All @@ -119,6 +144,9 @@ namespace params_gen {
static constexpr unsigned num_of_reductions = \
params_gen::template num_of_reductions<limbs_count, 2 * modulus_bit_count>(modulus, m);

#define MOD_SQR_SUBS() \
static constexpr unsigned mod_subs_count = reduced_digits_count << (limbs_count * 32 + 1 - modulus_bit_count);

#define TWIDDLES(modulus, rou) \
static constexpr unsigned omegas_count = params_gen::template two_adicity<limbs_count>(modulus); \
static constexpr storage_array<omegas_count, limbs_count> inv = \
Expand Down
Loading
Loading