Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DSLX:TS] Fixes ability to use xN based types as bit slices, use for std::{min,max} #1825

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions docs_src/tutorials/intro_to_parametrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

This tutorial demonstrates how types and functions can be parameterized to
enable them to work on data of different formats and layouts, e.g., for a
function `foo` to work on both u16 and u32 data types, and anywhere in between.
function `foo` to work on both `u16` and `u32` data types, and anywhere in
between.

It's recommended that you're familiar with the concepts in the previous
tutorial,
Expand All @@ -11,7 +12,7 @@ before following this tutorial.

## Simple parametrics

Consider the simple example of the `umax` function
Consider the simple example of a `umax` function -- similar to the `max` function
[in the DSLX standard library](https://github.com/google/xls/tree/main/xls/dslx/stdlib/std.x):

```dslx
Expand Down Expand Up @@ -40,10 +41,12 @@ infer them:
Explicit specification:

```dslx
import std;
fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] {
if x > y { x } else { y }
}

fn foo(a: u32, b: u16) -> u64 {
std::umax<u32:64>(a as u64, b as u64)
umax<u32:64>(a as u64, b as u64)
}
```

Expand All @@ -53,10 +56,12 @@ are.
Parametric inference:

```dslx
import std;
fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] {
if x > y { x } else { y }
}

