From 709457d406f7cf05e6bdbac3bcbbbebf22e9656a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Inf=C3=BChr?= Date: Wed, 12 Feb 2025 20:39:13 +0100 Subject: [PATCH] frontend: Support generic method calls with type args --- dora-bytecode/src/display.rs | 74 ++++++++++++++++--- dora-bytecode/src/lib.rs | 3 +- dora-bytecode/src/program.rs | 1 + dora-bytecode/src/ty.rs | 15 ++++ dora-frontend/src/program_emitter.rs | 8 +- dora-frontend/src/typeck/call.rs | 57 ++++++++------ dora-frontend/src/typeck/tests.rs | 63 ++++++++++++++++ dora-runtime/src/cannon/codegen.rs | 12 ++- dora-runtime/src/vm/impls.rs | 1 + dora-runtime/src/vm/specialize.rs | 1 + tests/generic/generic-fct-type-args1.dora | 18 +++++ .../generic-static-fct-type-args1.dora | 17 +++++ 12 files changed, 235 insertions(+), 35 deletions(-) create mode 100644 tests/generic/generic-fct-type-args1.dora create mode 100644 tests/generic/generic-static-fct-type-args1.dora diff --git a/dora-bytecode/src/display.rs b/dora-bytecode/src/display.rs index ab0520ec8..739dbe564 100644 --- a/dora-bytecode/src/display.rs +++ b/dora-bytecode/src/display.rs @@ -139,6 +139,18 @@ pub fn display_ty_without_type_params(prog: &Program, ty: &BytecodeType) -> Stri printer.string() } +pub fn fmt_ty<'a>( + prog: &'a Program, + ty: &'a BytecodeType, + type_params: TypeParamMode<'a>, +) -> BytecodeTypePrinter<'a> { + BytecodeTypePrinter { + prog, + type_params, + ty: ty.clone(), + } +} + pub fn fmt_ty_with_type_params<'a>( prog: &'a Program, ty: &'a BytecodeType, @@ -151,6 +163,30 @@ pub fn fmt_ty_with_type_params<'a>( } } +pub fn fmt_ty_array<'a>( + prog: &'a Program, + array: &'a BytecodeTypeArray, + type_params: TypeParamMode<'a>, +) -> BytecodeTypeArrayPrinter<'a> { + BytecodeTypeArrayPrinter { + prog, + type_params, + array, + } +} + +pub fn fmt_trait_ty<'a>( + prog: &'a Program, + trait_ty: &'a BytecodeTraitType, + type_params: TypeParamMode<'a>, +) -> BytecodeTraitTypePrinter<'a> { + BytecodeTraitTypePrinter { + prog, + type_params, + trait_ty, + } +} + pub fn display_ty_with_type_params( prog: &Program, ty: &BytecodeType, @@ -166,7 +202,7 @@ pub fn fmt_trait_ty_with_type_params<'a>( ) -> BytecodeTraitTypePrinter<'a> { BytecodeTraitTypePrinter { prog, - type_params, + type_params: TypeParamMode::TypeParams(type_params), trait_ty, } } @@ -179,7 +215,8 @@ pub fn display_trait_ty_with_type_params( fmt_trait_ty_with_type_params(prog, trait_ty, type_params).string() } -enum TypeParamMode<'a> { +#[derive(Clone)] +pub enum TypeParamMode<'a> { None, Unknown, TypeParams(&'a TypeParamData), @@ -296,9 +333,32 @@ impl<'a> std::fmt::Display for BytecodeTypePrinter<'a> { } } +pub struct BytecodeTypeArrayPrinter<'a> { + prog: &'a Program, + type_params: TypeParamMode<'a>, + array: &'a BytecodeTypeArray, +} + +impl<'a> std::fmt::Display for BytecodeTypeArrayPrinter<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "[")?; + + let mut first = true; + for ty in self.array.iter() { + if !first { + write!(f, ", ")?; + } + write!(f, "{}", fmt_ty(&self.prog, &ty, self.type_params.clone()))?; + first = false; + } + + write!(f, "]") + } +} + pub struct BytecodeTraitTypePrinter<'a> { prog: &'a Program, - type_params: &'a TypeParamData, + type_params: TypeParamMode<'a>, trait_ty: &'a BytecodeTraitType, } @@ -325,11 +385,7 @@ impl<'a> BytecodeTraitTypePrinter<'a> { write!(fmt, ", ")?; } - write!( - fmt, - "{}", - fmt_ty_with_type_params(&self.prog, &ty, self.type_params) - )?; + write!(fmt, "{}", fmt_ty(&self.prog, &ty, self.type_params.clone()))?; first = false; } @@ -343,7 +399,7 @@ impl<'a> BytecodeTraitTypePrinter<'a> { fmt, "{}={}", alias.name, - fmt_ty_with_type_params(&self.prog, &ty, self.type_params) + fmt_ty(&self.prog, &ty, self.type_params.clone()) )?; first = false; } diff --git a/dora-bytecode/src/lib.rs b/dora-bytecode/src/lib.rs index aeaf98254..cfcc49487 100644 --- a/dora-bytecode/src/lib.rs +++ b/dora-bytecode/src/lib.rs @@ -12,7 +12,8 @@ mod tests; pub use data::*; pub use display::{ display_fct, display_ty, display_ty_array, display_ty_with_type_params, - display_ty_without_type_params, module_path, module_path_name, + display_ty_without_type_params, fmt_trait_ty, fmt_ty, fmt_ty_array, module_path, + module_path_name, TypeParamMode, }; pub use dumper::{dump, dump_stdout}; pub use program::{ diff --git a/dora-bytecode/src/program.rs b/dora-bytecode/src/program.rs index bb811262c..b5caa4100 100644 --- a/dora-bytecode/src/program.rs +++ b/dora-bytecode/src/program.rs @@ -106,6 +106,7 @@ pub struct StructField { #[derive(Debug, Decode, Encode)] pub struct TypeParamData { pub names: Vec, + pub container_count: usize, pub bounds: Vec, } diff --git a/dora-bytecode/src/ty.rs b/dora-bytecode/src/ty.rs index c19268769..993135320 100644 --- a/dora-bytecode/src/ty.rs +++ b/dora-bytecode/src/ty.rs @@ -230,6 +230,21 @@ impl BytecodeTypeArray { BytecodeTypeArray::new(types) } + pub fn connect(&self, rhs: &BytecodeTypeArray) -> BytecodeTypeArray { + let mut types = self.0.to_vec(); + types.append(&mut rhs.to_vec()); + BytecodeTypeArray::new(types) + } + + pub fn split(&self, elements: usize) -> (BytecodeTypeArray, BytecodeTypeArray) { + let (first, second) = self.0.split_at(elements); + + ( + BytecodeTypeArray::new(first.to_vec()), + BytecodeTypeArray::new(second.to_vec()), + ) + } + pub fn is_concrete_type(&self) -> bool { self.0.iter().all(|ty| ty.is_concrete_type()) } diff --git a/dora-frontend/src/program_emitter.rs b/dora-frontend/src/program_emitter.rs index 0153536d9..3412b99a4 100644 --- a/dora-frontend/src/program_emitter.rs +++ b/dora-frontend/src/program_emitter.rs @@ -347,7 +347,13 @@ fn create_type_params(sa: &Sema, type_params: &TypeParamDefinition) -> TypeParam }) .collect(); - TypeParamData { names, bounds } + let container_count = type_params.container_type_params(); + + TypeParamData { + names, + container_count, + bounds, + } } fn create_struct_fields(sa: &Sema, struct_: &StructDefinition) -> Vec { diff --git a/dora-frontend/src/typeck/call.rs b/dora-frontend/src/typeck/call.rs index bd83977b9..8700cbd56 100644 --- a/dora-frontend/src/typeck/call.rs +++ b/dora-frontend/src/typeck/call.rs @@ -129,32 +129,45 @@ fn check_expr_call_generic_static_method( let trait_method = ck.sa.fct(trait_method_id); let tp = SourceType::TypeParam(tp_id); - let fct_type_params = trait_ty.type_params.connect(&pure_fct_type_params); + let combined_fct_type_params = trait_ty.type_params.connect(&pure_fct_type_params); - check_args_compatible_fct2(ck, trait_method, arguments, |ty| { - specialize_ty_for_generic(ck.sa, ty, tp_id, &trait_ty, &fct_type_params, &tp) - }); - - let call_type = CallType::GenericStaticMethod( - tp_id, - trait_ty.trait_id, - trait_method_id, - trait_ty.type_params.clone(), - ); - ck.analysis.map_calls.insert(e.id, Arc::new(call_type)); - - let return_type = specialize_ty_for_generic( + if check_type_params( ck.sa, - trait_method.return_type(), - tp_id, - &trait_ty, - &fct_type_params, - &tp, - ); + ck.element, + &ck.type_param_definition, + trait_method, + &combined_fct_type_params, + ck.file_id, + e.span, + |ty| specialize_ty_for_generic(ck.sa, ty, tp_id, &trait_ty, &combined_fct_type_params, &tp), + ) { + check_args_compatible_fct2(ck, trait_method, arguments, |ty| { + specialize_ty_for_generic(ck.sa, ty, tp_id, &trait_ty, &combined_fct_type_params, &tp) + }); + + let call_type = CallType::GenericStaticMethod( + tp_id, + trait_ty.trait_id, + trait_method_id, + combined_fct_type_params.clone(), + ); + ck.analysis.map_calls.insert(e.id, Arc::new(call_type)); - ck.analysis.set_ty(e.id, return_type.clone()); + let return_type = specialize_ty_for_generic( + ck.sa, + trait_method.return_type(), + tp_id, + &trait_ty, + &combined_fct_type_params, + &tp, + ); - return_type + ck.analysis.set_ty(e.id, return_type.clone()); + + return_type + } else { + SourceType::Error + } } fn check_expr_call_expr( diff --git a/dora-frontend/src/typeck/tests.rs b/dora-frontend/src/typeck/tests.rs index a10b12436..63127f8a3 100644 --- a/dora-frontend/src/typeck/tests.rs +++ b/dora-frontend/src/typeck/tests.rs @@ -1119,6 +1119,69 @@ fn trait_method_call_with_function_type_params_invalid_param() { ); } +#[test] +fn static_trait_method_call_with_function_type_params() { + ok(" + trait Foo { + static fn bar[T](x: T); + } + + fn f[T: Foo](x: T) { + T::bar[Int](1); + } + "); +} + +#[test] +fn static_trait_method_call_with_function_type_params_and_missing_params() { + err( + " + trait Foo { + static fn bar[T](x: T); + } + + fn f[T: Foo](x: T) { + T::bar(1); + } + ", + (7, 13), + ErrorMessage::WrongNumberTypeParams(1, 0), + ); + + err( + " + trait Foo[X] { + static fn bar[T](x: T); + } + + fn f[T: Foo[Int]](x: T) { + T::bar(1); + } + ", + (7, 13), + ErrorMessage::WrongNumberTypeParams(1, 0), + ); +} + +#[test] +fn static_trait_method_call_with_function_type_params_invalid_param() { + err( + " + trait Foo { + static fn bar[T: Bar](x: T); + } + + trait Bar {} + + fn f[T: Foo](x: T) { + T::bar[Int](1); + } + ", + (9, 13), + ErrorMessage::TypeNotImplementingTrait("Int64".into(), "Bar".into()), + ); +} + #[test] fn test_generic_ctor_without_type_params() { err( diff --git a/dora-runtime/src/cannon/codegen.rs b/dora-runtime/src/cannon/codegen.rs index 84fdf4e8d..fbff00ef2 100644 --- a/dora-runtime/src/cannon/codegen.rs +++ b/dora-runtime/src/cannon/codegen.rs @@ -2657,17 +2657,25 @@ impl<'a> CannonCodeGen<'a> { FunctionKind::Trait(trait_id) => trait_id, _ => unreachable!(), }; + + let type_params = self.specialize_ty_array(&type_params); + assert!(type_params.is_concrete_type()); + let (trait_type_params, pure_fct_type_params) = + type_params.split(fct.type_params.container_count); + let trait_ty = BytecodeTraitType { trait_id, - type_params: type_params.clone(), + type_params: trait_type_params, bindings: Vec::new(), }; - let (callee_id, type_params) = find_trait_impl(self.vm, trait_fct_id, trait_ty, ty); + let (callee_id, container_bindings) = find_trait_impl(self.vm, trait_fct_id, trait_ty, ty); let pos = self.bytecode.offset_location(self.current_offset.to_u32()); let arguments = self.argument_stack.drain(..).collect::>(); + let type_params = container_bindings.connect(&pure_fct_type_params); + if is_static { self.emit_invoke_static_or_intrinsic(dest, callee_id, type_params, arguments, pos); } else { diff --git a/dora-runtime/src/vm/impls.rs b/dora-runtime/src/vm/impls.rs index f59c4349b..4bc730960 100644 --- a/dora-runtime/src/vm/impls.rs +++ b/dora-runtime/src/vm/impls.rs @@ -14,6 +14,7 @@ pub fn find_trait_impl( let type_param_data = TypeParamData { names: Vec::new(), + container_count: 0, bounds: Vec::new(), }; diff --git a/dora-runtime/src/vm/specialize.rs b/dora-runtime/src/vm/specialize.rs index 98afe27d5..68413c1d1 100644 --- a/dora-runtime/src/vm/specialize.rs +++ b/dora-runtime/src/vm/specialize.rs @@ -703,6 +703,7 @@ pub fn specialize_ty( let type_param_data = TypeParamData { names: Vec::new(), + container_count: 0, bounds: Vec::new(), }; diff --git a/tests/generic/generic-fct-type-args1.dora b/tests/generic/generic-fct-type-args1.dora new file mode 100644 index 000000000..46db5e09e --- /dev/null +++ b/tests/generic/generic-fct-type-args1.dora @@ -0,0 +1,18 @@ +trait Foo { + fn bar[T](x: T); +} + +class Bar + +impl Foo for Bar { + fn bar[T](x: T) {} +} + +fn myf[T: Foo](x: T) { + x.bar[Int](1); +} + +fn main() { + let bar = Bar(); + myf[Bar](bar); +} diff --git a/tests/generic/generic-static-fct-type-args1.dora b/tests/generic/generic-static-fct-type-args1.dora new file mode 100644 index 000000000..d24a8f8a4 --- /dev/null +++ b/tests/generic/generic-static-fct-type-args1.dora @@ -0,0 +1,17 @@ +trait Foo { + static fn bar[T](x: T); +} + +class Bar + +impl Foo for Bar { + static fn bar[T](x: T) {} +} + +fn myf[T: Foo]() { + T::bar[Int](1); +} + +fn main() { + myf[Bar](); +}