From 3c2b6d2350ed957452307d799ab69f424efabb64 Mon Sep 17 00:00:00 2001
From: Chris Leary <cdleary@openai.com>
Date: Thu, 2 Jan 2025 17:14:56 -0800
Subject: [PATCH 1/2] [DSLX] Fix ability to use xN as bit slice bounds.

---
 docs_src/tutorials/intro_to_parametrics.md  |  17 ++-
 xls/dslx/ir_convert/function_converter.cc   |  11 +-
 xls/dslx/stdlib/apfloat.x                   |   2 +-
 xls/dslx/stdlib/std.x                       | 135 ++++++++++----------
 xls/dslx/tests/BUILD                        |   2 +
 xls/dslx/tests/errors/error_modules_test.py |   4 +-
 xls/dslx/tests/errors/spawn_wrong_argc.x    |   4 +-
 xls/dslx/tests/errors/umin_type_mismatch.x  |   2 +-
 xls/dslx/tests/xn_slice_bounds.x            |  33 +++++
 xls/dslx/type_system/deduce.cc              |  52 ++++----
 xls/dslx/type_system/type.cc                |  10 ++
 xls/dslx/type_system/type.h                 |  17 +++
 xls/dslx/type_system/type_info.h            |  14 +-
 xls/modules/zstd/dec_mux.x                  |   2 +-
 xls/modules/zstd/memory/axi_reader.x        |   2 +-
 xls/modules/zstd/memory/axi_writer.x        |   2 +-
 xls/modules/zstd/sequence_executor.x        |   2 +-
 xls/modules/zstd/zstd_dec.x                 |   2 +-
 18 files changed, 193 insertions(+), 120 deletions(-)
 create mode 100644 xls/dslx/tests/xn_slice_bounds.x

diff --git a/docs_src/tutorials/intro_to_parametrics.md b/docs_src/tutorials/intro_to_parametrics.md
index 9a57d51ac5..b458f21f07 100644
--- a/docs_src/tutorials/intro_to_parametrics.md
+++ b/docs_src/tutorials/intro_to_parametrics.md
@@ -2,7 +2,8 @@
 
 This tutorial demonstrates how types and functions can be parameterized to
 enable them to work on data of different formats and layouts, e.g., for a
-function `foo` to work on both u16 and u32 data types, and anywhere in between.
+function `foo` to work on both `u16` and `u32` data types, and anywhere in
+between.
 
 It's recommended that you're familiar with the concepts in the previous
 tutorial,
@@ -11,7 +12,7 @@ before following this tutorial.
 
 ## Simple parametrics
 
-Consider the simple example of the `umax` function
+Consider the simple example of a `umax` function -- similar to the `max` function
 [in the DSLX standard library](https://github.com/google/xls/tree/main/xls/dslx/stdlib/std.x):
 
 ```dslx
@@ -40,10 +41,12 @@ infer them:
 Explicit specification:
 
 ```dslx
-import std;
+fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] {
+  if x > y { x } else { y }
+}
 
 fn foo(a: u32, b: u16) -> u64 {
-  std::umax<u32:64>(a as u64, b as u64)
+  umax<u32:64>(a as u64, b as u64)
 }
 ```
 
@@ -53,10 +56,12 @@ are.
 Parametric inference:
 
 ```dslx
