Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return Result in accessors instead of panic #286

Merged
merged 5 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ impl<'a> OperationField<'a> {
let (param_type, return_type) = {
if ac.is_unit() {
(quote! { bool }, quote! { bool })
} else if ac.is_optional() {
(quote! { #kind_type<'c> }, quote! { Option<#kind_type<'c>> })
} else {
(quote! { #kind_type<'c> }, quote! { #kind_type<'c> })
(
quote! { #kind_type<'c> },
quote! { Result<#kind_type<'c>, ::melior::Error> },
)
}
};
let sanitized = sanitize_name_snake(name);
Expand Down Expand Up @@ -108,7 +109,7 @@ impl<'a> OperationField<'a> {
} else {
(
quote! { ::melior::ir::Region<'c> },
quote! { ::melior::ir::RegionRef<'c, '_> },
quote! { Result<::melior::ir::RegionRef<'c, '_>, ::melior::Error> },
)
}
};
Expand Down Expand Up @@ -142,7 +143,7 @@ impl<'a> OperationField<'a> {
} else {
(
quote! { &::melior::ir::Block<'c> },
quote! { ::melior::ir::BlockRef<'c, '_> },
quote! { Result<::melior::ir::BlockRef<'c, '_>, ::melior::Error> },
)
}
};
Expand Down Expand Up @@ -201,16 +202,23 @@ impl<'a> OperationField<'a> {
if tc.is_optional() {
(
quote! { #param_kind_type },
quote! { Option<#return_kind_type> },
quote! { Result<#return_kind_type, ::melior::Error> },
)
} else {
(
quote! { &[#param_kind_type] },
quote! { impl Iterator<Item = #return_kind_type> },
if let VariadicKind::AttrSized {} = variadic_info {
quote! { Result<impl Iterator<Item = #return_kind_type>, ::melior::Error> }
} else {
quote! { impl Iterator<Item = #return_kind_type> }
},
)
}
} else {
(param_kind_type, return_kind_type)
(
param_kind_type,
quote!(Result<#return_kind_type, ::melior::Error>),
)
}
};

Expand Down
65 changes: 27 additions & 38 deletions macro/src/dialect/operation/accessors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ impl<'a> OperationField<'a> {
.variadic_info
.as_ref()
.expect("operands and results need variadic info");
let error_variant = match &self.kind {
FieldKind::Operand(_) => quote!(OperandNotFound),
FieldKind::Result(_) => quote!(ResultNotFound),
_ => unreachable!(),
};
let name = self.name;

Some(match variadic_kind {
VariadicKind::Simple {
Expand All @@ -32,9 +38,9 @@ impl<'a> OperationField<'a> {
// elements.
quote! {
if self.operation.#count() < #len {
None
Err(::melior::Error::#error_variant(#name))
} else {
self.operation.#kind_ident(#index).ok()
self.operation.#kind_ident(#index)
}
}
} else {
Expand All @@ -49,16 +55,14 @@ impl<'a> OperationField<'a> {
} else if *seen_variable_length {
// Single element after variable length group
// Compute the length of that variable group and take the next element
let error = format!("operation should have this {}", kind);
quote! {
let group_length = self.operation.#count() - #len + 1;
self.operation.#kind_ident(#index + group_length - 1).expect(#error)
self.operation.#kind_ident(#index + group_length - 1)
}
} else {
// All elements so far are singular
let error = format!("operation should have this {}", kind);
quote! {
self.operation.#kind_ident(#index).expect(#error)
self.operation.#kind_ident(#index)
}
}
}
Expand All @@ -67,7 +71,6 @@ impl<'a> OperationField<'a> {
num_preceding_simple,
num_preceding_variadic,
} => {
let error = format!("operation should have this {}", kind);
let compute_start_length = quote! {
let total_var_len = self.operation.#count() - #num_variable_length + 1;
let group_len = total_var_len / #num_variable_length;
Expand All @@ -79,47 +82,42 @@ impl<'a> OperationField<'a> {
}
} else {
quote! {
self.operation.#kind_ident(start).expect(#error)
self.operation.#kind_ident(start)
}
};

quote! { #compute_start_length #get_elements }
}
VariadicKind::AttrSized {} => {
let error = format!("operation should have this {}", kind);
let attribute_name = format!("{}_segment_sizes", kind);
let attribute_missing_error =
format!("operation has {} attribute", attribute_name);
let compute_start_length = quote! {
let attribute =
::melior::ir::attribute::DenseI32ArrayAttribute::<'c>::try_from(
self.operation
.attribute(#attribute_name)
.expect(#attribute_missing_error)
).expect("is a DenseI32ArrayAttribute");
.attribute(#attribute_name)?
)?;
let start = (0..#index)
.map(|index| attribute.element(index)
.expect("has segment size"))
.map(|index| attribute.element(index))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.sum::<i32>() as usize;
let group_len = attribute
.element(#index)
.expect("has segment size") as usize;
let group_len = attribute.element(#index)? as usize;
};
let get_elements = if !constraint.is_variable_length() {
quote! {
self.operation.#kind_ident(start).expect(#error)
self.operation.#kind_ident(start)
}
} else if constraint.is_optional() {
quote! {
if group_len == 0 {
None
Err(::melior::Error::#error_variant(#name))
} else {
self.operation.#kind_ident(start).ok()
self.operation.#kind_ident(start)
}
}
} else {
quote! {
self.operation.#plural().skip(start).take(group_len)
Ok(self.operation.#plural().skip(start).take(group_len))
}
};

Expand All @@ -140,7 +138,7 @@ impl<'a> OperationField<'a> {
}
} else {
quote! {
self.operation.successor(#index).expect("operation should have this successor")
self.operation.successor(#index)
}
})
}
Expand All @@ -155,30 +153,21 @@ impl<'a> OperationField<'a> {
}
} else {
quote! {
self.operation.region(#index).expect("operation should have this region")
self.operation.region(#index)
}
})
}
FieldKind::Attribute(constraint) => {
let name = &self.name;
let attribute_error = format!("operation should have attribute {}", name);
let type_error = format!("{} should be a {}", name, constraint.storage_type());

Some(if constraint.is_unit() {
quote! { self.operation.attribute(#name).is_some() }
} else if constraint.is_optional() {
quote! {
self.operation
.attribute(#name)
.map(|attribute| attribute.try_into().expect(#type_error))
}
} else {
quote! {
self.operation
.attribute(#name)
.expect(#attribute_error)
.attribute(#name)?
.try_into()
.expect(#type_error)
.map_err(::melior::Error::from)
}
})
}
Expand All @@ -192,7 +181,7 @@ impl<'a> OperationField<'a> {

if constraint.is_unit() || constraint.is_optional() {
Some(quote! {
let _ = self.operation.remove_attribute(#name);
self.operation.remove_attribute(#name)
})
} else {
None
Expand Down Expand Up @@ -239,7 +228,7 @@ impl<'a> OperationField<'a> {
let ident = sanitize_name_snake(&format!("remove_{}", self.name));
self.remover_impl().map_or(quote!(), |body| {
quote! {
pub fn #ident(&mut self) {
pub fn #ident(&mut self) -> Result<(), ::melior::Error> {
#body
}
}
Expand Down
6 changes: 3 additions & 3 deletions macro/tests/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ fn simple() {
location,
);

assert_eq!(op.lhs(), block.argument(0).unwrap().into());
assert_eq!(op.rhs(), block.argument(1).unwrap().into());
assert_eq!(op.lhs().unwrap(), block.argument(0).unwrap().into());
assert_eq!(op.rhs().unwrap(), block.argument(1).unwrap().into());
assert_eq!(op.operation().operand_count(), 2);
}

Expand All @@ -48,7 +48,7 @@ fn variadic_after_single() {
location,
);

assert_eq!(op.first(), block.argument(0).unwrap().into());
assert_eq!(op.first().unwrap(), block.argument(0).unwrap().into());
assert_eq!(op.others().next(), Some(block.argument(2).unwrap().into()));
assert_eq!(op.others().nth(1), Some(block.argument(1).unwrap().into()));
assert_eq!(op.operation().operand_count(), 3);
Expand Down
4 changes: 2 additions & 2 deletions macro/tests/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn single() {
region_test::single(r1, location)
};

assert!(op.default_region().first_block().is_some());
assert!(op.default_region().unwrap().first_block().is_some());
}

#[test]
Expand Down Expand Up @@ -51,7 +51,7 @@ fn variadic_after_single() {

assert_eq!(op.operation().to_string(), op2.operation().to_string());

assert!(op.default_region().first_block().is_none());
assert!(op.default_region().unwrap().first_block().is_none());
assert_eq!(op.other_regions().count(), 2);
assert!(op.other_regions().next().unwrap().first_block().is_some());
assert!(op.other_regions().nth(1).unwrap().first_block().is_none());
Expand Down
15 changes: 15 additions & 0 deletions melior/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
convert::Infallible,
error,
fmt::{self, Display, Formatter},
str::Utf8Error,
Expand All @@ -15,13 +16,15 @@ pub enum Error {
value: String,
},
InvokeFunction,
OperandNotFound(&'static str),
OperationResultExpected(String),
PositionOutOfBounds {
name: &'static str,
value: String,
index: usize,
},
ParsePassPipeline(String),
ResultNotFound(&'static str),
RunPass,
TypeExpected(&'static str, String),
UnknownDiagnosticSeverity(u32),
Expand All @@ -44,6 +47,9 @@ impl Display for Error {
write!(formatter, "element of {type} type expected: {value}")
}
Self::InvokeFunction => write!(formatter, "failed to invoke JIT-compiled function"),
Self::OperandNotFound(name) => {
write!(formatter, "operand {name} not found")
}
Self::OperationResultExpected(value) => {
write!(formatter, "operation result expected: {value}")
}
Expand All @@ -53,6 +59,9 @@ impl Display for Error {
Self::PositionOutOfBounds { name, value, index } => {
write!(formatter, "{name} position {index} out of bounds: {value}")
}
Self::ResultNotFound(name) => {
write!(formatter, "result {name} not found")
}
Self::RunPass => write!(formatter, "failed to run pass"),
Self::TypeExpected(r#type, actual) => {
write!(formatter, "{type} type expected: {actual}")
Expand All @@ -74,3 +83,9 @@ impl From<Utf8Error> for Error {
Self::Utf8(error)
}
}

impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
9 changes: 5 additions & 4 deletions melior/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,19 @@ impl<'c> Operation<'c> {
}

/// Gets a attribute with the given name.
pub fn attribute(&self, name: &str) -> Option<Attribute<'c>> {
pub fn attribute(&self, name: &str) -> Result<Attribute<'c>, Error> {
unsafe {
Attribute::from_option_raw(mlirOperationGetAttributeByName(
self.raw,
StringRef::from(name).to_raw(),
))
}
.ok_or(Error::AttributeNotFound(name.into()))
}

/// Checks if the operation has a attribute with the given name.
pub fn has_attribute(&self, name: &str) -> bool {
self.attribute(name).is_some()
self.attribute(name).is_ok()
}

/// Sets the attribute with the given name to the given attribute.
Expand Down Expand Up @@ -547,14 +548,14 @@ mod tests {
assert!(operation.has_attribute("foo"));
assert_eq!(
operation.attribute("foo").map(|a| a.to_string()),
Some("\"bar\"".into())
Ok("\"bar\"".into())
);
assert!(operation.remove_attribute("foo").is_ok());
assert!(operation.remove_attribute("foo").is_err());
operation.set_attribute("foo", &StringAttribute::new(&context, "foo").into());
assert_eq!(
operation.attribute("foo").map(|a| a.to_string()),
Some("\"foo\"".into())
Ok("\"foo\"".into())
);
assert_eq!(
operation.attributes().next(),
Expand Down