fn foo(a: u32, b: u16) -> u64 {
std::umax(a as u64, b as u64)
umax(a as u64, b as u64)
}
```

Expand Down
11 changes: 7 additions & 4 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,12 +960,15 @@ absl::Status FunctionConverter::HandleBuiltinCheckedCast(
int64_t old_bit_count,
std::get<InterpValue>(input_bit_count_ctd.value()).GetBitValueViaSign());

if (dynamic_cast<ArrayType*>(output_type.get()) != nullptr ||
dynamic_cast<ArrayType*>(input_type.get()) != nullptr) {
std::optional<BitsLikeProperties> output_bits_like =
GetBitsLike(*output_type);
std::optional<BitsLikeProperties> input_bits_like = GetBitsLike(*input_type);

if (!output_bits_like.has_value() || !input_bits_like.has_value()) {
return IrConversionErrorStatus(
node->span(),
absl::StrFormat("CheckedCast to and from array "
"is not currently supported for IR conversion; "
absl::StrFormat("CheckedCast is only supported for bits-like types in "
"IR conversion; "
"attempted checked cast from: %s to: %s",
input_type->ToString(), output_type->ToString()),
file_table());
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/stdlib/apfloat.x
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@ fn test_fp_lt_2() {
fn to_signed_or_unsigned_int<RESULT_SZ: u32, RESULT_SIGNED: bool, EXP_SZ: u32, FRACTION_SZ: u32>
(x: APFloat<EXP_SZ, FRACTION_SZ>) -> xN[RESULT_SIGNED][RESULT_SZ] {
const WIDE_FRACTION: u32 = FRACTION_SZ + u32:1;
const MAX_FRACTION_SZ: u32 = std::umax(RESULT_SZ, WIDE_FRACTION);
const MAX_FRACTION_SZ: u32 = std::max(RESULT_SZ, WIDE_FRACTION);

const INT_MIN = if RESULT_SIGNED {
(uN[MAX_FRACTION_SZ]:1 << (RESULT_SZ - u32:1)) // or rather, its negative.
Expand Down
135 changes: 65 additions & 70 deletions xls/dslx/stdlib/std.x
Original file line number Diff line number Diff line change
Expand Up @@ -90,74 +90,68 @@ fn unsigned_max_value_test() {
assert_eq(u32:0xffffffff, unsigned_max_value<u32:32>());
}

// Returns the maximum of two signed integers.
pub fn smax<N: u32>(x: sN[N], y: sN[N]) -> sN[N] { if x > y { x } else { y } }
// Returns the maximum of two (signed or unsigned) integers.
pub fn max<S: bool, N: u32>(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x > y { x } else { y } }

#[test]
fn smax_test() {
assert_eq(s2:0, smax(s2:0, s2:0));
assert_eq(s2:1, smax(s2:-1, s2:1));
assert_eq(s7:-3, smax(s7:-3, s7:-6));
fn max_test_signed() {
assert_eq(s2:0, max(s2:0, s2:0));
assert_eq(s2:1, max(s2:-1, s2:1));
assert_eq(s7:-3, max(s7:-3, s7:-6));
}

// Returns the maximum of two unsigned integers.
pub fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x > y { x } else { y } }

#[test]
fn umax_test() {
assert_eq(u1:1, umax(u1:1, u1:0));
assert_eq(u1:1, umax(u1:1, u1:1));
assert_eq(u2:3, umax(u2:3, u2:2));
fn max_test_unsigned() {
assert_eq(u1:1, max(u1:1, u1:0));
assert_eq(u1:1, max(u1:1, u1:1));
assert_eq(u2:3, max(u2:3, u2:2));
}

// Returns the maximum of two signed integers.
pub fn smin<N: u32>(x: sN[N], y: sN[N]) -> sN[N] { if x < y { x } else { y } }
// Returns the minimum of two (signed or unsigned) integers.
pub fn min<S: bool, N: u32>(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x < y { x } else { y } }

#[test]
fn smin_test() {
assert_eq(s1:0, smin(s1:0, s1:0));
assert_eq(s1:-1, smin(s1:0, s1:1));
assert_eq(s1:-1, smin(s1:1, s1:0));
assert_eq(s1:-1, smin(s1:1, s1:1));

assert_eq(s2:-2, smin(s2:0, s2:-2));
assert_eq(s2:-1, smin(s2:0, s2:-1));
assert_eq(s2:0, smin(s2:0, s2:0));
assert_eq(s2:0, smin(s2:0, s2:1));
fn min_test_unsigned() {
assert_eq(u1:0, min(u1:1, u1:0));
assert_eq(u1:1, min(u1:1, u1:1));
assert_eq(u2:2, min(u2:3, u2:2));
}

assert_eq(s2:-2, smin(s2:1, s2:-2));
assert_eq(s2:-1, smin(s2:1, s2:-1));
assert_eq(s2:0, smin(s2:1, s2:0));
assert_eq(s2:1, smin(s2:1, s2:1));
#[test]
fn min_test_signed() {
assert_eq(s1:0, min(s1:0, s1:0));
assert_eq(s1:-1, min(s1:0, s1:1));
assert_eq(s1:-1, min(s1:1, s1:0));
assert_eq(s1:-1, min(s1:1, s1:1));

assert_eq(s2:-2, smin(s2:-2, s2:-2));
assert_eq(s2:-2, smin(s2:-2, s2:-1));
assert_eq(s2:-2, smin(s2:-2, s2:0));
assert_eq(s2:-2, smin(s2:-2, s2:1));
assert_eq(s2:-2, min(s2:0, s2:-2));
assert_eq(s2:-1, min(s2:0, s2:-1));
assert_eq(s2:0, min(s2:0, s2:0));
assert_eq(s2:0, min(s2:0, s2:1));

assert_eq(s2:-2, smin(s2:-1, s2:-2));
assert_eq(s2:-1, smin(s2:-1, s2:-1));
assert_eq(s2:-1, smin(s2:-1, s2:0));
assert_eq(s2:-1, smin(s2:-1, s2:1));
}
assert_eq(s2:-2, min(s2:1, s2:-2));
assert_eq(s2:-1, min(s2:1, s2:-1));
assert_eq(s2:0, min(s2:1, s2:0));
assert_eq(s2:1, min(s2:1, s2:1));

// Returns the minimum of two unsigned integers.
pub fn umin<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x < y { x } else { y } }
assert_eq(s2:-2, min(s2:-2, s2:-2));
assert_eq(s2:-2, min(s2:-2, s2:-1));
assert_eq(s2:-2, min(s2:-2, s2:0));
assert_eq(s2:-2, min(s2:-2, s2:1));

#[test]
fn umin_test() {
assert_eq(u1:0, umin(u1:1, u1:0));
assert_eq(u1:1, umin(u1:1, u1:1));
assert_eq(u2:2, umin(u2:3, u2:2));
assert_eq(s2:-2, min(s2:-1, s2:-2));
assert_eq(s2:-1, min(s2:-1, s2:-1));
assert_eq(s2:-1, min(s2:-1, s2:0));
assert_eq(s2:-1, min(s2:-1, s2:1));
}

// Returns unsigned add of x (N bits) and y (M bits) as a max(N,M)+1 bit value.
pub fn uadd<N: u32, M: u32, R: u32 = {umax(N, M) + u32:1}>(x: uN[N], y: uN[M]) -> uN[R] {
pub fn uadd<N: u32, M: u32, R: u32 = {max(N, M) + u32:1}>(x: uN[N], y: uN[M]) -> uN[R] {
(x as uN[R]) + (y as uN[R])
}

// Returns signed add of x (N bits) and y (M bits) as a max(N,M)+1 bit value.
pub fn sadd<N: u32, M: u32, R: u32 = {umax(N, M) + u32:1}>(x: sN[N], y: sN[M]) -> sN[R] {
pub fn sadd<N: u32, M: u32, R: u32 = {max(N, M) + u32:1}>(x: sN[N], y: sN[M]) -> sN[R] {
(x as sN[R]) + (y as sN[R])
}

Expand Down Expand Up @@ -773,7 +767,7 @@ fn test_to_unsigned() {
// let result : (bool, u16) = uadd_with_overflow<u32:16>(x, y);
//
pub fn uadd_with_overflow
<V: u32, N: u32, M: u32, MAX_N_M: u32 = {umax(N, M)}, MAX_N_M_V: u32 = {umax(MAX_N_M, V)}>
<V: u32, N: u32, M: u32, MAX_N_M: u32 = {max(N, M)}, MAX_N_M_V: u32 = {max(MAX_N_M, V)}>
(x: uN[N], y: uN[M]) -> (bool, uN[V]) {

let x_extended = widening_cast<uN[MAX_N_M_V + u32:1]>(x);
Expand Down Expand Up @@ -801,47 +795,48 @@ fn test_uadd_with_overflow() {
}

// Extract bits given a fixed-point integer with a constant offset.
// i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + fixed_shift, to_exclusive)];
// (x_extended << fixed_shift)[from_inclusive:to_exclusive]
// i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + FIXED_SHIFT, TO_EXCLUSIVE)];
// (x_extended << FIXED_SHIFT)[FROM_INCLUSIVE:TO_EXCLUSIVE]
//
// This function behaves as-if x has reasonably infinite precision so that
// the result is zero-padded if from_inclusive or to_exclusive are out of
// the result is zero-padded if FROM_INCLUSIVE or TO_EXCLUSIVE are out of
// range of the original x's bitwidth.
//
// If to_exclusive <= from_exclusive, the result will be a zero-bit uN[0].
// If TO_EXCLUSIVE <= FROM_INCLUSIVE, the result will be a zero-bit uN[0].
pub fn extract_bits
<from_inclusive: u32, to_exclusive: u32, fixed_shift: u32, N: u32,
extract_width: u32 = {smax(s32:0, to_exclusive as s32 - from_inclusive as s32) as u32}>
(x: uN[N]) -> uN[extract_width] {
if to_exclusive <= from_inclusive {
uN[extract_width]:0
<FROM_INCLUSIVE: u32, TO_EXCLUSIVE: u32, FIXED_SHIFT: u32, N: u32,
EXTRACT_WIDTH: u32 = {max(s32:0, TO_EXCLUSIVE as s32 - FROM_INCLUSIVE as s32) as u32}>
(x: uN[N]) -> uN[EXTRACT_WIDTH] {
if TO_EXCLUSIVE <= FROM_INCLUSIVE {
uN[EXTRACT_WIDTH]:0
} else {
// With a non-zero fixed width, all lower bits of index < fixed_shift are
// are zero.
let lower_bits =
uN[checked_cast<u32>(smax(s32:0, fixed_shift as s32 - from_inclusive as s32))]:0;
uN[checked_cast<u32>(max(s32:0, FIXED_SHIFT as s32 - FROM_INCLUSIVE as s32))]:0;

// Based on the input of N bits and a fixed shift, there are an effective
// count of N + fixed_shift known bits. All bits of index >
// N + fixed_shift - 1 are zero's.
const UPPER_BIT_COUNT = checked_cast<u32>(
smax(s32:0, N as s32 + fixed_shift as s32 - to_exclusive as s32 - s32:1));
let upper_bits = uN[UPPER_BIT_COUNT]:0;
max(s32:0, N as s32 + FIXED_SHIFT as s32 - TO_EXCLUSIVE as s32 - s32:1));
const UPPER_BITS = uN[UPPER_BIT_COUNT]:0;

if fixed_shift < from_inclusive {
if FIXED_SHIFT < FROM_INCLUSIVE {
// The bits extracted start within or after the middle span.
// upper_bits ++ middle_bits
let middle_bits = upper_bits ++
x[smin(from_inclusive as s32 - fixed_shift as s32, N as s32)
:smin(to_exclusive as s32 - fixed_shift as s32, N as s32)];
(upper_bits ++ middle_bits) as uN[extract_width]
} else if fixed_shift <= to_exclusive {
const FROM: s32 = min(FROM_INCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
let middle_bits = UPPER_BITS ++ x[FROM:TO];
(UPPER_BITS ++ middle_bits) as uN[EXTRACT_WIDTH]
} else if FIXED_SHIFT <= TO_EXCLUSIVE {
// The bits extracted start within the fixed_shift span.
let middle_bits = x[0:smin(to_exclusive as s32 - fixed_shift as s32, N as s32)];
const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
let middle_bits = x[0:TO];

(upper_bits ++ middle_bits ++ lower_bits) as uN[extract_width]
(UPPER_BITS ++ middle_bits ++ lower_bits) as uN[EXTRACT_WIDTH]
} else {
uN[extract_width]:0
uN[EXTRACT_WIDTH]:0
}
}
}
Expand Down Expand Up @@ -928,7 +923,7 @@ pub fn umul_with_overflow
<V: u32, N: u32, M: u32, N_lower_bits: u32 = {N >> u32:1},
N_upper_bits: u32 = {N - N_lower_bits}, M_lower_bits: u32 = {M >> u32:1},
M_upper_bits: u32 = {M - M_lower_bits},
Min_N_M_lower_bits: u32 = {umin(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}>
Min_N_M_lower_bits: u32 = {min(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}>
(x: uN[N], y: uN[M]) -> (bool, uN[V]) {
// Break x and y into two halves.
// x = x1 ++ x0,
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ dslx_lang_test(name = "xn_type_equivalence")

dslx_lang_test(name = "xn_signedness_properties")

dslx_lang_test(name = "xn_slice_bounds")

dslx_lang_test(name = "xn_widening_cast")

dslx_lang_test(
name = "parametric_shift",
# TODO(leary): 2023-08-14 Runs into "cannot translate zero length bitvector
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/tests/errors/error_modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,8 +873,8 @@ def test_equals_rhs_undefined_nameref(self):

def test_umin_type_mismatch(self):
stderr = self._run('xls/dslx/tests/errors/umin_type_mismatch.x')
self.assertIn('umin_type_mismatch.x:21:12-21:27', stderr)
self.assertIn('XlsTypeError: uN[N] vs uN[8]', stderr)
self.assertIn('umin_type_mismatch.x:21:13-21:28', stderr)
self.assertIn('saw: 42; then: 8', stderr)

def test_diag_block_with_trailing_semi(self):
stderr = self._run(
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/tests/errors/spawn_wrong_argc.x
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub proc foo {
config () { () }

next(state: ()) {
std::umin(u32:1, u32:2);
std::min(u32:1, u32:2);
()
}
}
Expand All @@ -37,7 +37,7 @@ proc test_case {
}

next(state: ()) {
std::umin(u32:1, u32:2);
std::min(u32:1, u32:2);
let tok = send(join(), terminator, true);
()
}
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/tests/errors/umin_type_mismatch.x
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ const MY_U32 = u42:42;
const MY_U8 = u8:42;

fn f() -> u32 {
std::umin(MY_U32, MY_U8)
std::min(MY_U32, MY_U8)
}
33 changes: 33 additions & 0 deletions xls/dslx/tests/xn_slice_bounds.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2025 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

const S = true;
const N = u32:32;

type MyS32 = xN[S][N];

fn from_to(x: u32) -> u8 { x[MyS32:0:MyS32:8] }

fn to(x: u32) -> u8 { x[:MyS32:8] }

fn from(x: u32) -> u8 { x[MyS32:-8:] }

fn main(x: u32) -> u8[3] { [from_to(x), to(x), from(x)] }

#[test]
fn test_main() {
assert_eq(from_to(u32:0x12345678), u8:0x78);
assert_eq(to(u32:0x12345678), u8:0x78);
assert_eq(from(u32:0x12345678), u8:0x12);
}
17 changes: 17 additions & 0 deletions xls/dslx/tests/xn_widening_cast.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2025 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import std;

fn main(x: u7, y: u7) -> u32 { widening_cast<u32>(std::max(x, y)) }
Loading
Loading