Skip to content

Commit

Permalink
frontend: Support generic method calls with type args
Browse files Browse the repository at this point in the history
  • Loading branch information
dinfuehr committed Feb 12, 2025
1 parent d572a59 commit 709457d
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 35 deletions.
74 changes: 65 additions & 9 deletions dora-bytecode/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ pub fn display_ty_without_type_params(prog: &Program, ty: &BytecodeType) -> Stri
printer.string()
}

pub fn fmt_ty<'a>(
prog: &'a Program,
ty: &'a BytecodeType,
type_params: TypeParamMode<'a>,
) -> BytecodeTypePrinter<'a> {
BytecodeTypePrinter {
prog,
type_params,
ty: ty.clone(),
}
}

pub fn fmt_ty_with_type_params<'a>(
prog: &'a Program,
ty: &'a BytecodeType,
Expand All @@ -151,6 +163,30 @@ pub fn fmt_ty_with_type_params<'a>(
}
}

pub fn fmt_ty_array<'a>(
prog: &'a Program,
array: &'a BytecodeTypeArray,
type_params: TypeParamMode<'a>,
) -> BytecodeTypeArrayPrinter<'a> {
BytecodeTypeArrayPrinter {
prog,
type_params,
array,
}
}

pub fn fmt_trait_ty<'a>(
prog: &'a Program,
trait_ty: &'a BytecodeTraitType,
type_params: TypeParamMode<'a>,
) -> BytecodeTraitTypePrinter<'a> {
BytecodeTraitTypePrinter {
prog,
type_params,
trait_ty,
}
}

pub fn display_ty_with_type_params(
prog: &Program,
ty: &BytecodeType,
Expand All @@ -166,7 +202,7 @@ pub fn fmt_trait_ty_with_type_params<'a>(
) -> BytecodeTraitTypePrinter<'a> {
BytecodeTraitTypePrinter {
prog,
type_params,
type_params: TypeParamMode::TypeParams(type_params),
trait_ty,
}
}
Expand All @@ -179,7 +215,8 @@ pub fn display_trait_ty_with_type_params(
fmt_trait_ty_with_type_params(prog, trait_ty, type_params).string()
}

enum TypeParamMode<'a> {
#[derive(Clone)]
pub enum TypeParamMode<'a> {
None,
Unknown,
TypeParams(&'a TypeParamData),
Expand Down Expand Up @@ -296,9 +333,32 @@ impl<'a> std::fmt::Display for BytecodeTypePrinter<'a> {
}
}

pub struct BytecodeTypeArrayPrinter<'a> {
prog: &'a Program,
type_params: TypeParamMode<'a>,
array: &'a BytecodeTypeArray,
}

impl<'a> std::fmt::Display for BytecodeTypeArrayPrinter<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "[")?;

let mut first = true;
for ty in self.array.iter() {
if !first {
write!(f, ", ")?;
}
write!(f, "{}", fmt_ty(&self.prog, &ty, self.type_params.clone()))?;
first = false;
}

write!(f, "]")
}
}

pub struct BytecodeTraitTypePrinter<'a> {
prog: &'a Program,
type_params: &'a TypeParamData,
type_params: TypeParamMode<'a>,
trait_ty: &'a BytecodeTraitType,
}

Expand All @@ -325,11 +385,7 @@ impl<'a> BytecodeTraitTypePrinter<'a> {
write!(fmt, ", ")?;
}

write!(
fmt,
"{}",
fmt_ty_with_type_params(&self.prog, &ty, self.type_params)
)?;
write!(fmt, "{}", fmt_ty(&self.prog, &ty, self.type_params.clone()))?;
first = false;
}

