Skip to content

Commit

Permalink
feat: Add English language parsing of failure messages
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Chernicoff <[email protected]>
  • Loading branch information
mchernicoff committed Jan 22, 2025
1 parent b35bf4c commit 89a4432
Show file tree
Hide file tree
Showing 16 changed files with 293 additions and 45 deletions.
52 changes: 27 additions & 25 deletions hipcheck/src/policy_exprs/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,45 +363,45 @@ impl Env<'_> {
let mut env = Env::empty();

// Comparison functions.
env.add_fn("gt", gt, 2, ty_comp);
env.add_fn("lt", lt, 2, ty_comp);
env.add_fn("gte", gte, 2, ty_comp);
env.add_fn("lte", lte, 2, ty_comp);
env.add_fn("eq", eq, 2, ty_comp);
env.add_fn("neq", neq, 2, ty_comp);
env.add_fn("gt", "greater than", gt, 2, ty_comp);
env.add_fn("lt", "less than", lt, 2, ty_comp);
env.add_fn("gte", "greater than or equal to", gte, 2, ty_comp);
env.add_fn("lte", "less than or equal to", lte, 2, ty_comp);
env.add_fn("eq", "equal to", eq, 2, ty_comp);
env.add_fn("neq", "not equal to", neq, 2, ty_comp);

// Math functions.
env.add_fn("add", add, 2, ty_arithmetic_binary_ops);
env.add_fn("sub", sub, 2, ty_arithmetic_binary_ops);
env.add_fn("divz", divz, 2, ty_divz);
env.add_fn("add", "plus", add, 2, ty_arithmetic_binary_ops);
env.add_fn("sub", "minus", sub, 2, ty_arithmetic_binary_ops);
env.add_fn("divz", "divided by", divz, 2, ty_divz);

// Additional datetime math functions
env.add_fn("duration", duration, 2, ty_duration);
env.add_fn("duration", "minus", duration, 2, ty_duration);

// Logical functions.
env.add_fn("and", and, 2, ty_bool_binary);
env.add_fn("or", or, 2, ty_bool_binary);
env.add_fn("not", not, 1, ty_bool_unary);
env.add_fn("and", "and", and, 2, ty_bool_binary);
env.add_fn("or", "or", or, 2, ty_bool_binary);
env.add_fn("not", "not", not, 1, ty_bool_unary);

// Array math functions.
env.add_fn("max", max, 1, ty_from_first_arr);
env.add_fn("min", min, 1, ty_from_first_arr);
env.add_fn("avg", avg, 1, ty_avg);
env.add_fn("median", median, 1, ty_from_first_arr);
env.add_fn("count", count, 1, ty_count);
env.add_fn("max", "the maximum of", max, 1, ty_from_first_arr);
env.add_fn("min", "the minimum of", min, 1, ty_from_first_arr);
env.add_fn("avg", "the mean of", avg, 1, ty_avg);
env.add_fn("median", "the median of", median, 1, ty_from_first_arr);
env.add_fn("count", "the number of elements in", count, 1, ty_count);

// Array logic functions.
env.add_fn("all", all, 1, ty_higher_order_bool_fn);
env.add_fn("nall", nall, 1, ty_higher_order_bool_fn);
env.add_fn("some", some, 1, ty_higher_order_bool_fn);
env.add_fn("none", none, 1, ty_higher_order_bool_fn);
env.add_fn("all", "all of", all, 1, ty_higher_order_bool_fn);
env.add_fn("nall", "not all of", nall, 1, ty_higher_order_bool_fn);
env.add_fn("some", "some of", some, 1, ty_higher_order_bool_fn);
env.add_fn("none", "none of", none, 1, ty_higher_order_bool_fn);

// Array higher-order functions.
env.add_fn("filter", filter, 2, ty_filter);
env.add_fn("foreach", foreach, 2, ty_foreach);
env.add_fn("filter", "filtered on", filter, 2, ty_filter);
env.add_fn("foreach", "each to be", foreach, 2, ty_foreach);

// Debugging functions.
env.add_fn("dbg", dbg, 1, ty_inherit_first);
env.add_fn("dbg", "debugging", dbg, 1, ty_inherit_first);

