Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce typed symbol references #758

Merged
merged 9 commits into from
Jan 13, 2025
89 changes: 88 additions & 1 deletion include/vast/Dialect/Core/CommonAttrConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,91 @@

include "mlir/IR/CommonAttrConstraints.td"

#endif // VAST_DIALECT_CORE_COMMON_ATTR_CONSTRAINTS_TD
//
// Attributes for symbol references
//

def Core_VarSymbolRefAttr : Attr<
CPred< "::llvm::isa< ::vast::core::VarSymbolRefAttr >($_self)" >,
"variable symbol reference attribute"
> {
let storageType = [{ ::vast::core::VarSymbolRefAttr }];
let returnType = [{ ::llvm::StringRef }];
let valueType = NoneType;
let constBuilderCall =
"::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
let convertFromStorage = "$_self.getValue()";
}

def Core_TypeSymbolRefAttr : Attr<
CPred< "::llvm::isa< ::vast::core::TypeSymbolRefAttr >($_self)" >,
"type symbol reference attribute"
> {
let storageType = [{ ::vast::core::TypeSymbolRefAttr }];
let returnType = [{ ::llvm::StringRef }];
let valueType = NoneType;
let constBuilderCall =
"::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
let convertFromStorage = "$_self.getValue()";
}

def Core_FuncSymbolRefAttr : Attr<
CPred< "::llvm::isa< ::vast::core::FuncSymbolRefAttr >($_self)" >,
"function symbol reference attribute"
> {
let storageType = [{ ::vast::core::FuncSymbolRefAttr }];
let returnType = [{ ::llvm::StringRef }];
let valueType = NoneType;
let constBuilderCall =
"::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
let convertFromStorage = "$_self.getValue()";
}

def Core_LabelSymbolRefAttr : Attr<
CPred< "::llvm::isa< ::vast::core::LabelSymbolRefAttr >($_self)" >,
"label symbol reference attribute"
> {
let storageType = [{ ::vast::core::LabelSymbolRefAttr }];
let returnType = [{ ::llvm::StringRef }];
let valueType = NoneType;
let constBuilderCall =
"::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
let convertFromStorage = "$_self.getValue()";
}

def Core_EnumConstantSymbolRefAttr : Attr<
CPred< "::llvm::isa< ::vast::core::EnumConstantSymbolRefAttr >($_self)" >,
"enum constant symbol reference attribute"
> {
let storageType = [{ ::vast::core::EnumConstantSymbolRefAttr }];
let returnType = [{ ::llvm::StringRef }];
let valueType = NoneType;
let constBuilderCall =
"::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
let convertFromStorage = "$_self.getValue()";
}

def Core_MemberVarSymbolRefAttr : Attr<
CPred< "::llvm::isa< ::vast::core::MemberVarSymbolRefAttr >($_self)" >,
"member variable symbol reference attribute"
> {
let storageType = [{ ::vast::core::MemberVarSymbolRefAttr }];
let returnType = [{ ::llvm::StringRef }];
let valueType = NoneType;
let constBuilderCall =
"::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
let convertFromStorage = "$_self.getValue()";
}

def Core_ElaboratedTypeSymbolRefAttr : Attr<
CPred< "::llvm::isa< ::vast::core::ElaboratedTypeSymbolRefAttr >($_self)" >,
"elaborated type symbol reference attribute"
> {
let storageType = [{ ::vast::core::ElaboratedTypeSymbolRefAttr }];
let returnType = [{ ::llvm::StringRef }];
let valueType = NoneType;
let constBuilderCall =
"::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
let convertFromStorage = "$_self.getValue()";
}
#endif // VAST_DIALECT_CORE_COMMON_ATTR_CONSTRAINTS_TD
27 changes: 27 additions & 0 deletions include/vast/Dialect/Core/CoreAttributes.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022-present, Trail of Bits, Inc.

Check notice on line 1 in include/vast/Dialect/Core/CoreAttributes.hpp

View workflow job for this annotation

GitHub Actions / cpp-linter (19, 22.04)

Run clang-format on include/vast/Dialect/Core/CoreAttributes.hpp

File include/vast/Dialect/Core/CoreAttributes.hpp does not conform to Custom style guidelines. (lines 42, 43, 44, 45, 46, 47, 48)

#pragma once

Expand Down Expand Up @@ -36,4 +36,31 @@
Printer &printer, mlir::Operation *op, core::StorageClassAttr storage_class, core::TSClassAttr thread_storage_class
);