Expand All @@ -343,7 +399,7 @@ impl<'a> BytecodeTraitTypePrinter<'a> {
fmt,
"{}={}",
alias.name,
fmt_ty_with_type_params(&self.prog, &ty, self.type_params)
fmt_ty(&self.prog, &ty, self.type_params.clone())
)?;
first = false;
}
Expand Down
3 changes: 2 additions & 1 deletion dora-bytecode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ mod tests;
pub use data::*;
pub use display::{
display_fct, display_ty, display_ty_array, display_ty_with_type_params,
display_ty_without_type_params, module_path, module_path_name,
display_ty_without_type_params, fmt_trait_ty, fmt_ty, fmt_ty_array, module_path,
module_path_name, TypeParamMode,
};
pub use dumper::{dump, dump_stdout};
pub use program::{
Expand Down
1 change: 1 addition & 0 deletions dora-bytecode/src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pub struct StructField {
#[derive(Debug, Decode, Encode)]
pub struct TypeParamData {
pub names: Vec<String>,
pub container_count: usize,
pub bounds: Vec<TypeParamBound>,
}

Expand Down
15 changes: 15 additions & 0 deletions dora-bytecode/src/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ impl BytecodeTypeArray {
BytecodeTypeArray::new(types)
}

pub fn connect(&self, rhs: &BytecodeTypeArray) -> BytecodeTypeArray {
let mut types = self.0.to_vec();
types.append(&mut rhs.to_vec());
BytecodeTypeArray::new(types)
}

pub fn split(&self, elements: usize) -> (BytecodeTypeArray, BytecodeTypeArray) {
let (first, second) = self.0.split_at(elements);

(
BytecodeTypeArray::new(first.to_vec()),
BytecodeTypeArray::new(second.to_vec()),
)
}

pub fn is_concrete_type(&self) -> bool {
self.0.iter().all(|ty| ty.is_concrete_type())
}
Expand Down
8 changes: 7 additions & 1 deletion dora-frontend/src/program_emitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,13 @@ fn create_type_params(sa: &Sema, type_params: &TypeParamDefinition) -> TypeParam
})
.collect();

TypeParamData { names, bounds }
let container_count = type_params.container_type_params();

TypeParamData {
names,
container_count,
bounds,
}
}