env
}
Expand All @@ -423,6 +423,7 @@ impl Env<'_> {
pub fn add_fn(
&mut self,
name: &str,
english: &str,
op: Op,
expected_args: usize,
ty_checker: TypeChecker,
Expand All @@ -431,6 +432,7 @@ impl Env<'_> {
name.to_owned(),
Binding::Fn(FunctionDef {
name: name.to_owned(),
english: english.to_owned(),
expected_args,
ty_checker,
op,
Expand Down
3 changes: 3 additions & 0 deletions hipcheck/src/policy_exprs/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ pub enum Error {
#[error("expression returned '{0:?}', not a boolean")]
DidNotReturnBool(Expr),

#[error("evaluation of inner expression returned '{0:?}', not a primitive")]
BadReturnType(Expr),

#[error("tried to call unknown function '{0}'")]
UnknownFunction(String),

Expand Down
1 change: 1 addition & 0 deletions hipcheck/src/policy_exprs/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub type TypeChecker = fn(&[Type]) -> Result<ReturnableType>;
#[derive(Clone, PartialEq, Debug, Eq)]
pub struct FunctionDef {
pub name: String,
pub english: String,
pub expected_args: usize,
pub ty_checker: TypeChecker,
pub op: Op,
Expand Down
209 changes: 207 additions & 2 deletions hipcheck/src/policy_exprs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,17 @@ impl FromStr for Expr {
}

/// Evaluates `deke` expressions.
#[cfg(test)]
pub struct Executor {
env: Env<'static>,
}

#[cfg(test)]
impl Executor {
/// Create an `Executor` with the standard set of functions defined.
pub fn std() -> Self {
Executor { env: Env::std() }
}

#[cfg(test)]
/// Run a `deke` program.
pub fn run(&self, raw_program: &str, context: &Value) -> Result<bool> {
match self.parse_and_eval(raw_program, context)? {
Expand Down Expand Up @@ -156,6 +155,212 @@ impl ExprMutator for Env<'_> {
}
}

/// Return an English language explanation for what a failing plugin was expected to see and what it saw instead
pub fn parse_failing_expr_to_english(
input: &Expr,
message: &str,
value: &Option<Value>,
) -> Result<String> {
// Create a standard environment, with its list of functions and their English descriptions
let env = Env::std();
// Store that environment and the plugin explanation message in a struct for English parsing
let english = English {
env,
message: message.to_string(),
};

// Check that the "top level" of the policy expression is a function, then recursively parse that function into an English language description of why the plugin failed
if let Expr::Function(func) = input {
// Recursively parse the top level function to English
let english_expr = english.visit_function(func)?;

// Get the function's args
let args = &func.args;

// Confirm that the outermost function has two arguments
if args.len() != 2 {
return Err(Error::MissingArgs);
}

// Get whichever of the function's arguments is **not** a primitive (i.e. the top level expected value) for evaluation
let inner = match (&args[0], &args[1]) {
(&Expr::Primitive(_), inner) => inner,
(inner, &Expr::Primitive(_)) => inner,
_ => return Err(Error::MissingArgs),
};

// Evaluate that argument using the value returned by the plugin to see what the top level operator is comparing the expected value to
let inner_value = match value {
Some(context) => {
format!(
"it was {}",
match Executor::std().parse_and_eval(&inner.to_string(), context)? {
Expr::Primitive(prim) => english.visit_primitive(&prim)?,
_ => return Err(Error::BadReturnType(inner.clone())),
}
)
}
None => "no value was returned by the query".to_string(),
};

return Ok(format!("Expected {english_expr} but {inner_value}"));
}

Err(Error::MissingIdent)
}

/// Struct that contains a basic environment, with its English function descriptions, and a plugin explanation message.
pub struct English<'a> {
env: Env<'a>,
message: String,
}

// Trait implementation to return English descriptions of an Expr
impl ExprVisitor<Result<String>> for English<'_> {
/// Parse a function expression into an English string
fn visit_function(&self, func: &Function) -> Result<String> {
let env = &self.env;

// Get the function operator from the list of functions in the environment
let ident = &func.ident;
let fn_name = ident.to_string();

let function_def = match env.get(&fn_name) {
Some(binding) => match binding {
Binding::Fn(function_def) => function_def,
_ => {
return Err(Error::UnknownFunction(format!(
"Given function name {} is not a function",
fn_name
)))
}
},
_ => {
return Err(Error::UnknownFunction(format!(
"Given function name {} not found in list of functions",
fn_name
)))
}
};

// Convert theoperator to English, with additional phrasing specific to comparison operators in a function
let operator = match function_def.name.as_ref() {
"gt" | "lt" | "gte" | "lte" | "eq" | "ne" => format!("to be {}", function_def.english),
_ => function_def.english,
};

// Get the number of args the function should have
let expected_args = function_def.expected_args;

// Get the funciton's args
let args = &func.args;

// Check for an invalid number of arguments
if args.len() < expected_args {
return Err(Error::NotEnoughArgs {
name: fn_name,
expected: expected_args,
given: args.len(),
});
}
if args.len() > expected_args {
return Err(Error::TooManyArgs {
name: fn_name,
expected: expected_args,
given: args.len(),
});
}

if args.len() == 2 {
// If there are two arguments, parse a function comparing a pair of some combination of primitives,
// JSON pointers, nested functions (including lambdas in the first position), or arrays (in the second position) to English
if matches!(args[0], Expr::Array(_)) || matches!(args[1], Expr::Lambda(_)) {
return Err(Error::BadType("English::visit_function()"));
}
let argument_1 = self.visit_expr(&args[0])?;
let argument_2 = self.visit_expr(&args[1])?;

Ok(format!("{} {} {}", argument_1, operator, argument_2))
} else {
// If there is one argument, parse a function operating on an array, JSON pointer, or a nested function to English
if matches!(args[0], Expr::Lambda(_)) {
return Err(Error::BadType("English::visit_function()"));
}
let argument = self.visit_expr(&args[0])?;

Ok(format!("{} {}", operator, argument))
}
}

/// Parse a lambda expression into an English string
fn visit_lambda(&self, func: &Lambda) -> Result<String> {
let env = &self.env;

// Get the lambda function from the lambda
let function = &func.body;
//Get the lambda's function operator from the list of functions in the environment
let ident = &function.ident;
let fn_name = ident.to_string();

let function_def = match env.get(&fn_name) {
Some(binding) => match binding {
Binding::Fn(function_def) => function_def,
_ => {
return Err(Error::UnknownFunction(format!(
"Given function name {} is not a function",
fn_name
)))
}
},
_ => {
return Err(Error::UnknownFunction(format!(
"Given function name {} not found in list of functions",
fn_name
)))
}
};

// Convert the operator to English
let operator = function_def.english;

// Get the lambda function's argument and parse it to English
// Note: The useful arugment for a lambda function is the *second* argument
let args = &function.args;
let argument = self.visit_expr(&args[1])?;

Ok(format!("\"{} {}\"", operator, argument))
}

// Parse a primitive type expression to English
fn visit_primitive(&self, prim: &Primitive) -> Result<String> {
match prim {
Primitive::Bool(true) => Ok("true".to_string()),
Primitive::Bool(false) => Ok("false".to_string()),
Primitive::Int(i) => Ok(i.to_string()),
Primitive::Float(f) => Ok(f.to_string()),
Primitive::DateTime(dt) => Ok(dt.to_string()),
Primitive::Span(span) => Ok(span.to_string()),
_ => Err(Error::BadType("English::visit_primitive()")),
}
}

// Parse a primitive type array expression to English
fn visit_array(&self, arr: &Array) -> Result<String> {
let english_elts = arr
.elts
.iter()
.map(|p| self.visit_primitive(p).unwrap())
.collect::<Vec<String>>()
.join(",");
Ok(format!("the array [{}]", english_elts))
}

// Parse a JSON pointer expression into English by returning the explanation message for the plugin
fn visit_json_pointer(&self, _func: &JsonPointer) -> Result<String> {
Ok(self.message.clone())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 89a4432

Please sign in to comment.