From af1d12b1f418dbfe039c8644904a535f53308193 Mon Sep 17 00:00:00 2001 From: ncave <777696+ncave@users.noreply.github.com> Date: Sat, 20 Jan 2024 04:24:15 -0800 Subject: [PATCH] [Rust] Added support for generic comparers --- src/Fable.Cli/CHANGELOG.md | 2 + src/Fable.Transforms/Rust/Fable2Rust.fs | 15 +++-- .../src/System.Collections.Generic.fs | 40 +++++++++--- tests/Js/Main/ComparisonTests.fs | 21 +++++-- tests/Rust/tests/src/ComparisonTests.fs | 61 +++++++++++-------- 5 files changed, 96 insertions(+), 43 deletions(-) diff --git a/src/Fable.Cli/CHANGELOG.md b/src/Fable.Cli/CHANGELOG.md index 529a00a7c5..93fefbaad2 100644 --- a/src/Fable.Cli/CHANGELOG.md +++ b/src/Fable.Cli/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### Javascript * Fixed 'System.Collections.Generic.Queue' bug (by @PierreYvesR) +* Fixed instance calls for generic comparers (by @ncave) #### Python @@ -30,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fixed generic try_catch closure trait (by @ncave) * Fixed `self` arg capture in methods (by @ncave) * Fixed 'System.Collections.Generic.Queue' bug (by @PierreYvesR) +* Added support for generic comparers (by @ncave) ## 4.9.0 - 2023-12-14 diff --git a/src/Fable.Transforms/Rust/Fable2Rust.fs b/src/Fable.Transforms/Rust/Fable2Rust.fs index 983f2bb7de..169c04ff00 100644 --- a/src/Fable.Transforms/Rust/Fable2Rust.fs +++ b/src/Fable.Transforms/Rust/Fable2Rust.fs @@ -980,12 +980,19 @@ module TypeInfo = genArgs : Rust.Ty = - let nameParts = - getAbstractClassImportName com ctx entRef |> splitNameParts - + let entName = getAbstractClassImportName com ctx entRef + let nameParts = entName |> splitNameParts let genArgsOpt = transformGenArgs com ctx genArgs let traitBound = mkTypeTraitGenericBound nameParts genArgsOpt - mkDynTraitTy [ traitBound ] + + match entRef.FullName with + | "System.Collections.Generic.Comparer`1" + | "System.Collections.Generic.EqualityComparer`1" -> + // some abstract classes are implemented as non-abstract + makeFullNamePathTy entName genArgsOpt + | _ -> + // abstract classes implemented as interfaces + mkDynTraitTy [ traitBound ] let (|HasEmitAttribute|_|) (ent: Fable.Entity) = ent.Attributes diff --git a/src/fable-library-rust/src/System.Collections.Generic.fs b/src/fable-library-rust/src/System.Collections.Generic.fs index c9436fcf54..2d766885ab 100644 --- a/src/fable-library-rust/src/System.Collections.Generic.fs +++ b/src/fable-library-rust/src/System.Collections.Generic.fs @@ -2,16 +2,36 @@ namespace System.Collections.Generic open Global_ -// type Comparer<'T when 'T: comparison>() = -// static member Default = Comparer<'T>() -// interface IComparer<'T> with -// member _.Compare(x, y) = LanguagePrimitives.GenericComparison x y - -// type EqualityComparer<'T when 'T: equality>() = -// static member Default = EqualityComparer<'T>() -// interface IEqualityComparer<'T> with -// member _.Equals(x, y) = LanguagePrimitives.GenericEquality x y -// member _.GetHashCode(x) = LanguagePrimitives.GenericHash x +type Comparer<'T when 'T: comparison>(comparison: 'T -> 'T -> int) = + + static member Default = Comparer<'T>(LanguagePrimitives.GenericComparison) + + static member Create(comparison) = Comparer<'T>(comparison) + + member _.Compare(x, y) = comparison x y + + interface IComparer<'T> with + member _.Compare(x, y) = comparison x y + +type EqualityComparer<'T when 'T: equality> + (equals: 'T -> 'T -> bool, getHashCode: 'T -> int) + = + + static member Default = + EqualityComparer<'T>( + LanguagePrimitives.GenericEquality, + LanguagePrimitives.GenericHash + ) + + static member Create(equals, getHashCode) = + EqualityComparer<'T>(equals, getHashCode) + + member _.Equals(x, y) = equals x y + member _.GetHashCode(x) = getHashCode x + + interface IEqualityComparer<'T> with + member _.Equals(x, y) = equals x y + member _.GetHashCode(x) = getHashCode x type Stack<'T when 'T: equality> private (initialContents: 'T[], initialCount) = let mutable contents = initialContents diff --git a/tests/Js/Main/ComparisonTests.fs b/tests/Js/Main/ComparisonTests.fs index 00baadbe97..7ecf1db49e 100644 --- a/tests/Js/Main/ComparisonTests.fs +++ b/tests/Js/Main/ComparisonTests.fs @@ -85,17 +85,17 @@ type FuzzyInt = x - 2 <= y && y <= x + 2 | _ -> false -let genericEquals (a:'T) (b:'T) : bool = +let genericEquals (a: 'T) (b: 'T) : bool = let cmp = EqualityComparer<'T>.Default - cmp.Equals(a,b) + cmp.Equals(a, b) -let genericHash (x:'T) : int = +let genericHash (x: 'T) : int = let cmp = EqualityComparer<'T>.Default cmp.GetHashCode(x) -let genericCompare (a:'T) (b:'T) : int = +let genericCompare (a: 'T) (b: 'T) : int = let cmp = Comparer<'T>.Default - cmp.Compare(a,b) + cmp.Compare(a, b) let tests = testList "Comparison" [ @@ -714,6 +714,11 @@ let tests = let distance: decimal = LanguagePrimitives.DecimalWithMeasure 1.0m distance |> equal 1.0m + testCase "EqualityComparer.Create works" <| fun () -> + let cmp = EqualityComparer<'T>.Create((<>), hash) + cmp.Equals(1, 1) |> equal false + cmp.Equals(1, 2) |> equal true + testCase "EqualityComparer.Equals works" <| fun () -> genericEquals 1 1 |> equal true genericEquals 1 2 |> equal false @@ -729,4 +734,10 @@ let tests = genericCompare 1 2 |> equal -1 genericCompare 2 1 |> equal 1 + testCase "Comparer.Create works" <| fun () -> + let cmp = Comparer<'T>.Create(fun x y -> -(compare x y)) + cmp.Compare(1, 1) |> equal 0 + cmp.Compare(1, 2) |> equal 1 + cmp.Compare(2, 1) |> equal -1 + ] \ No newline at end of file diff --git a/tests/Rust/tests/src/ComparisonTests.fs b/tests/Rust/tests/src/ComparisonTests.fs index 164672fa80..57f64c6957 100644 --- a/tests/Rust/tests/src/ComparisonTests.fs +++ b/tests/Rust/tests/src/ComparisonTests.fs @@ -81,17 +81,17 @@ type MyClass(v) = // x - 2 <= y && y <= x + 2 // | _ -> false -// let genericEquals (a:'T) (b:'T) : bool = -// let cmp = EqualityComparer<'T>.Default -// cmp.Equals(a,b) +let genericEquals<'T when 'T: equality> (a: 'T) (b: 'T) : bool = + let cmp = EqualityComparer<'T>.Default + cmp.Equals(a, b) -// let genericHash (x:'T) : int = -// let cmp = EqualityComparer<'T>.Default -// cmp.GetHashCode(x) +let genericHash<'T when 'T: equality> (x: 'T) : int = + let cmp = EqualityComparer<'T>.Default + cmp.GetHashCode(x) -// let genericCompare (a:'T) (b:'T) : int = -// let cmp = Comparer<'T>.Default -// cmp.Compare(a,b) +let genericCompare<'T when 'T: comparison> (a: 'T) (b: 'T) : int = + let cmp = Comparer<'T>.Default + cmp.Compare(a, b) [] let ``Typed array equality works`` () = @@ -761,20 +761,33 @@ let ``LanguagePrimitives.DecimalWithMeasure works`` () = let distance: decimal = LanguagePrimitives.DecimalWithMeasure 1.0m distance |> equal 1.0m -// [] -// let ``EqualityComparer.Equals works`` () = -// genericEquals 1 1 |> equal true -// genericEquals 1 2 |> equal false -// genericEquals "1" "1" |> equal true -// genericEquals "1" "2" |> equal false +[] +let ``EqualityComparer.Create works`` () = + let cmp = EqualityComparer<'T>.Create((<>), hash) + cmp.Equals(1, 1) |> equal false + cmp.Equals(1, 2) |> equal true -// [] -// let ``EqualityComparer.GetHashCode works`` () = -// genericHash 1 |> equal ((1).GetHashCode()) -// genericHash "1" |> equal ("1".GetHashCode()) +[] +let ``EqualityComparer.Equals works`` () = + genericEquals 1 1 |> equal true + genericEquals 1 2 |> equal false + genericEquals "1" "1" |> equal true + genericEquals "1" "2" |> equal false -// [] -// let ``Comparer.Compare works`` () = -// genericCompare 1 1 |> equal 0 -// genericCompare 1 2 |> equal -1 -// genericCompare 2 1 |> equal 1 +[] +let ``EqualityComparer.GetHashCode works`` () = + genericHash 1 |> equal ((1).GetHashCode()) + genericHash "1" |> equal ("1".GetHashCode()) + +[] +let ``Comparer.Compare works`` () = + genericCompare 1 1 |> equal 0 + genericCompare 1 2 |> equal -1 + genericCompare 2 1 |> equal 1 + +[] +let ``Comparer.Create works`` () = + let cmp = Comparer<'T>.Create(fun x y -> -(compare x y)) + cmp.Compare(1, 1) |> equal 0 + cmp.Compare(1, 2) |> equal 1 + cmp.Compare(2, 1) |> equal -1