fn create_struct_fields(sa: &Sema, struct_: &StructDefinition) -> Vec<StructField> {
Expand Down
57 changes: 35 additions & 22 deletions dora-frontend/src/typeck/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,32 +129,45 @@ fn check_expr_call_generic_static_method(
let trait_method = ck.sa.fct(trait_method_id);

let tp = SourceType::TypeParam(tp_id);
let fct_type_params = trait_ty.type_params.connect(&pure_fct_type_params);
let combined_fct_type_params = trait_ty.type_params.connect(&pure_fct_type_params);

check_args_compatible_fct2(ck, trait_method, arguments, |ty| {
specialize_ty_for_generic(ck.sa, ty, tp_id, &trait_ty, &fct_type_params, &tp)
});

let call_type = CallType::GenericStaticMethod(
tp_id,
trait_ty.trait_id,
trait_method_id,
trait_ty.type_params.clone(),
);
ck.analysis.map_calls.insert(e.id, Arc::new(call_type));

let return_type = specialize_ty_for_generic(
if check_type_params(
ck.sa,
trait_method.return_type(),
tp_id,
&trait_ty,
&fct_type_params,
&tp,
);
ck.element,
&ck.type_param_definition,
trait_method,
&combined_fct_type_params,
ck.file_id,
e.span,
|ty| specialize_ty_for_generic(ck.sa, ty, tp_id, &trait_ty, &combined_fct_type_params, &tp),
) {
check_args_compatible_fct2(ck, trait_method, arguments, |ty| {
specialize_ty_for_generic(ck.sa, ty, tp_id, &trait_ty, &combined_fct_type_params, &tp)
});

let call_type = CallType::GenericStaticMethod(
tp_id,
trait_ty.trait_id,
trait_method_id,
combined_fct_type_params.clone(),
);
ck.analysis.map_calls.insert(e.id, Arc::new(call_type));

ck.analysis.set_ty(e.id, return_type.clone());
let return_type = specialize_ty_for_generic(
ck.sa,
trait_method.return_type(),
tp_id,
&trait_ty,
&combined_fct_type_params,
&tp,
);

return_type
ck.analysis.set_ty(e.id, return_type.clone());

return_type
} else {
SourceType::Error
}
}

fn check_expr_call_expr(
Expand Down
63 changes: 63 additions & 0 deletions dora-frontend/src/typeck/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,69 @@ fn trait_method_call_with_function_type_params_invalid_param() {
);
}

#[test]
fn static_trait_method_call_with_function_type_params() {
ok("
trait Foo {
static fn bar[T](x: T);
}
fn f[T: Foo](x: T) {
T::bar[Int](1);
}
");
}

#[test]
fn static_trait_method_call_with_function_type_params_and_missing_params() {
err(
"
trait Foo {
static fn bar[T](x: T);
}
fn f[T: Foo](x: T) {
T::bar(1);
}
",
(7, 13),
ErrorMessage::WrongNumberTypeParams(1, 0),
);

err(
"
trait Foo[X] {
static fn bar[T](x: T);
}
fn f[T: Foo[Int]](x: T) {
T::bar(1);
}
",
(7, 13),
ErrorMessage::WrongNumberTypeParams(1, 0),
);
}

#[test]
fn static_trait_method_call_with_function_type_params_invalid_param() {
err(
"
trait Foo {
static fn bar[T: Bar](x: T);
}
trait Bar {}
fn f[T: Foo](x: T) {
T::bar[Int](1);
}
",
(9, 13),
ErrorMessage::TypeNotImplementingTrait("Int64".into(), "Bar".into()),
);
}

#[test]
fn test_generic_ctor_without_type_params() {
err(
Expand Down
12 changes: 10 additions & 2 deletions dora-runtime/src/cannon/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2657,17 +2657,25 @@ impl<'a> CannonCodeGen<'a> {
FunctionKind::Trait(trait_id) => trait_id,
_ => unreachable!(),
};

let type_params = self.specialize_ty_array(&type_params);
assert!(type_params.is_concrete_type());
let (trait_type_params, pure_fct_type_params) =
type_params.split(fct.type_params.container_count);

let trait_ty = BytecodeTraitType {
trait_id,
type_params: type_params.clone(),
type_params: trait_type_params,
bindings: Vec::new(),
};

let (callee_id, type_params) = find_trait_impl(self.vm, trait_fct_id, trait_ty, ty);
let (callee_id, container_bindings) = find_trait_impl(self.vm, trait_fct_id, trait_ty, ty);

let pos = self.bytecode.offset_location(self.current_offset.to_u32());
let arguments = self.argument_stack.drain(..).collect::<Vec<_>>();

let type_params = container_bindings.connect(&pure_fct_type_params);

if is_static {
self.emit_invoke_static_or_intrinsic(dest, callee_id, type_params, arguments, pos);
} else {
Expand Down
1 change: 1 addition & 0 deletions dora-runtime/src/vm/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub fn find_trait_impl(

let type_param_data = TypeParamData {
names: Vec::new(),
container_count: 0,
bounds: Vec::new(),
};

Expand Down
1 change: 1 addition & 0 deletions dora-runtime/src/vm/specialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ pub fn specialize_ty(

let type_param_data = TypeParamData {
names: Vec::new(),
container_count: 0,
bounds: Vec::new(),
};

Expand Down
18 changes: 18 additions & 0 deletions tests/generic/generic-fct-type-args1.dora
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
trait Foo {
fn bar[T](x: T);
}

class Bar

impl Foo for Bar {
fn bar[T](x: T) {}
}

fn myf[T: Foo](x: T) {
x.bar[Int](1);
}

fn main() {
let bar = Bar();
myf[Bar](bar);
}
17 changes: 17 additions & 0 deletions tests/generic/generic-static-fct-type-args1.dora
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
trait Foo {
static fn bar[T](x: T);
}

class Bar

impl Foo for Bar {
static fn bar[T](x: T) {}
}

fn myf[T: Foo]() {
T::bar[Int](1);
}

fn main() {
myf[Bar]();
}

0 comments on commit 709457d

Please sign in to comment.