Skip to content

Commit

Permalink
Implement translation of actions
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Korobeynikov <[email protected]>
  • Loading branch information
asl committed Feb 10, 2025
1 parent d773b6f commit 48e845e
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 32 deletions.
2 changes: 1 addition & 1 deletion include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def ActionOp : P4HIR_Op<"action", [
}

/// Returns attribute name used to store directions
llvm::StringRef getDirectionAttrName() const { return "p4hir.dir"; }
static llvm::StringRef getDirectionAttrName() { return "p4hir.dir"; }

/// Return the `i`th argument direction.
ParamDirection getArgumentDirection(unsigned i) {
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,17 @@ mlir::Region *P4HIR::ActionOp::getCallableRegion() { return &getBody(); }
void P4HIR::ActionOp::build(OpBuilder &builder, OperationState &result, llvm::StringRef name,
P4HIR::ActionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
result.addRegion();
result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type));

result.attributes.append(attrs.begin(), attrs.end());
if (argAttrs.empty()) return;

function_interface_impl::addArgAndResultAttrs(
builder, result, argAttrs,
/*resultAttrs=*/std::nullopt, getArgAttrsAttrName(result.name), builder.getStringAttr(""));

auto *region = result.addRegion();
Block &first = region->emplaceBlock();
for (auto argType : type.getInputs()) first.addArgument(argType, result.location);
}