//
// Symbol Reference Attributes
//
struct VarSymbolRefAttr : mlir::FlatSymbolRefAttr {};
struct TypeSymbolRefAttr : mlir::FlatSymbolRefAttr {};
struct FuncSymbolRefAttr : mlir::FlatSymbolRefAttr {};
struct LabelSymbolRefAttr : mlir::FlatSymbolRefAttr {};
struct EnumConstantSymbolRefAttr : mlir::FlatSymbolRefAttr {};
struct MemberVarSymbolRefAttr : mlir::FlatSymbolRefAttr {};
struct ElaboratedTypeSymbolRefAttr : mlir::FlatSymbolRefAttr {};

using var_symbol_ref_attr = VarSymbolRefAttr;
using type_symbol_ref_attr = TypeSymbolRefAttr;
using func_symbol_ref_attr = FuncSymbolRefAttr;
using label_symbol_ref_attr = LabelSymbolRefAttr;
using enum_constant_symbol_ref_attr = EnumConstantSymbolRefAttr;
using member_var_symbol_ref_attr = MemberVarSymbolRefAttr;
using elaborated_type_symbol_ref_attr = ElaboratedTypeSymbolRefAttr;

} // namespace vast::core

MLIR_DECLARE_EXPLICIT_TYPE_ID(vast::core::VarSymbolRefAttr);
MLIR_DECLARE_EXPLICIT_TYPE_ID(vast::core::TypeSymbolRefAttr);
MLIR_DECLARE_EXPLICIT_TYPE_ID(vast::core::FuncSymbolRefAttr);
MLIR_DECLARE_EXPLICIT_TYPE_ID(vast::core::LabelSymbolRefAttr);
MLIR_DECLARE_EXPLICIT_TYPE_ID(vast::core::EnumConstantSymbolRefAttr);
MLIR_DECLARE_EXPLICIT_TYPE_ID(vast::core::MemberVarSymbolRefAttr);
MLIR_DECLARE_EXPLICIT_TYPE_ID(vast::core::ElaboratedTypeSymbolRefAttr);
29 changes: 5 additions & 24 deletions include/vast/Dialect/Core/SymbolTable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ VAST_UNRELAX_WARNINGS
#include "vast/Util/Common.hpp"
#include "vast/Util/TypeList.hpp"

#include "vast/Dialect/Core/CoreAttributes.hpp"
#include "vast/Dialect/Core/Interfaces/SymbolInterface.hpp"

#include <gap/coro/generator.hpp>
Expand Down Expand Up @@ -166,36 +167,16 @@ namespace vast::core {
return gmw::operations(symbol_table_op);
}

// Note: These are adapted from mlir::SymbolTable,
// try to keep them consistent when updating, or note the differences.
//
// Get an iterator range for all of the uses, for any symbol, that are
// nested within the given operation 'from'. This does not traverse into
// any nested symbol tables.
static symbol_use_range get_direct_symbol_uses(operation from);
static symbol_use_range get_direct_symbol_uses(region_ptr from);

// Get all of the uses of the given symbol that are nested within the given
// operation 'from'. This does not traverse into any nested symbol tables.
static symbol_use_range get_direct_symbol_uses(operation symbol, operation from);
static symbol_use_range get_direct_symbol_uses(string_attr symbol, operation from);
static symbol_use_range get_direct_symbol_uses(operation symbol, region_ptr from);
static symbol_use_range get_direct_symbol_uses(string_attr symbol, region_ptr from);

// Get an iterator range for all of the uses, for any symbol, that are
// nested within the given operation 'from'. In contrast to
// mlir::SymbolTable::getSymbolUses and `get_direct_symbol_uses` this
// function traverses into nested symbol tables.
static symbol_use_range get_symbol_uses(operation from);
static symbol_use_range get_symbol_uses(region_ptr from);
static symbol_use_range get_direct_symbol_uses(operation symbol, operation scope);
static symbol_use_range get_direct_symbol_uses(operation symbol, region_ptr scope);

// Get all of the uses of the given symbol that are nested within the given
// operation 'from'. In contrast to mlir::SymbolTable::getSymbolUses, this
// function traverses into nested symbol tables.
static symbol_use_range get_symbol_uses(operation symbol, operation from);
static symbol_use_range get_symbol_uses(string_attr symbol, operation from);
static symbol_use_range get_symbol_uses(operation symbol, region_ptr from);
static symbol_use_range get_symbol_uses(string_attr symbol, region_ptr from);
static symbol_use_range get_symbol_uses(operation symbol, operation scope);
static symbol_use_range get_symbol_uses(operation symbol, region_ptr scope);

