Skip to content

Commit

Permalink
Support early return from loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Feb 13, 2025
1 parent e03c8e1 commit 607a9ea
Show file tree
Hide file tree
Showing 7 changed files with 439 additions and 30 deletions.
5 changes: 5 additions & 0 deletions corelib/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ pub const fn assert(cond: bool, err_code: felt252) {
}
}

pub enum LoopResult<N, E> {
Normal: N,
EarlyReturn: E,
}

pub mod hash;

pub mod keccak;
Expand Down
18 changes: 16 additions & 2 deletions crates/cairo-lang-lowering/src/lower/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use cairo_lang_semantic::expr::fmt::ExprFormatter;
use cairo_lang_semantic::items::enm::SemanticEnumEx;
use cairo_lang_semantic::items::imp::ImplLookupContext;
use cairo_lang_semantic::usage::Usages;
use cairo_lang_semantic::{ConcreteVariant, TypeId};
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_utils::Intern;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
Expand Down Expand Up @@ -123,6 +124,19 @@ impl<'db> EncapsulatingLoweringContext<'db> {
}
}

#[derive(Clone)]
pub struct LoopEarlyReturnInfo {
pub ret_ty: TypeId,
pub normal_return_variant: ConcreteVariant,
pub early_return_variant: ConcreteVariant,
}

pub struct LoopContext {
/// loop expression needed for recursive calls in `continue`
pub loop_expr_id: semantic::ExprId,
pub early_return_info: Option<LoopEarlyReturnInfo>,
}

pub struct LoweringContext<'a, 'db> {
pub encapsulating_ctx: Option<&'a mut EncapsulatingLoweringContext<'db>>,
/// Variable allocator.
Expand All @@ -135,7 +149,7 @@ pub struct LoweringContext<'a, 'db> {
/// This it the generic function specialized with its own generic parameters.
pub concrete_function_id: ConcreteFunctionWithBodyId,
/// Current loop expression needed for recursive calls in `continue`
pub current_loop_expr_id: Option<semantic::ExprId>,
pub current_loop_ctx: Option<LoopContext>,
/// Current emitted diagnostics.
pub diagnostics: LoweringDiagnostics,
/// Lowered blocks of the function.
Expand All @@ -159,7 +173,7 @@ impl<'a, 'db> LoweringContext<'a, 'db> {
signature,
function_id,
concrete_function_id,
current_loop_expr_id: None,
current_loop_ctx: None,
diagnostics: LoweringDiagnostics::default(),
blocks: Default::default(),
})
Expand Down
151 changes: 134 additions & 17 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use block_builder::BlockBuilder;
use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::diagnostic_utils::StableLocation;
use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::corelib::{ErrorPropagationType, unwrap_error_propagation_type};
use cairo_lang_semantic::corelib::{
ErrorPropagationType, get_core_enum_concrete_variant, unwrap_error_propagation_type,
};
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
use cairo_lang_semantic::items::imp::ImplLongId;
Expand Down Expand Up @@ -51,7 +53,7 @@ use crate::ids::{
FunctionLongId, FunctionWithBodyId, FunctionWithBodyLongId, GeneratedFunction,
GeneratedFunctionKey, LocationId, SemanticFunctionIdEx, Signature, parameter_as_member_path,
};
use crate::lower::context::{LoweringResult, VarRequest};
use crate::lower::context::{LoopContext, LoopEarlyReturnInfo, LoweringResult, VarRequest};
use crate::lower::generators::StructDestructure;
use crate::lower::lower_match::{
MatchArmWrapper, TupleInfo, lower_concrete_enum_match, lower_expr_match_tuple,
Expand Down Expand Up @@ -422,11 +424,12 @@ pub fn lower_loop_function(
encapsulating_ctx: &mut EncapsulatingLoweringContext<'_>,
function_id: FunctionWithBodyId,
loop_signature: Signature,
loop_expr_id: semantic::ExprId,
loop_ctx: LoopContext,
snapped_params: &OrderedHashMap<MemberPath, semantic::ExprVarMemberPath>,
) -> Maybe<FlatLowered> {
let loop_expr_id = loop_ctx.loop_expr_id;
let mut ctx = LoweringContext::new(encapsulating_ctx, function_id, loop_signature.clone())?;
let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id));
let old_loop_ctx = std::mem::replace(&mut ctx.current_loop_ctx, Some(loop_ctx));

// Initialize builder.
let root_block_id = alloc_empty_block(&mut ctx);
Expand Down Expand Up @@ -492,7 +495,7 @@ pub fn lower_loop_function(

Ok(root_block_id)
})();
ctx.current_loop_expr_id = old_loop_expr_id;
ctx.current_loop_ctx = old_loop_ctx;