void P4HIR::ActionOp::print(OpAsmPrinter &p) {
Expand Down Expand Up @@ -419,7 +420,7 @@ ParseResult P4HIR::ActionOp::parse(OpAsmParser &parser, OperationState &state) {
getArgAttrsAttrName(state.name),
builder.getStringAttr(""));

// Parse the action body. We need to strip out !p4hir.param wrappers types
// Parse the action body.
auto *body = state.addRegion();
ParseResult parseResult = parser.parseRegion(*body, arguments, /*enableNameShadowing=*/false);
if (failed(parseResult)) return failure();
Expand Down
17 changes: 17 additions & 0 deletions test/Translate/Ops/action.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s

// CHECK-LABEL: p4hir.action @foo(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir<dir in>}, %arg1: !p4hir.ref<!p4hir.int<10>> {p4hir.dir = #p4hir<dir inout>}, %arg2: !p4hir.ref<!p4hir.bit<16>> {p4hir.dir = #p4hir<dir out>}, %arg3: !p4hir.bit<16> {p4hir.dir = #p4hir<dir undir>})
// CHECK: p4hir.return
action foo(in bit<16> arg1, inout int<10> arg2, out bit<16> arg3, bit<16> arg4) {
bit<16> x = arg1;
arg3 = x;
if (arg1 == arg4) {
arg2 = arg2 + 1;
}
return;
}

// CHECK-LABEL: p4hir.action @bar(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir<dir undir>}) {
// CHECK: p4hir.return
action bar(bit<16> arg1) {
}
4 changes: 2 additions & 2 deletions test/Translate/Ops/assign.p4
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ action assign() {
res = lhs + rhs;
}

// CHECK-LABEL: module
// CHECK-NEXT: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bit<10> ["res"] : !p4hir.ref<!p4hir.bit<10>>
// CHECK-LABEL: p4hir.action @assign()
// CHECK: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bit<10> ["res"] : !p4hir.ref<!p4hir.bit<10>>
// CHECK: %[[VAL_1:.*]] = p4hir.const #p4hir.int<1> : !p4hir.bit<10>
// CHECK: %[[VAL_2:.*]] = p4hir.cast(%[[VAL_1]] : !p4hir.bit<10>) : !p4hir.bit<10>
// CHECK: %[[VAL_3:.*]] = p4hir.alloca !p4hir.bit<10> ["lhs", init] : !p4hir.ref<!p4hir.bit<10>>
Expand Down
5 changes: 3 additions & 2 deletions test/Translate/Ops/binop.p4
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ action int_binops() {
int<10> r13 = lhs ^ rhs;
}

// CHECK-LABEL: module
// CHECK-NEXT: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bit<10> ["res"] : !p4hir.ref<!p4hir.bit<10>>
// CHECK-LABEL: p4hir.action @bit_binops()
// CHECK: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bit<10> ["res"] : !p4hir.ref<!p4hir.bit<10>>
// CHECK: %[[VAL_1:.*]] = p4hir.const #p4hir.int<1> : !p4hir.bit<10>
// CHECK: %[[VAL_2:.*]] = p4hir.cast(%[[VAL_1]] : !p4hir.bit<10>) : !p4hir.bit<10>
// CHECK: %[[VAL_3:.*]] = p4hir.alloca !p4hir.bit<10> ["lhs", init] : !p4hir.ref<!p4hir.bit<10>>
Expand Down Expand Up @@ -139,6 +139,7 @@ action int_binops() {
// CHECK: %[[VAL_67:.*]] = p4hir.binop(xor, %[[VAL_65]], %[[VAL_66]]) : !p4hir.bit<10>
// CHECK: %[[VAL_68:.*]] = p4hir.alloca !p4hir.bit<10> ["r14", init] : !p4hir.ref<!p4hir.bit<10>>
// CHECK: p4hir.store %[[VAL_67]], %[[VAL_68]] : !p4hir.bit<10>, !p4hir.ref<!p4hir.bit<10>>
// CHECK-LABEL: p4hir.action @int_binops()
// CHECK: %[[VAL_69:.*]] = p4hir.alloca !p4hir.int<10> ["res"] : !p4hir.ref<!p4hir.int<10>>
// CHECK: %[[VAL_70:.*]] = p4hir.const #p4hir.int<1> : !p4hir.int<10>
// CHECK: %[[VAL_71:.*]] = p4hir.cast(%[[VAL_70]] : !p4hir.int<10>) : !p4hir.int<10>
Expand Down
4 changes: 2 additions & 2 deletions test/Translate/Ops/cmp.p4
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py

// CHECK-LABEL: module
// CHECK-NEXT: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bool ["res"] : !p4hir.ref<!p4hir.bool>
// CHECK-LABEL: p4hir.action @cmp()
// CHECK: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bool ["res"] : !p4hir.ref<!p4hir.bool>
// CHECK: %[[VAL_1:.*]] = p4hir.const #p4hir.int<1> : !p4hir.bit<10>
// CHECK: %[[VAL_2:.*]] = p4hir.cast(%[[VAL_1]] : !p4hir.bit<10>) : !p4hir.bit<10>
// CHECK: %[[VAL_3:.*]] = p4hir.alloca !p4hir.bit<10> ["lhs", init] : !p4hir.ref<!p4hir.bit<10>>
Expand Down
2 changes: 1 addition & 1 deletion test/Translate/Ops/scope.p4
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s

// CHECK-LABEL: module
// CHECK-LABEL: p4hir.action @scope()
action scope() {
bool res;
// Outer alloca
Expand Down
4 changes: 2 additions & 2 deletions test/Translate/Ops/unop.p4
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py

// CHECK-LABEL: module
// CHECK-NEXT: %[[VAL_0:.*]] = p4hir.const #p4hir.bool<true> : !p4hir.bool
// CHECK-LABEL: p4hir.action @foo()
// CHECK: %[[VAL_0:.*]] = p4hir.const #p4hir.bool<true> : !p4hir.bool
// CHECK: %[[VAL_1:.*]] = p4hir.alloca !p4hir.bool ["b0", init] : !p4hir.ref<!p4hir.bool>
// CHECK: p4hir.store %[[VAL_0]], %[[VAL_1]] : !p4hir.bool, !p4hir.ref<!p4hir.bool>
// CHECK: %[[VAL_2:.*]] = p4hir.const #p4hir.int<255> : !p4hir.int<32>
Expand Down
4 changes: 2 additions & 2 deletions test/Translate/Ops/variables.p4
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ action foo() {
bit<8> b10 = (bit<8>)b8;
}

// CHECK-LABEL: module
// CHECK-NEXT: %[[VAL_0:.*]] = p4hir.const #p4hir.int<255> : !p4hir.bit<32>
// CHECK-LABEL: p4hir.action @foo()
// CHECK: %[[VAL_0:.*]] = p4hir.const #p4hir.int<255> : !p4hir.bit<32>
// CHECK: %[[VAL_1:.*]] = p4hir.alloca !p4hir.bit<32> ["b0", init] : !p4hir.ref<!p4hir.bit<32>>
// CHECK: p4hir.store %[[VAL_0]], %[[VAL_1]] : !p4hir.bit<32>, !p4hir.ref<!p4hir.bit<32>>
// CHECK: %[[VAL_2:.*]] = p4hir.const #p4hir.int<255> : !p4hir.int<32>
Expand Down
136 changes: 120 additions & 16 deletions tools/p4mlir-translate/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@

#include "frontends/common/resolveReferences/resolveReferences.h"
#include "frontends/p4/typeMap.h"
#include "ir/ir-generated.h"
#include "ir/ir.h"
#include "ir/visitor.h"
#include "lib/big_int.h"
#include "lib/indent.h"
#include "lib/log.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h"
Expand Down Expand Up @@ -118,9 +122,11 @@ class P4TypeConverter : public P4::Inspector {
}

bool preorder(const P4::IR::Type_Name *name) override;
bool preorder(const P4::IR::Type_Action *act) override;

mlir::Type getType() { return type; }
mlir::Type getType() const { return type; }
bool setType(const P4::IR::Type *type, mlir::Type mlirType);
mlir::Type convert(const P4::IR::Type *type);

private:
P4HIRConverter &converter;
Expand Down Expand Up @@ -224,16 +230,14 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext {
node = resolvePath(pe->path, false)->checkedTo<P4::IR::Declaration>();
}

if (const auto *decl = node->to<P4::IR::Declaration_Variable>()) {
// Getting value out of variable involves involves a load.
auto alloca = p4Values.lookup(decl);
BUG_CHECK(alloca, "expected %1% (aka %2%) to be converted", node, dbp(node));
return builder.create<P4HIR::LoadOp>(getLoc(builder, node), alloca);
}
auto val = p4Values.lookup(node);
BUG_CHECK(val, "expected %1% (aka %2%) to be converted", node, dbp(node));

if (auto val = p4Values.lookup(node)) return val;
if (mlir::isa<P4HIR::ReferenceType>(val.getType()))
// Getting value out of variable involves a load.
return builder.create<P4HIR::LoadOp>(getLoc(builder, node), val);

BUG("expected %1% (aka %2%) to be converted", node, dbp(node));
return val;
}

mlir::Value setValue(const P4::IR::Node *node, mlir::Value value) {
Expand Down Expand Up @@ -264,13 +268,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext {
}

bool preorder(const P4::IR::P4Program *) override { return true; }
bool preorder(const P4::IR::P4Action *a) override {
// We cannot simply visit each node of the top-level block as
// ResolutionContext would not be able to resolve declarations there
// (sic!)
visit(a->body);
return false;
}
bool preorder(const P4::IR::P4Action *a) override;
bool preorder(const P4::IR::BlockStatement *block) override {
// If this is a top-level block where scope is implied (e.g. function,
// action, certain statements) do not create explicit scope.
Expand Down Expand Up @@ -333,6 +331,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext {

HANDLE_IN_POSTORDER(Cast)
HANDLE_IN_POSTORDER(Declaration_Variable)
HANDLE_IN_POSTORDER(ReturnStatement)

#undef HANDLE_IN_POSTORDER

Expand Down Expand Up @@ -387,12 +386,37 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Name *name) {
return false;
}

bool P4TypeConverter::preorder(const P4::IR::Type_Action *type) {
if ((this->type = converter.findType(type))) return false;

ConversionTracer trace("TypeConverting ", type);
llvm::SmallVector<mlir::Type, 4> argTypes;

BUG_CHECK(type->returnType == nullptr, "actions should not have return type set");
CHECK_NULL(type->parameters);

for (const auto *p : type->parameters->parameters) {
mlir::Type type = convert(p->type);
argTypes.push_back(p->hasOut() ? P4HIR::ReferenceType::get(type) : type);
}

auto mlirType = P4HIR::ActionType::get(converter.context(), argTypes);
return setType(type, mlirType);
}

bool P4TypeConverter::setType(const P4::IR::Type *type, mlir::Type mlirType) {
this->type = mlirType;
converter.setType(type, mlirType);
return false;
}

mlir::Type P4TypeConverter::convert(const P4::IR::Type *type) {
if ((this->type = converter.findType(type))) return getType();

visit(type);
return getType();
}

mlir::Value P4HIRConverter::resolveReference(const P4::IR::Node *node) {
// If this is a PathExpression, resolve it
if (const auto *pe = node->to<P4::IR::PathExpression>()) {
Expand Down Expand Up @@ -588,6 +612,8 @@ bool P4HIRConverter::preorder(const P4::IR::AssignmentStatement *assign) {
}

bool P4HIRConverter::preorder(const P4::IR::LOr *lor) {
ConversionTracer trace("Converting ", lor);

// Lower a || b into a ? true : b
visit(lor->left);

Expand All @@ -606,6 +632,8 @@ bool P4HIRConverter::preorder(const P4::IR::LOr *lor) {
}

bool P4HIRConverter::preorder(const P4::IR::LAnd *land) {
ConversionTracer trace("Converting ", land);

// Lower a && b into a ? b : false
visit(land->left);

Expand All @@ -624,6 +652,8 @@ bool P4HIRConverter::preorder(const P4::IR::LAnd *land) {
}

bool P4HIRConverter::preorder(const P4::IR::IfStatement *ifs) {
ConversionTracer trace("Converting ", ifs);

// Materialize condition first
visit(ifs->condition);

Expand All @@ -641,6 +671,80 @@ bool P4HIRConverter::preorder(const P4::IR::IfStatement *ifs) {
return false;
}

bool P4HIRConverter::preorder(const P4::IR::P4Action *act) {
ConversionTracer trace("Converting ", act);

// FIXME: Get rid of typeMap: ensure action knows its type
auto actType = mlir::cast<P4HIR::ActionType>(getOrCreateType(typeMap->getType(act, true)));
const auto &params = act->getParameters()->parameters;

// Create attributes for directions
llvm::SmallVector<mlir::DictionaryAttr, 4> argAttrs;
for (const auto *p : params) {
P4HIR::ParamDirection dir;
switch (p->direction) {
case P4::IR::Direction::None:
dir = P4HIR::ParamDirection::None;
break;
case P4::IR::Direction::In:
dir = P4HIR::ParamDirection::In;
break;
case P4::IR::Direction::Out:
dir = P4HIR::ParamDirection::Out;
break;
case P4::IR::Direction::InOut:
dir = P4HIR::ParamDirection::InOut;
break;
};

mlir::NamedAttribute dirAttr(
mlir::StringAttr::get(context(), P4HIR::ActionOp::getDirectionAttrName()),
P4HIR::ParamDirectionAttr::get(context(), dir));

argAttrs.emplace_back(mlir::DictionaryAttr::get(context(), dirAttr));
}
assert(actType.getNumInputs() == argAttrs.size() && "invalid parameter conversion");

auto action =
builder.create<P4HIR::ActionOp>(getLoc(builder, act), act->name.string_view(), actType,
llvm::ArrayRef<mlir::NamedAttribute>(), argAttrs);

// Iterate over parameters again binding parameter values to arguments of first BB
auto &body = action.getBody();

assert(body.getNumArguments() == params.size() && "invalid parameter conversion");
for (auto [param, bodyArg] : llvm::zip(params, body.getArguments())) setValue(param, bodyArg);

// We cannot simply visit each node of the top-level block as
// ResolutionContext would not be able to resolve declarations there
// (sic!)
{
mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&body.front());
visit(act->body);

// Check if body's last block is not terminated.
mlir::Block &b = body.back();
if (!b.mightHaveTerminator()) {
builder.setInsertionPointToEnd(&b);
builder.create<P4HIR::ReturnOp>(getEndLoc(builder, act));
}
}

return false;
}

void P4HIRConverter::postorder(const P4::IR::ReturnStatement *ret) {
// TODO: ReturnOp is a terminator, so it cannot be in the middle of block;
// ensure nothing is created afterwards
if (ret->expression) {
auto retVal = getValue(ret->expression);
builder.create<P4HIR::ReturnOp>(getLoc(builder, ret), retVal);
} else {
builder.create<P4HIR::ReturnOp>(getLoc(builder, ret));
}
}

} // namespace

namespace P4::P4MLIR {
Expand Down

0 comments on commit 48e845e

Please sign in to comment.