protected:

Expand Down
17 changes: 0 additions & 17 deletions include/vast/Dialect/HighLevel/HighLevelAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,6 @@ class HighLevel_SymbolReferringAttr<string name, string attr_mnemonic>
let assemblyFormat = "`<` $symbol `>`";
}

class HighLevel_FlatSymbolReferringAttr<string name, string attr_mnemonic>
: HighLevel_Attr< name, attr_mnemonic >
{
let parameters = (ins "::mlir::FlatSymbolRefAttr":$symbol);

let builders = [
AttrBuilderWithInferredContext<(ins "::mlir::FlatSymbolRefAttr":$symbol), [{
return get(symbol.getContext(), symbol);
}]>,
AttrBuilder<(ins "::mlir::StringRef":$symbol), [{
return get(mlir::FlatSymbolRefAttr::get($_ctxt, symbol));
}]>,
];

let assemblyFormat = "`<` $symbol `>`";
}

def HighLevel_AnnotationAttr : HighLevel_NameAttr< "Annotation", "annotation" >;
def HighLevel_FlattentAttr : HighLevel_Attr< "Flatten", "flatten" >;
def HighLevel_FormatAttr : HighLevel_NameAttr< "Format", "format" >;
Expand Down
10 changes: 5 additions & 5 deletions include/vast/Dialect/HighLevel/HighLevelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def HighLevel_FieldDeclOp

def HighLevel_RecordMemberOp
: HighLevel_Op< "member" >
, Arguments<(ins AnyType:$record, FlatSymbolRefAttr:$field)>
, Arguments<(ins AnyType:$record, Core_MemberVarSymbolRefAttr:$field)>
, Results<(outs AnyType:$element)>
{
let summary = "VAST record element access operation";
Expand All @@ -358,7 +358,7 @@ def HighLevel_CallOp
DeclareOpInterfaceMethods<VastCallOpInterface, ["resolveCallable", "resolveCallableInTable"]>
] >
, Arguments<(ins
FlatSymbolRefAttr:$callee,
Core_FuncSymbolRefAttr:$callee,
Variadic<AnyType>:$argOperands
) >
, Results<(outs Variadic<AnyType>:$results)>
Expand Down Expand Up @@ -462,7 +462,7 @@ def HighLevel_ReturnOp

def HighLevel_DeclRefOp
: HighLevel_Op< "ref" >
, Arguments<(ins FlatSymbolRefAttr:$name)>
, Arguments<(ins Core_VarSymbolRefAttr:$name)>
, Results<(outs AnyType:$result)>
{
let summary = "VAST variable reference declaration";
Expand All @@ -473,7 +473,7 @@ def HighLevel_DeclRefOp

def HighLevel_FuncRefOp
: HighLevel_Op< "funcref" >
, Arguments<(ins FlatSymbolRefAttr:$function)>
, Arguments<(ins Core_FuncSymbolRefAttr:$function)>
, Results<(outs AnyType:$result)>
{
let summary = "VAST function reference declaration";
Expand All @@ -484,7 +484,7 @@ def HighLevel_FuncRefOp

def HighLevel_EnumRefOp
: HighLevel_Op< "enumref" >
, Arguments<(ins FlatSymbolRefAttr:$name)>
, Arguments<(ins Core_EnumConstantSymbolRefAttr:$name)>
, Results<(outs AnyType:$result)>
{
let summary = "VAST enum constant reference declaration";
Expand Down
8 changes: 8 additions & 0 deletions lib/vast/Dialect/Core/CoreAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,11 @@ namespace vast::core
}

} // namespace vast::core

MLIR_DEFINE_EXPLICIT_TYPE_ID(vast::core::VarSymbolRefAttr);
MLIR_DEFINE_EXPLICIT_TYPE_ID(vast::core::TypeSymbolRefAttr);
MLIR_DEFINE_EXPLICIT_TYPE_ID(vast::core::FuncSymbolRefAttr);
MLIR_DEFINE_EXPLICIT_TYPE_ID(vast::core::LabelSymbolRefAttr);
MLIR_DEFINE_EXPLICIT_TYPE_ID(vast::core::EnumConstantSymbolRefAttr);
MLIR_DEFINE_EXPLICIT_TYPE_ID(vast::core::MemberVarSymbolRefAttr);
MLIR_DEFINE_EXPLICIT_TYPE_ID(vast::core::ElaboratedTypeSymbolRefAttr);
Loading
Loading