-import std;
+fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] {
+  if x > y { x } else { y }
+}
 
 fn foo(a: u32, b: u16) -> u64 {
-  std::umax(a as u64, b as u64)
+  umax(a as u64, b as u64)
 }
 ```
 
diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc
index 09a7c8bb6c..13a28fcd4c 100644
--- a/xls/dslx/ir_convert/function_converter.cc
+++ b/xls/dslx/ir_convert/function_converter.cc
@@ -960,12 +960,15 @@ absl::Status FunctionConverter::HandleBuiltinCheckedCast(
       int64_t old_bit_count,
       std::get<InterpValue>(input_bit_count_ctd.value()).GetBitValueViaSign());
 
-  if (dynamic_cast<ArrayType*>(output_type.get()) != nullptr ||
-      dynamic_cast<ArrayType*>(input_type.get()) != nullptr) {
+  std::optional<BitsLikeProperties> output_bits_like =
+      GetBitsLike(*output_type);
+  std::optional<BitsLikeProperties> input_bits_like = GetBitsLike(*input_type);
+
+  if (!output_bits_like.has_value() || !input_bits_like.has_value()) {
     return IrConversionErrorStatus(
         node->span(),
-        absl::StrFormat("CheckedCast to and from array "
-                        "is not currently supported for IR conversion; "
+        absl::StrFormat("CheckedCast is only supported for bits-like types in "
+                        "IR conversion; "
                         "attempted checked cast from: %s to: %s",
                         input_type->ToString(), output_type->ToString()),
         file_table());
diff --git a/xls/dslx/stdlib/apfloat.x b/xls/dslx/stdlib/apfloat.x
index 81785f94bb..acae738e94 100644
--- a/xls/dslx/stdlib/apfloat.x
+++ b/xls/dslx/stdlib/apfloat.x
@@ -1503,7 +1503,7 @@ fn test_fp_lt_2() {
 fn to_signed_or_unsigned_int<RESULT_SZ: u32, RESULT_SIGNED: bool, EXP_SZ: u32, FRACTION_SZ: u32>
     (x: APFloat<EXP_SZ, FRACTION_SZ>) -> xN[RESULT_SIGNED][RESULT_SZ] {
     const WIDE_FRACTION: u32 = FRACTION_SZ + u32:1;
-    const MAX_FRACTION_SZ: u32 = std::umax(RESULT_SZ, WIDE_FRACTION);
+    const MAX_FRACTION_SZ: u32 = std::max(RESULT_SZ, WIDE_FRACTION);
 
     const INT_MIN = if RESULT_SIGNED {
         (uN[MAX_FRACTION_SZ]:1 << (RESULT_SZ - u32:1))  // or rather, its negative.
diff --git a/xls/dslx/stdlib/std.x b/xls/dslx/stdlib/std.x
index 01cf454896..18f835db6e 100644
--- a/xls/dslx/stdlib/std.x
+++ b/xls/dslx/stdlib/std.x
@@ -90,74 +90,68 @@ fn unsigned_max_value_test() {
     assert_eq(u32:0xffffffff, unsigned_max_value<u32:32>());
 }
 
-// Returns the maximum of two signed integers.
-pub fn smax<N: u32>(x: sN[N], y: sN[N]) -> sN[N] { if x > y { x } else { y } }
+// Returns the maximum of two (signed or unsigned) integers.
+pub fn max<S: bool, N: u32>(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x > y { x } else { y } }
 
 #[test]
-fn smax_test() {
-    assert_eq(s2:0, smax(s2:0, s2:0));
-    assert_eq(s2:1, smax(s2:-1, s2:1));
-    assert_eq(s7:-3, smax(s7:-3, s7:-6));
+fn max_test_signed() {
+    assert_eq(s2:0, max(s2:0, s2:0));
+    assert_eq(s2:1, max(s2:-1, s2:1));
+    assert_eq(s7:-3, max(s7:-3, s7:-6));
 }
 
-// Returns the maximum of two unsigned integers.
-pub fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x > y { x } else { y } }
-
 #[test]
-fn umax_test() {
-    assert_eq(u1:1, umax(u1:1, u1:0));
-    assert_eq(u1:1, umax(u1:1, u1:1));
-    assert_eq(u2:3, umax(u2:3, u2:2));
+fn max_test_unsigned() {
+    assert_eq(u1:1, max(u1:1, u1:0));
+    assert_eq(u1:1, max(u1:1, u1:1));
+    assert_eq(u2:3, max(u2:3, u2:2));
 }
 
-// Returns the maximum of two signed integers.
-pub fn smin<N: u32>(x: sN[N], y: sN[N]) -> sN[N] { if x < y { x } else { y } }
+// Returns the minimum of two (signed or unsigned) integers.
+pub fn min<S: bool, N: u32>(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x < y { x } else { y } }
 
 #[test]
-fn smin_test() {
-    assert_eq(s1:0, smin(s1:0, s1:0));
-    assert_eq(s1:-1, smin(s1:0, s1:1));
-    assert_eq(s1:-1, smin(s1:1, s1:0));
-    assert_eq(s1:-1, smin(s1:1, s1:1));
-
-    assert_eq(s2:-2, smin(s2:0, s2:-2));
-    assert_eq(s2:-1, smin(s2:0, s2:-1));
-    assert_eq(s2:0, smin(s2:0, s2:0));
-    assert_eq(s2:0, smin(s2:0, s2:1));
+fn min_test_unsigned() {
+    assert_eq(u1:0, min(u1:1, u1:0));
+    assert_eq(u1:1, min(u1:1, u1:1));
+    assert_eq(u2:2, min(u2:3, u2:2));
+}
 
-    assert_eq(s2:-2, smin(s2:1, s2:-2));
-    assert_eq(s2:-1, smin(s2:1, s2:-1));
-    assert_eq(s2:0, smin(s2:1, s2:0));
-    assert_eq(s2:1, smin(s2:1, s2:1));
+#[test]
+fn min_test_signed() {
+    assert_eq(s1:0, min(s1:0, s1:0));
+    assert_eq(s1:-1, min(s1:0, s1:1));
+    assert_eq(s1:-1, min(s1:1, s1:0));
+    assert_eq(s1:-1, min(s1:1, s1:1));
 
-    assert_eq(s2:-2, smin(s2:-2, s2:-2));
-    assert_eq(s2:-2, smin(s2:-2, s2:-1));
-    assert_eq(s2:-2, smin(s2:-2, s2:0));
-    assert_eq(s2:-2, smin(s2:-2, s2:1));
+    assert_eq(s2:-2, min(s2:0, s2:-2));
+    assert_eq(s2:-1, min(s2:0, s2:-1));
+    assert_eq(s2:0, min(s2:0, s2:0));
+    assert_eq(s2:0, min(s2:0, s2:1));
 
-    assert_eq(s2:-2, smin(s2:-1, s2:-2));
-    assert_eq(s2:-1, smin(s2:-1, s2:-1));
-    assert_eq(s2:-1, smin(s2:-1, s2:0));
-    assert_eq(s2:-1, smin(s2:-1, s2:1));
-}
+    assert_eq(s2:-2, min(s2:1, s2:-2));
+    assert_eq(s2:-1, min(s2:1, s2:-1));
+    assert_eq(s2:0, min(s2:1, s2:0));
+    assert_eq(s2:1, min(s2:1, s2:1));
 
-// Returns the minimum of two unsigned integers.
-pub fn umin<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x < y { x } else { y } }
+    assert_eq(s2:-2, min(s2:-2, s2:-2));
+    assert_eq(s2:-2, min(s2:-2, s2:-1));
+    assert_eq(s2:-2, min(s2:-2, s2:0));
+    assert_eq(s2:-2, min(s2:-2, s2:1));
 
-#[test]
-fn umin_test() {
-    assert_eq(u1:0, umin(u1:1, u1:0));
-    assert_eq(u1:1, umin(u1:1, u1:1));
-    assert_eq(u2:2, umin(u2:3, u2:2));
+    assert_eq(s2:-2, min(s2:-1, s2:-2));
+    assert_eq(s2:-1, min(s2:-1, s2:-1));
+    assert_eq(s2:-1, min(s2:-1, s2:0));
+    assert_eq(s2:-1, min(s2:-1, s2:1));
 }
 
 // Returns unsigned add of x (N bits) and y (M bits) as a max(N,M)+1 bit value.
-pub fn uadd<N: u32, M: u32, R: u32 = {umax(N, M) + u32:1}>(x: uN[N], y: uN[M]) -> uN[R] {
+pub fn uadd<N: u32, M: u32, R: u32 = {max(N, M) + u32:1}>(x: uN[N], y: uN[M]) -> uN[R] {
     (x as uN[R]) + (y as uN[R])
 }
 
 // Returns signed add of x (N bits) and y (M bits) as a max(N,M)+1 bit value.
-pub fn sadd<N: u32, M: u32, R: u32 = {umax(N, M) + u32:1}>(x: sN[N], y: sN[M]) -> sN[R] {
+pub fn sadd<N: u32, M: u32, R: u32 = {max(N, M) + u32:1}>(x: sN[N], y: sN[M]) -> sN[R] {
     (x as sN[R]) + (y as sN[R])
 }
 
@@ -773,7 +767,7 @@ fn test_to_unsigned() {
 //  let result : (bool, u16) = uadd_with_overflow<u32:16>(x, y);
 //
 pub fn uadd_with_overflow
-    <V: u32, N: u32, M: u32, MAX_N_M: u32 = {umax(N, M)}, MAX_N_M_V: u32 = {umax(MAX_N_M, V)}>
+    <V: u32, N: u32, M: u32, MAX_N_M: u32 = {max(N, M)}, MAX_N_M_V: u32 = {max(MAX_N_M, V)}>
     (x: uN[N], y: uN[M]) -> (bool, uN[V]) {
 
     let x_extended = widening_cast<uN[MAX_N_M_V + u32:1]>(x);
@@ -801,47 +795,48 @@ fn test_uadd_with_overflow() {
 }
 
 // Extract bits given a fixed-point integer with a constant offset.
-//   i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + fixed_shift, to_exclusive)];
-//        (x_extended << fixed_shift)[from_inclusive:to_exclusive]
+//   i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + FIXED_SHIFT, TO_EXCLUSIVE)];
+//        (x_extended << FIXED_SHIFT)[FROM_INCLUSIVE:TO_EXCLUSIVE]
 //
 // This function behaves as-if x has reasonably infinite precision so that
-// the result is zero-padded if from_inclusive or to_exclusive are out of
+// the result is zero-padded if FROM_INCLUSIVE or TO_EXCLUSIVE are out of
 // range of the original x's bitwidth.
 //
-// If to_exclusive <= from_exclusive, the result will be a zero-bit uN[0].
+// If TO_EXCLUSIVE <= FROM_INCLUSIVE, the result will be a zero-bit uN[0].
 pub fn extract_bits
-    <from_inclusive: u32, to_exclusive: u32, fixed_shift: u32, N: u32,
-     extract_width: u32 = {smax(s32:0, to_exclusive as s32 - from_inclusive as s32) as u32}>
-    (x: uN[N]) -> uN[extract_width] {
-    if to_exclusive <= from_inclusive {
-        uN[extract_width]:0
+    <FROM_INCLUSIVE: u32, TO_EXCLUSIVE: u32, FIXED_SHIFT: u32, N: u32,
+     EXTRACT_WIDTH: u32 = {max(s32:0, TO_EXCLUSIVE as s32 - FROM_INCLUSIVE as s32) as u32}>
+    (x: uN[N]) -> uN[EXTRACT_WIDTH] {
+    if TO_EXCLUSIVE <= FROM_INCLUSIVE {
+        uN[EXTRACT_WIDTH]:0
     } else {
         // With a non-zero fixed width, all lower bits of index < fixed_shift are
         // are zero.
         let lower_bits =
-            uN[checked_cast<u32>(smax(s32:0, fixed_shift as s32 - from_inclusive as s32))]:0;
+            uN[checked_cast<u32>(max(s32:0, FIXED_SHIFT as s32 - FROM_INCLUSIVE as s32))]:0;
 
         // Based on the input of N bits and a fixed shift, there are an effective
         // count of N + fixed_shift known bits.  All bits of index >
         // N + fixed_shift - 1 are zero's.
         const UPPER_BIT_COUNT = checked_cast<u32>(
-            smax(s32:0, N as s32 + fixed_shift as s32 - to_exclusive as s32 - s32:1));
-        let upper_bits = uN[UPPER_BIT_COUNT]:0;
+            max(s32:0, N as s32 + FIXED_SHIFT as s32 - TO_EXCLUSIVE as s32 - s32:1));
+        const UPPER_BITS = uN[UPPER_BIT_COUNT]:0;
 
-        if fixed_shift < from_inclusive {
+        if FIXED_SHIFT < FROM_INCLUSIVE {
             // The bits extracted start within or after the middle span.
             //  upper_bits ++ middle_bits
-            let middle_bits = upper_bits ++
-                              x[smin(from_inclusive as s32 - fixed_shift as s32, N as s32)
-                                :smin(to_exclusive as s32 - fixed_shift as s32, N as s32)];
-            (upper_bits ++ middle_bits) as uN[extract_width]
-        } else if fixed_shift <= to_exclusive {
+            const FROM: s32 = min(FROM_INCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
+            const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
+            let middle_bits = UPPER_BITS ++ x[FROM:TO];
+            (UPPER_BITS ++ middle_bits) as uN[EXTRACT_WIDTH]
+        } else if FIXED_SHIFT <= TO_EXCLUSIVE {
             // The bits extracted start within the fixed_shift span.
-            let middle_bits = x[0:smin(to_exclusive as s32 - fixed_shift as s32, N as s32)];
+            const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32);
+            let middle_bits = x[0:TO];
 
-            (upper_bits ++ middle_bits ++ lower_bits) as uN[extract_width]
+            (UPPER_BITS ++ middle_bits ++ lower_bits) as uN[EXTRACT_WIDTH]
         } else {
-            uN[extract_width]:0
+            uN[EXTRACT_WIDTH]:0
         }
     }
 }
@@ -928,7 +923,7 @@ pub fn umul_with_overflow
     <V: u32, N: u32, M: u32, N_lower_bits: u32 = {N >> u32:1},
      N_upper_bits: u32 = {N - N_lower_bits}, M_lower_bits: u32 = {M >> u32:1},
      M_upper_bits: u32 = {M - M_lower_bits},
-     Min_N_M_lower_bits: u32 = {umin(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}>
+     Min_N_M_lower_bits: u32 = {min(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}>
     (x: uN[N], y: uN[M]) -> (bool, uN[V]) {
     // Break x and y into two halves.
     // x = x1 ++ x0,
diff --git a/xls/dslx/tests/BUILD b/xls/dslx/tests/BUILD
index 18ff9fbe86..e391052f27 100644
--- a/xls/dslx/tests/BUILD
+++ b/xls/dslx/tests/BUILD
@@ -185,6 +185,8 @@ dslx_lang_test(name = "xn_type_equivalence")
 
 dslx_lang_test(name = "xn_signedness_properties")
 
+dslx_lang_test(name = "xn_slice_bounds")
+
 dslx_lang_test(
     name = "parametric_shift",
     # TODO(leary): 2023-08-14 Runs into "cannot translate zero length bitvector
diff --git a/xls/dslx/tests/errors/error_modules_test.py b/xls/dslx/tests/errors/error_modules_test.py
index 0c2efd9c19..a97f496233 100644
--- a/xls/dslx/tests/errors/error_modules_test.py
+++ b/xls/dslx/tests/errors/error_modules_test.py
@@ -873,8 +873,8 @@ def test_equals_rhs_undefined_nameref(self):
 
   def test_umin_type_mismatch(self):
     stderr = self._run('xls/dslx/tests/errors/umin_type_mismatch.x')
-    self.assertIn('umin_type_mismatch.x:21:12-21:27', stderr)
-    self.assertIn('XlsTypeError: uN[N] vs uN[8]', stderr)
+    self.assertIn('umin_type_mismatch.x:21:13-21:28', stderr)
+    self.assertIn('saw: 42; then: 8', stderr)
 
   def test_diag_block_with_trailing_semi(self):
     stderr = self._run(
diff --git a/xls/dslx/tests/errors/spawn_wrong_argc.x b/xls/dslx/tests/errors/spawn_wrong_argc.x
index d3ec922987..8b2242c4d7 100644
--- a/xls/dslx/tests/errors/spawn_wrong_argc.x
+++ b/xls/dslx/tests/errors/spawn_wrong_argc.x
@@ -19,7 +19,7 @@ pub proc foo {
   config () { () }
 
   next(state: ()) {
-    std::umin(u32:1, u32:2);
+    std::min(u32:1, u32:2);
     ()
   }
 }
@@ -37,7 +37,7 @@ proc test_case {
   }
 
   next(state: ()) {
-    std::umin(u32:1, u32:2);
+    std::min(u32:1, u32:2);
     let tok = send(join(), terminator, true);
     ()
   }
diff --git a/xls/dslx/tests/errors/umin_type_mismatch.x b/xls/dslx/tests/errors/umin_type_mismatch.x
index 6ebfa0f207..129f5f2dc0 100644
--- a/xls/dslx/tests/errors/umin_type_mismatch.x
+++ b/xls/dslx/tests/errors/umin_type_mismatch.x
@@ -18,5 +18,5 @@ const MY_U32 = u42:42;
 const MY_U8 = u8:42;
 
 fn f() -> u32 {
-  std::umin(MY_U32, MY_U8)
+    std::min(MY_U32, MY_U8)
 }
diff --git a/xls/dslx/tests/xn_slice_bounds.x b/xls/dslx/tests/xn_slice_bounds.x
new file mode 100644
index 0000000000..a584121e5b
--- /dev/null
+++ b/xls/dslx/tests/xn_slice_bounds.x
@@ -0,0 +1,33 @@
+// Copyright 2025 The XLS Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+const S = true;
+const N = u32:32;
+
+type MyS32 = xN[S][N];
+
+fn from_to(x: u32) -> u8 { x[MyS32:0:MyS32:8] }
+
+fn to(x: u32) -> u8 { x[:MyS32:8] }
+
+fn from(x: u32) -> u8 { x[MyS32:-8:] }
+
+fn main(x: u32) -> u8[3] { [from_to(x), to(x), from(x)] }
+
+#[test]
+fn test_main() {
+    assert_eq(from_to(u32:0x12345678), u8:0x78);
+    assert_eq(to(u32:0x12345678), u8:0x78);
+    assert_eq(from(u32:0x12345678), u8:0x12);
+}
diff --git a/xls/dslx/type_system/deduce.cc b/xls/dslx/type_system/deduce.cc
index a90e301313..1aca1cf91e 100644
--- a/xls/dslx/type_system/deduce.cc
+++ b/xls/dslx/type_system/deduce.cc
@@ -1243,48 +1243,52 @@ static absl::StatusOr<std::unique_ptr<Type>> DeduceSliceType(
   ctx->type_info()->AddSliceStartAndWidth(slice, fn_parametric_env, saw);
 
   // Make sure the start and end types match and that the limit fits.
-  std::unique_ptr<Type> start_type;
-  std::unique_ptr<Type> limit_type;
+  std::optional<BitsLikeProperties> start_bits_like;
+  std::optional<BitsLikeProperties> limit_bits_like;
   if (slice->start() == nullptr && slice->limit() == nullptr) {
-    start_type = BitsType::MakeS32();
-    limit_type = BitsType::MakeS32();
+    start_bits_like.emplace(
+        BitsLikeProperties{.is_signed = TypeDim::CreateBool(true),
+                           .size = TypeDim::CreateU32(32)});
+    limit_bits_like.emplace(
+        BitsLikeProperties{.is_signed = TypeDim::CreateBool(true),
+                           .size = TypeDim::CreateU32(32)});
   } else if (slice->start() != nullptr && slice->limit() == nullptr) {
-    XLS_ASSIGN_OR_RETURN(BitsType * tmp,
-                         ctx->type_info()->GetItemAs<BitsType>(slice->start()));
-    start_type = tmp->CloneToUnique();
-    limit_type = start_type->CloneToUnique();
+    std::optional<Type*> start_type = ctx->type_info()->GetItem(slice->start());
+    XLS_RET_CHECK(start_type.has_value());
+    start_bits_like = GetBitsLike(*start_type.value());
+    limit_bits_like.emplace(Clone(start_bits_like.value()));
   } else if (slice->start() == nullptr && slice->limit() != nullptr) {
-    XLS_ASSIGN_OR_RETURN(BitsType * tmp,
-                         ctx->type_info()->GetItemAs<BitsType>(slice->limit()));
-    limit_type = tmp->CloneToUnique();
-    start_type = limit_type->CloneToUnique();
+    std::optional<Type*> limit_type = ctx->type_info()->GetItem(slice->limit());
+    XLS_RET_CHECK(limit_type.has_value());
+    limit_bits_like = GetBitsLike(*limit_type.value());
+    start_bits_like.emplace(Clone(limit_bits_like.value()));
   } else {
-    XLS_ASSIGN_OR_RETURN(BitsType * tmp,
-                         ctx->type_info()->GetItemAs<BitsType>(slice->start()));
-    start_type = tmp->CloneToUnique();
-    XLS_ASSIGN_OR_RETURN(tmp,
-                         ctx->type_info()->GetItemAs<BitsType>(slice->limit()));
-    limit_type = tmp->CloneToUnique();
+    std::optional<Type*> start_type = ctx->type_info()->GetItem(slice->start());
+    XLS_RET_CHECK(start_type.has_value());
+    start_bits_like = GetBitsLike(*start_type.value());
+
+    std::optional<Type*> limit_type = ctx->type_info()->GetItem(slice->limit());
+    XLS_RET_CHECK(limit_type.has_value());
+    limit_bits_like = GetBitsLike(*limit_type.value());
   }
 
-  if (*start_type != *limit_type) {
+  if (*start_bits_like != *limit_bits_like) {
     return TypeInferenceErrorStatus(
-        node->span(), limit_type.get(),
+        node->span(), nullptr,
         absl::StrFormat(
             "Slice limit type (%s) did not match slice start type (%s).",
-            limit_type->ToString(), start_type->ToString()),
+            ToTypeString(*limit_bits_like), ToTypeString(*start_bits_like)),
         ctx->file_table());
   }
-  XLS_ASSIGN_OR_RETURN(TypeDim type_width_dim, start_type->GetTotalBitCount());
+  const TypeDim& type_width_dim = start_bits_like->size;
   XLS_ASSIGN_OR_RETURN(int64_t type_width, type_width_dim.GetAsInt64());
   if (Bits::MinBitCountSigned(saw.start + saw.width) > type_width) {
     return TypeInferenceErrorStatus(
-        node->span(), limit_type.get(),
+        node->span(), nullptr,
         absl::StrFormat("Slice limit does not fit in index type: %d.",
                         saw.start + saw.width),
         ctx->file_table());
   }
-
   return std::make_unique<BitsType>(/*signed=*/false, saw.width);
 }
 
diff --git a/xls/dslx/type_system/type.cc b/xls/dslx/type_system/type.cc
index 97d1171e69..a63e28fed2 100644
--- a/xls/dslx/type_system/type.cc
+++ b/xls/dslx/type_system/type.cc
@@ -1117,6 +1117,16 @@ bool IsBitsLike(const Type& t) {
          IsArrayOfBitsConstructor(t);
 }
 
+std::string ToTypeString(const BitsLikeProperties& properties) {
+  if (properties.is_signed.IsParametric()) {
+    return absl::StrFormat("xN[%s][%s]", properties.is_signed.ToString(),
+                           properties.size.ToString());
+  }
+  bool is_signed = properties.is_signed.GetAsBool().value();
+  return absl::StrFormat("%sN[%s]", is_signed ? "s" : "u",
+                         properties.size.ToString());
+}
+
 std::optional<BitsLikeProperties> GetBitsLike(const Type& t) {
   if (auto* bits_type = dynamic_cast<const BitsType*>(&t);
       bits_type != nullptr) {
diff --git a/xls/dslx/type_system/type.h b/xls/dslx/type_system/type.h
index 91cd44a4e2..4d1fd75ea6 100644
--- a/xls/dslx/type_system/type.h
+++ b/xls/dslx/type_system/type.h
@@ -608,6 +608,8 @@ class StructTypeBase : public Type {
 // things like type comparisons
 class StructType : public StructTypeBase {
  public:
+  static std::string GetDebugName() { return "StructType"; }
+
   StructType(std::vector<std::unique_ptr<Type>> members,
              const StructDef& struct_def,
              absl::flat_hash_map<std::string, TypeDim>
@@ -751,6 +753,8 @@ class TupleType : public Type {
 // These will nest in the case of multidimensional arrays.
 class ArrayType : public Type {
  public:
+  static std::string GetDebugName() { return "ArrayType"; }
+
   ArrayType(std::unique_ptr<Type> element_type, const TypeDim& size);
 
   absl::Status Accept(TypeVisitor& v) const override {
@@ -878,6 +882,8 @@ class BitsConstructorType : public Type {
 // respectively.
 class BitsType : public Type {
  public:
+  static std::string GetDebugName() { return "BitsType"; }
+
   static std::unique_ptr<BitsType> MakeU64() {
     return std::make_unique<BitsType>(false, 64);
   }
@@ -941,6 +947,8 @@ class BitsType : public Type {
 // Represents a function type with params and a return type.
 class FunctionType : public Type {
  public:
+  static std::string GetDebugName() { return "FunctionType"; }
+
   FunctionType(std::vector<std::unique_ptr<Type>> params,
                std::unique_ptr<Type> return_type)
       : params_(std::move(params)), return_type_(std::move(return_type)) {
@@ -1106,6 +1114,15 @@ struct BitsLikeProperties {
   TypeDim size;
 };
 
+// Returns a string representation of the BitsLikeProperties that looks similar
+// to a corresponding BitsType.
+std::string ToTypeString(const BitsLikeProperties& properties);
+
+inline BitsLikeProperties Clone(const BitsLikeProperties& properties) {
+  return BitsLikeProperties{.is_signed = properties.is_signed.Clone(),
+                            .size = properties.size.Clone()};
+}
+
 inline bool operator==(const BitsLikeProperties& a,
                        const BitsLikeProperties& b) {
   return a.is_signed == b.is_signed && a.size == b.size;
diff --git a/xls/dslx/type_system/type_info.h b/xls/dslx/type_system/type_info.h
index 9938ccdc31..f4b692a708 100644
--- a/xls/dslx/type_system/type_info.h
+++ b/xls/dslx/type_system/type_info.h
@@ -30,6 +30,7 @@
 #include "absl/status/statusor.h"
 #include "absl/strings/str_format.h"
 #include "absl/types/variant.h"
+#include "xls/common/status/ret_check.h"
 #include "xls/dslx/frontend/ast.h"
 #include "xls/dslx/frontend/pos.h"
 #include "xls/dslx/interp_value.h"
@@ -403,6 +404,9 @@ class TypeInfo {
 
 template <typename T>
 inline absl::StatusOr<T*> TypeInfo::GetItemAs(const AstNode* key) const {
+  static_assert(std::is_base_of<Type, T>::value,
+                "T must be a subclass of Type");
+
   std::optional<Type*> t = GetItem(key);
   if (!t.has_value()) {
     return absl::NotFoundError(
@@ -411,11 +415,11 @@ inline absl::StatusOr<T*> TypeInfo::GetItemAs(const AstNode* key) const {
   }
   DCHECK(t.value() != nullptr);
   auto* target = dynamic_cast<T*>(t.value());
-  if (target == nullptr) {
-    return absl::FailedPreconditionError(absl::StrFormat(
-        "AST node (%s) @ %s did not have expected Type subtype.",
-        key->GetNodeTypeName(), SpanToString(key->GetSpan(), file_table())));
-  }
+  XLS_RET_CHECK(target != nullptr) << absl::StreamFormat(
+      "AST node `%s` @ %s did not have expected `xls::dslx::Type` subtype; "
+      "want: %s got: %s",
+      key->ToString(), SpanToString(key->GetSpan(), file_table()),
+      T::GetDebugName(), t.value()->GetDebugTypeName());
   return target;
 }
 
diff --git a/xls/modules/zstd/dec_mux.x b/xls/modules/zstd/dec_mux.x
index 59778ff304..a877cc4aa5 100644
--- a/xls/modules/zstd/dec_mux.x
+++ b/xls/modules/zstd/dec_mux.x
@@ -117,7 +117,7 @@ pub proc DecoderMux {
         let all_valid = state.raw_data_valid && state.rle_data_valid && state.compressed_data_valid;
 
         let state = if (any_valid) {
-            let min_id = std::umin(std::umin(rle_id, raw_id), compressed_id);
+            let min_id = std::min(std::min(rle_id, raw_id), compressed_id);
             trace_fmt!("DecoderMux: rle_id: {}, raw_id: {}, compressed_id: {}", rle_id, raw_id, compressed_id);
             trace_fmt!("DecoderMux: min_id: {}", min_id);
 
diff --git a/xls/modules/zstd/memory/axi_reader.x b/xls/modules/zstd/memory/axi_reader.x
index 02ea24aff7..43504eea30 100644
--- a/xls/modules/zstd/memory/axi_reader.x
+++ b/xls/modules/zstd/memory/axi_reader.x
@@ -139,7 +139,7 @@ pub proc AxiReader<
 
         let bytes_to_max_burst = MAX_AXI_BURST_BYTES - aligned_offset as Length;
         let bytes_to_4k = common::bytes_to_4k_boundary(state.tran_addr);
-        let tran_len = std::umin(state.tran_len, std::umin(bytes_to_4k, bytes_to_max_burst));
+        let tran_len = std::min(state.tran_len, std::min(bytes_to_4k, bytes_to_max_burst));
         let (req_low_lane, req_high_lane) = common::get_lanes<DATA_W_DIV8>(state.tran_addr, tran_len);
 
         let adjusted_tran_len = aligned_offset as Addr + tran_len;
diff --git a/xls/modules/zstd/memory/axi_writer.x b/xls/modules/zstd/memory/axi_writer.x
index 982c444bcf..2f62307731 100644
--- a/xls/modules/zstd/memory/axi_writer.x
+++ b/xls/modules/zstd/memory/axi_writer.x
@@ -164,7 +164,7 @@ pub proc AxiWriter<
                 }
             },
             Fsm::TRANSFER_LENGTH => {
-                let tran_len = std::umin(state.transfer_data.length, std::umin(state.bytes_to_4k, state.bytes_to_max_axi_burst));
+                let tran_len = std::min(state.transfer_data.length, std::min(state.bytes_to_4k, state.bytes_to_max_axi_burst));
                 State {
                     fsm: Fsm::CALC_NEXT_TRANSFER,
                     transaction_len: tran_len,
diff --git a/xls/modules/zstd/sequence_executor.x b/xls/modules/zstd/sequence_executor.x
index 422185d236..a1fea91d50 100644
--- a/xls/modules/zstd/sequence_executor.x
+++ b/xls/modules/zstd/sequence_executor.x
@@ -482,7 +482,7 @@ fn sequence_packet_to_read_reqs
     -> (ram::ReadReq<RAM_ADDR_WIDTH, RAM_NUM_PARTITIONS>[RAM_NUM], RamOrder[RAM_NUM], SequenceExecutorPacket, bool) {
     type ReadReq = ram::ReadReq<RAM_ADDR_WIDTH, RAM_NUM_PARTITIONS>;
 
-    let max_len = std::umin(seq.length as u32, std::umin(RAM_NUM, hb_len));
+    let max_len = std::min(seq.length as u32, std::min(RAM_NUM, hb_len));
 
     let (next_seq, next_seq_valid) = if seq.length > max_len as CopyOrMatchLength {
         (
diff --git a/xls/modules/zstd/zstd_dec.x b/xls/modules/zstd/zstd_dec.x
index 259361de8f..0f9fac906e 100644
--- a/xls/modules/zstd/zstd_dec.x
+++ b/xls/modules/zstd/zstd_dec.x
@@ -179,7 +179,7 @@ fn feed_block_decoder(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdDe
     trace_fmt!("zstd_dec: feed_block_decoder: buffer_length_bytes: {}", buffer_length_bytes);
     let data_width_bytes = (DATA_WIDTH >> 3) as BlockSize;
     trace_fmt!("zstd_dec: feed_block_decoder: data_width_bytes: {}", data_width_bytes);
-    let remaining_bytes_to_send_now = std::umin(remaining_bytes_to_send, data_width_bytes);
+    let remaining_bytes_to_send_now = std::min(remaining_bytes_to_send, data_width_bytes);
     trace_fmt!("zstd_dec: feed_block_decoder: remaining_bytes_to_send_now: {}", remaining_bytes_to_send_now);
     if (buffer_length_bytes >= remaining_bytes_to_send_now as u32) {
         let remaining_bits_to_send_now = (remaining_bytes_to_send_now as u32) << 3;

From 9e985ac251ee96737cbce5b1a4fd6d9afa63dfce Mon Sep 17 00:00:00 2001
From: Chris Leary <cdleary@openai.com>
Date: Thu, 2 Jan 2025 18:55:34 -0800
Subject: [PATCH 2/2] Add test for widening cast of max now that it produces an
 xN.

---
 xls/dslx/tests/BUILD                         |  2 ++
 xls/dslx/tests/xn_widening_cast.x            | 17 +++++++++++++++
 xls/dslx/type_system/typecheck_invocation.cc | 22 +++++++++++---------
 3 files changed, 31 insertions(+), 10 deletions(-)
 create mode 100644 xls/dslx/tests/xn_widening_cast.x

diff --git a/xls/dslx/tests/BUILD b/xls/dslx/tests/BUILD
index e391052f27..18b70cc45f 100644
--- a/xls/dslx/tests/BUILD
+++ b/xls/dslx/tests/BUILD
@@ -187,6 +187,8 @@ dslx_lang_test(name = "xn_signedness_properties")
 
 dslx_lang_test(name = "xn_slice_bounds")
 
+dslx_lang_test(name = "xn_widening_cast")
+
 dslx_lang_test(
     name = "parametric_shift",
     # TODO(leary): 2023-08-14 Runs into "cannot translate zero length bitvector
diff --git a/xls/dslx/tests/xn_widening_cast.x b/xls/dslx/tests/xn_widening_cast.x
new file mode 100644
index 0000000000..ef06644bc9
--- /dev/null
+++ b/xls/dslx/tests/xn_widening_cast.x
@@ -0,0 +1,17 @@
+// Copyright 2025 The XLS Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import std;
+
+fn main(x: u7, y: u7) -> u32 { widening_cast<u32>(std::max(x, y)) }
diff --git a/xls/dslx/type_system/typecheck_invocation.cc b/xls/dslx/type_system/typecheck_invocation.cc
index 197100ecde..b16797d97b 100644
--- a/xls/dslx/type_system/typecheck_invocation.cc
+++ b/xls/dslx/type_system/typecheck_invocation.cc
@@ -98,10 +98,12 @@ static absl::Status TypecheckIsAcceptableWideningCast(DeduceCtx* ctx,
   XLS_RET_CHECK(maybe_from_type.has_value());
   XLS_RET_CHECK(maybe_to_type.has_value());
 
-  BitsType* from = dynamic_cast<BitsType*>(maybe_from_type.value());
-  BitsType* to = dynamic_cast<BitsType*>(maybe_to_type.value());
+  std::optional<BitsLikeProperties> from_bits_like =
+      GetBitsLike(*maybe_from_type.value());
+  std::optional<BitsLikeProperties> to_bits_like =
+      GetBitsLike(*maybe_to_type.value());
 
-  if (from == nullptr || to == nullptr) {
+  if (!from_bits_like.has_value() || !to_bits_like.has_value()) {
     return ctx->TypeMismatchError(
         node->span(), from_expr, *maybe_from_type.value(), node,
         *maybe_to_type.value(),
@@ -110,13 +112,13 @@ static absl::Status TypecheckIsAcceptableWideningCast(DeduceCtx* ctx,
                         maybe_to_type.value()->ToErrorString()));
   }
 
-  bool signed_input = from->is_signed();
-  bool signed_output = to->is_signed();
+  XLS_ASSIGN_OR_RETURN(bool signed_input,
+                       from_bits_like->is_signed.GetAsBool());
+  XLS_ASSIGN_OR_RETURN(bool signed_output, to_bits_like->is_signed.GetAsBool());
 
   XLS_ASSIGN_OR_RETURN(int64_t old_bit_count,
-                       from->GetTotalBitCount().value().GetAsInt64());
-  XLS_ASSIGN_OR_RETURN(int64_t new_bit_count,
-                       to->GetTotalBitCount().value().GetAsInt64());
+                       from_bits_like->size.GetAsInt64());
+  XLS_ASSIGN_OR_RETURN(int64_t new_bit_count, to_bits_like->size.GetAsInt64());
 
   bool can_cast =
       ((signed_input == signed_output) && (new_bit_count >= old_bit_count)) ||
@@ -128,8 +130,8 @@ static absl::Status TypecheckIsAcceptableWideningCast(DeduceCtx* ctx,
         *maybe_to_type.value(),
         absl::StrFormat("Can not cast from type %s (%d bits) to"
                         " %s (%d bits) with widening_cast",
-                        from->ToString(), old_bit_count, to->ToString(),
-                        new_bit_count));
+                        ToTypeString(from_bits_like.value()), old_bit_count,
+                        ToTypeString(to_bits_like.value()), new_bit_count));
   }
 
   return absl::OkStatus();