let blocks = root_ok
.map(|_| ctx.blocks.build().expect("Root block must exist."))
Expand Down Expand Up @@ -538,14 +541,28 @@ fn wrap_sealed_block_as_function(
}
_ => {
// Convert to a return.
let var_usage = expr.unwrap_or_else(|| {
let mut var_usage = expr.unwrap_or_else(|| {
generators::StructConstruct {
inputs: vec![],
ty: unit_ty(ctx.db.upcast()),
location,
}
.add(ctx, &mut builder.statements)
});

if let Some(LoopContext {
early_return_info: Some(LoopEarlyReturnInfo { normal_return_variant, .. }),
..
}) = &ctx.current_loop_ctx
{
var_usage = generators::EnumConstruct {
input: var_usage,
variant: normal_return_variant.clone(),
location,
}
.add(ctx, &mut builder.statements);
}

builder.ret(ctx, var_usage, location)
}
}
Expand Down Expand Up @@ -625,7 +642,7 @@ pub fn lowered_expr_to_block_scope_end(
})
}

/// Converts [`LoweringResult<LoweredExpr>`] into `BlockScopeEnd`.
/// Generates the lowering for an early return.
pub fn lower_early_return(
ctx: &mut LoweringContext<'_, '_>,
mut builder: BlockBuilder,
Expand Down Expand Up @@ -667,7 +684,7 @@ pub fn lower_statement(
ctx,
ctx.signature.clone(),
builder,
ctx.current_loop_expr_id.unwrap(),
ctx.current_loop_ctx.as_ref().unwrap().loop_expr_id,
stable_ptr.untyped(),
)?;
let ret_var = lowered_expr.as_var_usage(ctx, builder)?;
Expand All @@ -676,13 +693,28 @@ pub fn lower_statement(
semantic::Statement::Return(semantic::StatementReturn { expr_option, stable_ptr })
| semantic::Statement::Break(semantic::StatementBreak { expr_option, stable_ptr }) => {
log::trace!("Lowering a return | break statement.");
let ret_var = match expr_option {
let location = ctx.get_location(stable_ptr.untyped());
let mut ret_var = match expr_option {
None => {
let location = ctx.get_location(stable_ptr.untyped());
LoweredExpr::Tuple { exprs: vec![], location }.as_var_usage(ctx, builder)?
}
Some(expr) => lower_expr_to_var_usage(ctx, builder, *expr)?,
};

if matches!(stmt, semantic::Statement::Return(_)) {
if let Some(LoopContext {
early_return_info: Some(LoopEarlyReturnInfo { early_return_variant, .. }),
..
}) = &ctx.current_loop_ctx
{
ret_var = generators::EnumConstruct {
input: ret_var,
variant: early_return_variant.clone(),
location,
}
.add(ctx, &mut builder.statements);
}
}
return Err(LoweringFlowError::Return(ret_var, ctx.get_location(stable_ptr.untyped())));
}
semantic::Statement::Item(_) => {}
Expand Down Expand Up @@ -1412,8 +1444,34 @@ fn lower_expr_loop(
_ => unreachable!("Loop expression must be either loop, while or for."),
};

let semantic_db = ctx.db.upcast();

let usage = &ctx.usages.usages[&loop_expr_id];

let early_return_info = if usage.has_early_return {
let generic_args = vec![
GenericArgumentId::Type(return_type),
GenericArgumentId::Type(ctx.signature.return_type),
];
Some(LoopEarlyReturnInfo {
ret_ty: get_core_ty_by_name(semantic_db, "LoopResult".into(), generic_args.clone()),
normal_return_variant: get_core_enum_concrete_variant(
semantic_db,
"LoopResult",
generic_args.clone(),
"Normal",
),
early_return_variant: get_core_enum_concrete_variant(
semantic_db,
"LoopResult",
generic_args,
"EarlyReturn",
),
})
} else {
None
};

// Determine signature.
let params = usage
.usage
Expand All @@ -1437,13 +1495,14 @@ fn lower_expr_loop(
.collect_vec();
let extra_rets = usage.changes.iter().map(|(_, expr)| expr.clone()).collect_vec();

let loop_location = ctx.get_location(stable_ptr.untyped());
let loop_signature = Signature {
params,
extra_rets,
return_type,
return_type: if let Some(info) = &early_return_info { info.ret_ty } else { return_type },
implicits: vec![],
panicable: ctx.signature.panicable,
location: ctx.get_location(stable_ptr.untyped()),
location: loop_location,
};

// Get the function id.
Expand All @@ -1457,18 +1516,19 @@ fn lower_expr_loop(

// Generate the function.
let encapsulating_ctx = std::mem::take(&mut ctx.encapsulating_ctx).unwrap();

let loop_ctx = LoopContext { loop_expr_id, early_return_info: early_return_info.clone() };
let lowered = lower_loop_function(
encapsulating_ctx,
function,
loop_signature.clone(),
loop_expr_id,
loop_ctx,
&snap_usage,
)
.map_err(LoweringFlowError::Failed)?;
// TODO(spapini): Recursive call.
encapsulating_ctx.lowerings.insert(GeneratedFunctionKey::Loop(stable_ptr), lowered);
ctx.encapsulating_ctx = Some(encapsulating_ctx);
let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id));
for snapshot_param in snap_usage.values() {
// if we have access to the real member we generate a snapshot, otherwise it should be
// accessible with `builder.get_snap_ref`
Expand All @@ -1482,10 +1542,67 @@ fn lower_expr_loop(
builder.update_ref(ctx, snapshot_param, original);
}
}
let call = call_loop_func(ctx, loop_signature, builder, loop_expr_id, stable_ptr.untyped());
let call_loop_expr =
call_loop_func(ctx, loop_signature, builder, loop_expr_id, stable_ptr.untyped());

let Some(LoopEarlyReturnInfo { ret_ty: _, normal_return_variant, early_return_variant }) =
early_return_info
else {
return call_loop_expr;
};

let loop_res = call_loop_expr?.as_var_usage(ctx, builder)?;

let normal_return_subscope = create_subscope(ctx, builder);
let normal_return_subscope_block_id = normal_return_subscope.block_id;
let normal_return_var_id = ctx.new_var(VarRequest { ty: return_type, location: loop_location });

let sealed_normal_return = lowered_expr_to_block_scope_end(
ctx,
normal_return_subscope,
Ok(LoweredExpr::AtVariable(VarUsage {
var_id: normal_return_var_id,
location: loop_location,
})),
)
.map_err(LoweringFlowError::Failed)?;

let early_return_subscope = create_subscope(ctx, builder);
let early_return_var_id =
ctx.new_var(VarRequest { ty: ctx.signature.return_type, location: loop_location });
let early_return_subscope_block_id = early_return_subscope.block_id;
let sealed_early_return = lower_early_return(
ctx,
early_return_subscope,
LoweredExpr::AtVariable(VarUsage { var_id: early_return_var_id, location: loop_location }),
loop_location,
)
.map_err(LoweringFlowError::Failed)?;

let match_info = MatchInfo::Enum(MatchEnumInfo {
concrete_enum_id: normal_return_variant.concrete_enum_id,
input: loop_res,
arms: vec![
MatchArm {
arm_selector: MatchArmSelector::VariantId(normal_return_variant),
block_id: normal_return_subscope_block_id,
var_ids: vec![normal_return_var_id],
},
MatchArm {
arm_selector: MatchArmSelector::VariantId(early_return_variant),
block_id: early_return_subscope_block_id,
var_ids: vec![early_return_var_id],
},
],
location: loop_location,
});

ctx.current_loop_expr_id = old_loop_expr_id;
call
builder.merge_and_end_with_match(
ctx,
match_info,
vec![sealed_normal_return, sealed_early_return],
loop_location,
)
}

/// Adds a call to an inner loop-generated function.
Expand Down
Loading

0 comments on commit 607a9ea

Please sign in to comment.