Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into fi…
Browse files Browse the repository at this point in the history
…x-for-demo

Signed-off-by: Anna Gringauze <[email protected]>
  • Loading branch information
annagrin committed Feb 13, 2025
2 parents e178987 + de02e57 commit a93ba6e
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 14 deletions.
41 changes: 39 additions & 2 deletions include/cudaq/Optimizer/Dialect/CC/CCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1459,8 +1459,8 @@ def cc_CallCallableOp : CCOp<"call_callable", [CallOpInterface]> {
}];
}

def cc_CallIndirectCallableOp : CCOp<"call_indirect_callable",
[CallOpInterface]> {
def cc_CallIndirectCallableOp :
CCOp<"call_indirect_callable", [CallOpInterface]> {
let summary = "Call a C++ callable, unresolved, at run-time.";
let description = [{
This effectively connects a call from one kernel to another kernel, which
Expand Down Expand Up @@ -1649,6 +1649,43 @@ def cc_CallableClosureOp : CCOp<"callable_closure", [Pure]> {
}];
}

def cc_VarargCallOp :
CCOp<"call_vararg", [CallOpInterface, SymbolUserOpInterface]> {
let summary = "Create a call to an llvm.func with variadic arguments.";
let description = [{
This operation lets us create a call to an LLVMIR FuncOp with variadic
arguments without the restriction that all the arguments have to be
converted to LLVMIR types first. These conversions are just code bloat and
make the code harder to read.
}];

let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$args
);
let results = (outs Variadic<AnyType>);

let assemblyFormat = [{
$callee `(` $args `)` `:` functional-type(operands, results) attr-dict
}];

let extraClassDeclaration = [{
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}

operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }

/// Return the callee of this operation.
mlir::CallInterfaceCallable getCallableForCallee() {
return getCalleeAttr();
}

mlir::LogicalResult verifySymbolUses(mlir::SymbolTableCollection &);
}];
}

def cc_CreateStringLiteralOp : CCOp<"string_literal"> {
let summary = "Create a constant string literal.";
let description = [{
Expand Down
35 changes: 25 additions & 10 deletions lib/Optimizer/CodeGen/CCToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,18 +710,33 @@ class UndefOpPattern : public ConvertOpToLLVMPattern<cudaq::cc::UndefOp> {
return success();
}
};

class VarargCallPattern
: public ConvertOpToLLVMPattern<cudaq::cc::VarargCallOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(cudaq::cc::VarargCallOp vcall, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> types;
for (auto ty : vcall.getResultTypes())
types.push_back(getTypeConverter()->convertType(ty));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(vcall, types, vcall.getCallee(),
adaptor.getArgs());
return success();
}
};
} // namespace

void cudaq::opt::populateCCToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.insert<AddressOfOpPattern, AllocaOpPattern, CallableClosureOpPattern,
CallableFuncOpPattern, CallCallableOpPattern,
CallIndirectCallableOpPattern, CastOpPattern,
ComputePtrOpPattern, CreateStringLiteralOpPattern,
ExtractValueOpPattern, FuncToPtrOpPattern, GlobalOpPattern,
InsertValueOpPattern, InstantiateCallableOpPattern,
LoadOpPattern, OffsetOfOpPattern, PoisonOpPattern,
SizeOfOpPattern, StdvecDataOpPattern, StdvecInitOpPattern,
StdvecSizeOpPattern, StoreOpPattern, UndefOpPattern>(
typeConverter);
patterns.insert<
AddressOfOpPattern, AllocaOpPattern, CallableClosureOpPattern,
CallableFuncOpPattern, CallCallableOpPattern,
CallIndirectCallableOpPattern, CastOpPattern, ComputePtrOpPattern,
CreateStringLiteralOpPattern, ExtractValueOpPattern, FuncToPtrOpPattern,
GlobalOpPattern, InsertValueOpPattern, InstantiateCallableOpPattern,
LoadOpPattern, OffsetOfOpPattern, PoisonOpPattern, SizeOfOpPattern,
StdvecDataOpPattern, StdvecInitOpPattern, StdvecSizeOpPattern,
StoreOpPattern, UndefOpPattern, VarargCallPattern>(typeConverter);
}
4 changes: 2 additions & 2 deletions lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,8 +1150,8 @@ struct QuantumGatePattern : public OpConversionPattern<OP> {
args.append(opTargs.begin(), opTargs.end());

// Call the generalized version of the gate invocation.
rewriter.create<LLVM::CallOp>(loc, TypeRange{},
cudaq::opt::NVQIRGeneralizedInvokeAny, args);
rewriter.create<cudaq::cc::VarargCallOp>(
loc, TypeRange{}, cudaq::opt::NVQIRGeneralizedInvokeAny, args);
return forwardOrEraseOp();
}

Expand Down
44 changes: 44 additions & 0 deletions lib/Optimizer/Dialect/CC/CCOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2347,6 +2347,50 @@ LogicalResult cudaq::cc::UnwindReturnOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// VarargCallOp
//===----------------------------------------------------------------------===//

LogicalResult
cudaq::cc::VarargCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return emitOpError("requires a 'callee' symbol reference attribute");
LLVM::LLVMFuncOp fn =
symbolTable.lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(*this, fnAttr);
if (!fn)
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid LLVM function";

// Verify that the operand and result types match the callee.
auto fnType = fn.getFunctionType();
if (fnType.getNumParams() > getNumOperands())
return emitOpError("incorrect number of operands for callee");

for (unsigned i = 0, e = fnType.getNumParams(); i != e; ++i)
if (getOperand(i).getType() != fnType.getParams()[i]) {
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getParams()[i] << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
}

if (fnType.getReturnType() == LLVM::LLVMVoidType::get(getContext()) &&
getNumResults() == 0)
return success();

if (getNumResults() > 1)
return emitOpError("wrong number of result types: ") << getNumResults();

if (getResult(1).getType() != fnType.getReturnType()) {
auto diag = emitOpError("result type mismatch ");
diag.attachNote() << " op result types: " << getResultTypes();
diag.attachNote() << "function result types: " << fnType.getReturnType();
return diag;
}
return success();
}

//===----------------------------------------------------------------------===//
// Generated logic
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions test/Quake/roundtrip-ops.qke
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,23 @@ func.func @indirect_callable1(%arg : !cc.indirect_callable<() -> ()>) {
// CHECK: return
// CHECK: }

func.func @varargs_test() {
%1 = arith.constant 12 : i32
%2 = cc.undef !cc.ptr<none>
cc.call_vararg @my_variadic(%1, %2) : (i32, !cc.ptr<none>) -> ()
return
}

llvm.func @my_variadic(i32, ...)

// CHECK-LABEL: func.func @varargs_test() {
// CHECK: %[[VAL_0:.*]] = arith.constant 12 : i32
// CHECK: %[[VAL_1:.*]] = cc.undef !cc.ptr<none>
// CHECK: cc.call_vararg @my_variadic(%[[VAL_0]], %[[VAL_1]]) : (i32, !cc.ptr<none>) -> ()
// CHECK: return
// CHECK: }
// CHECK: llvm.func @my_variadic(i32, ...)

func.func @indirect_callable2(%arg : !cc.indirect_callable<(i32) -> i64>) -> i64 {
%cst = arith.constant 4 : i32
%0 = cc.call_indirect_callable %arg, %cst : (!cc.indirect_callable<(i32) -> i64>, i32) -> i64
Expand Down

0 comments on commit a93ba6e

Please sign in to comment.