forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDialectBuilder.hpp
502 lines (412 loc) · 19.8 KB
/
DialectBuilder.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
//===---- DialectBuilder.hpp - Helper functions for MLIR dialects -----===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file contains helper functions for building MLIR operations.
//
//===----------------------------------------------------------------------===//
#pragma once
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "src/Dialect/Mlir/IndexExpr.hpp"
namespace onnx_mlir {
struct DialectBuilder {
DialectBuilder(mlir::OpBuilder &b, mlir::Location loc) : b(b), loc(loc) {}
DialectBuilder(const DialectBuilder &db) : b(db.b), loc(db.loc) {}
virtual ~DialectBuilder() {}
DialectBuilder(DialectBuilder &&) = delete;
DialectBuilder &operator=(const DialectBuilder &) = delete;
DialectBuilder &&operator=(const DialectBuilder &&) = delete;
mlir::OpBuilder &getBuilder() const { return b; }
mlir::Location getLoc() const { return loc; }
protected:
mlir::OpBuilder &b;
mlir::Location loc;
};
//===----------------------------------------------------------------------===//
// Math Builder
//===----------------------------------------------------------------------===//
/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support. Code is adapted to support the DialectBuilder super-class
/// that facilitate the building of other dialect builders using another dialect
/// builder.
//===----------------------------------------------------------------------===//
// Original code for MathBuilder is copied from LLVM MLIR Utils.cpp
// Modified here to add operations, add super class.
// License added here for this class for completeness.
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
struct MathBuilder final : DialectBuilder {
MathBuilder(mlir::OpBuilder &b, mlir::Location loc)
: DialectBuilder(b, loc) {}
MathBuilder(const DialectBuilder &db) : DialectBuilder(db) {}
mlir::Value andi(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value ori(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value add(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value sub(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value mul(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value div(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value exp(mlir::Value val) const;
mlir::Value exp2(mlir::Value val) const;
mlir::Value log2(mlir::Value val) const;
mlir::Value sqrt(mlir::Value val) const;
mlir::Value pow(mlir::Value base, mlir::Value exp) const;
mlir::Value select(mlir::Value cmp, mlir::Value lhs, mlir::Value rhs) const;
mlir::Value sgt(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value sge(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value slt(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value sle(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value eq(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value neq(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value min(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value max(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value constant(mlir::Type type, double val) const;
mlir::Value constantIndex(int64_t val) const;
/// Emit a negative infinity constant of a specific type. Supported types:
/// F16, F32, F64, Int8, Int16, Int32, Int64. In case of Float, emit the
/// negative of the positive infinity. In case of Integer, emit the minimum
/// mlir::Value.
mlir::Value negativeInf(mlir::Type type) const;
/// Emit a positive infinity constant of a specific type. Supported types:
/// F16, F32, F64, Int8, Int16, Int32, Int64. In case of Integer, emit the
/// maximum mlir::Value.
mlir::Value positiveInf(mlir::Type type) const;
// Cast handle bool/int/float/index elementary types. Do not convert
// signed/index to unsigned.
mlir::Value cast(mlir::Type destType, mlir::Value val) const;
mlir::Value castToIndex(mlir::Value val) const;
// Add indexOffsets to the least significant indices. So if indices are (i, j,
// k, l) and offsets are (K, L), the results will be (i, j, k+K, l+L).
void addOffsetToLeastSignificant(mlir::ValueRange indices,
mlir::ValueRange offsets,
llvm::SmallVectorImpl<mlir::Value> &computedIndices) const;
void addOffsetToLeastSignificant(mlir::ArrayRef<IndexExpr> indices,
mlir::ValueRange offsets,
llvm::SmallVectorImpl<mlir::Value> &computedIndices) const;
private:
mlir::Value createArithCmp(
mlir::Value lhs, mlir::Value rhs, mlir::arith::CmpIPredicate pred) const;
mlir::Value createArithCmp(
mlir::Value lhs, mlir::Value rhs, mlir::arith::CmpFPredicate pred) const;
mlir::Value castToSignless(mlir::Value source, int64_t width) const;
mlir::Value castToUnsigned(mlir::Value source, int64_t width) const;
};
//===----------------------------------------------------------------------===//
// MemRef Builder with added support for aligned memory
//===----------------------------------------------------------------------===//
struct MemRefBuilder final : DialectBuilder {
MemRefBuilder(mlir::OpBuilder &b, mlir::Location loc)
: DialectBuilder(b, loc) {}
MemRefBuilder(const DialectBuilder &db) : DialectBuilder(db) {}
mlir::memref::AllocOp alloc(mlir::MemRefType type) const;
mlir::memref::AllocOp alloc(
mlir::MemRefType type, mlir::ValueRange dynSymbols) const;
mlir::memref::AllocOp alignedAlloc(
mlir::MemRefType type, int64_t align = -1) const;
mlir::memref::AllocOp alignedAlloc(mlir::MemRefType type,
mlir::ValueRange dynSymbols, int64_t align = -1) const;
// The alloca instruction allocates memory on the stack frame of the currently
// executing function, to be automatically released when this function returns
// to its caller. It is strongly suggested to place alloca instructions
// outside of a loop.
mlir::memref::AllocaOp alloca(mlir::MemRefType type) const;
mlir::memref::AllocaOp alignedAlloca(
mlir::MemRefType type, int64_t align = -1) const;
mlir::memref::DeallocOp dealloc(mlir::Value val) const;
mlir::memref::CastOp cast(
mlir::Value input, mlir::MemRefType outputType) const;
mlir::Value reinterpretCast(
mlir::Value input, llvm::SmallVectorImpl<IndexExpr> &outputDims) const;
mlir::Value dim(mlir::Value val, int64_t index) const;
mlir::Value dim(mlir::Value val, mlir::Value index) const;
};
// Default alignment attribute for all allocation of memory. On most system, it
// is 16 bytes.
static constexpr int64_t gDefaultAllocAlign = 16;
//===----------------------------------------------------------------------===//
// Structured Control Flow (SCF) Builder
//===----------------------------------------------------------------------===//
struct SCFBuilder final : DialectBuilder {
SCFBuilder(mlir::OpBuilder &b, mlir::Location loc) : DialectBuilder(b, loc) {}
SCFBuilder(const DialectBuilder &db) : DialectBuilder(db) {}
/// Create an if then with optional else. Construct does not generate a result
/// (unlike some scf::if) and introduces the yields automatically.
void ifThenElse(mlir::Value cond,
mlir::function_ref<void(SCFBuilder &createSCF)> thenFn,
mlir::function_ref<void(SCFBuilder &createSCF)> elseFn = nullptr) const;
void yield() const;
};
//===----------------------------------------------------------------------===//
// Vector Builder
//===----------------------------------------------------------------------===//
struct VectorBuilder final : DialectBuilder {
VectorBuilder(mlir::OpBuilder &b, mlir::Location loc)
: DialectBuilder(b, loc) {}
VectorBuilder(const DialectBuilder &db) : DialectBuilder(db) {}
// Get the machine SIMD vector length for the given elementary type.
// This can help guide certain optimizations.
int64_t getMachineVectorLength(const mlir::Type &elementType) const;
int64_t getMachineVectorLength(const mlir::VectorType &vecType) const;
int64_t getMachineVectorLength(mlir::Value vecValue) const;
mlir::Value load(mlir::VectorType vecType, mlir::Value memref,
mlir::ValueRange indices = {}) const;
// When ranks of offsets<indices, add offsets to the least significant dims.
mlir::Value load(mlir::VectorType vecType, mlir::Value memref,
mlir::ValueRange indices, mlir::ValueRange offsets) const;
mlir::Value loadIE(mlir::VectorType vecType, mlir::Value memref,
llvm::ArrayRef<IndexExpr> indices, mlir::ValueRange offsets) const;
void store(
mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}) const;
// When ranks of offsets<indices, add offsets to the least significant dims.
void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices,
mlir::ValueRange offsets) const;
void storeIE(mlir::Value val, mlir::Value memref,
llvm::ArrayRef<IndexExpr> indices, mlir::ValueRange offsets) const;
mlir::Value broadcast(mlir::VectorType vecType, mlir::Value val) const;
mlir::Value shuffle(mlir::Value lhs, mlir::Value rhs,
llvm::SmallVectorImpl<int64_t> &mask) const;
mlir::Value fma(mlir::Value lhs, mlir::Value rhs, mlir::Value acc) const;
// Composite functions.
mlir::Value mergeHigh(mlir::Value lhs, mlir::Value rhs, int64_t step) const;
mlir::Value mergeLow(mlir::Value lhs, mlir::Value rhs, int64_t step) const;
void multiReduction(llvm::SmallVectorImpl<mlir::Value> &inputVecArray,
llvm::SmallVectorImpl<mlir::Value> &outputVecArray);
private:
bool isPowerOf2(uint64_t num) const;
uint64_t getLengthOf1DVector(mlir::Value vec) const;
};
//===----------------------------------------------------------------------===//
// Affine Builder
//===----------------------------------------------------------------------===//
template <class LOAD_OP, class STORE_OP>
struct GenericAffineBuilder final : DialectBuilder {
GenericAffineBuilder(mlir::OpBuilder &b, mlir::Location loc)
: DialectBuilder(b, loc) {}
GenericAffineBuilder(const DialectBuilder &db) : DialectBuilder(db) {}
mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {}) const;
// When ranks of offsets<indices, add offsets to the least significant dims.
mlir::Value load(mlir::Value memref, mlir::ValueRange indices,
mlir::ValueRange offsets) const;
mlir::Value loadIE(mlir::Value memref, llvm::ArrayRef<IndexExpr> indices,
mlir::ValueRange offsets) const;
void store(
mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}) const;
// When ranks of offsets<indices, add offsets to the least significant dims.
void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices,
mlir::ValueRange offsets) const;
void storeIE(mlir::Value val, mlir::Value memref,
llvm::ArrayRef<IndexExpr> indices, mlir::ValueRange offsets) const;
void forIE(IndexExpr lb, IndexExpr ub, int64_t step,
mlir::function_ref<void(GenericAffineBuilder &, mlir::Value)> builderFn)
const;
void forIE(llvm::SmallVectorImpl<IndexExpr> &lbs,
llvm::SmallVectorImpl<IndexExpr> &ubs,
llvm::SmallVectorImpl<int64_t> &steps,
mlir::function_ref<void(GenericAffineBuilder &, mlir::ValueRange)>
builderFn) const;
// This if then else construct has no arguments to the blocks.
void ifThenElse(IndexExprScope &scope,
llvm::SmallVectorImpl<IndexExpr> &conditions,
mlir::function_ref<void(GenericAffineBuilder &createAffine)> thenFn,
mlir::function_ref<void(GenericAffineBuilder &createAffine)> elseFn)
const;
void yield() const;
private:
// Support for multiple forIE loops.
void recursionForIE(llvm::SmallVectorImpl<IndexExpr> &lbs,
llvm::SmallVectorImpl<IndexExpr> &ubs,
llvm::SmallVectorImpl<int64_t> &steps,
llvm::SmallVectorImpl<mlir::Value> &loopIndices,
mlir::function_ref<void(GenericAffineBuilder &, mlir::ValueRange)>
builderFn) const;
// Support for adding blocks.
void appendToBlock(mlir::Block *block,
mlir::function_ref<void(mlir::ValueRange)> builderFn) const;
};
// Affine builder uses affine load and store for memory operations. A later
// definition of AffineBuilderKrnlMem will use Krnl load and store for memory
// operations. We recommend to use AffineBuilderKrnlMem when converting the Krnl
// dialect into the affine dialect.
using AffineBuilder =
GenericAffineBuilder<mlir::AffineLoadOp, mlir::AffineStoreOp>;
//===----------------------------------------------------------------------===//
// LLVM Builder
//===----------------------------------------------------------------------===//
struct LLVMBuilder final : DialectBuilder {
using voidFuncRef = mlir::function_ref<void(LLVMBuilder &createLLVM)>;
using valueFuncRef = mlir::function_ref<mlir::Value(LLVMBuilder &createLLVM)>;
LLVMBuilder(mlir::OpBuilder &b, mlir::Location loc)
: DialectBuilder(b, loc) {}
LLVMBuilder(const DialectBuilder &db) : DialectBuilder(db) {}
// AddressOfOp
mlir::Value addressOf(mlir::LLVM::GlobalOp op) const;
// AllocaOp
mlir::Value _alloca(
mlir::Type resultType, mlir::Value size, int64_t alignment) const;
// BitcastOp
mlir::Value bitcast(mlir::Type type, mlir::Value val) const;
mlir::Value bitcastI8Ptr(mlir::Value val) const;
mlir::Value bitcastI8PtrPtr(mlir::Value val) const;
// BrOp
void br(
llvm::ArrayRef<mlir::Value> destOperands, mlir::Block *destBlock) const;
// CallOp
mlir::Value call(mlir::ArrayRef<mlir::Type> resultTypes,
llvm::StringRef funcName, mlir::ArrayRef<mlir::Value> inputs) const;
mlir::Value call(mlir::ArrayRef<mlir::Type> resultTypes,
mlir::FlatSymbolRefAttr funcSymbol,
mlir::ArrayRef<mlir::Value> inputs) const;
// CondBrOp
void condBr(mlir::Value cond, mlir::Block *trueBlock,
llvm::ArrayRef<mlir::Value> trueOperands, mlir::Block *falseBlock,
llvm::ArrayRef<mlir::Value> falseOperands) const;
// ConstantOp
mlir::Value constant(mlir::Type type, int64_t val) const;
mlir::Value constant(mlir::Type type, double val) const;
// ExtractValueOp
mlir::Value extractValue(mlir::Type resultType, mlir::Value container,
llvm::ArrayRef<int64_t> position) const;
// FuncOp
mlir::LLVM::LLVMFuncOp func(llvm::StringRef name, mlir::Type type) const;
// GEPOp
mlir::Value getElemPtr(mlir::Type resultType, mlir::Value base,
llvm::ArrayRef<mlir::Value> indices) const;
// GlobalOp
mlir::LLVM::GlobalOp globalOp(mlir::Type resultType, bool isConstant,
mlir::LLVM::Linkage, llvm::StringRef name, mlir::Attribute attr,
uint64_t alignment = 0) const;
// ICmpOp
mlir::Value icmp(
mlir::LLVM::ICmpPredicate cond, mlir::Value lhs, mlir::Value rhs) const;
// InsertValueOp
mlir::Value insertValue(mlir::Type resultType, mlir::Value container,
mlir::Value val, llvm::ArrayRef<int64_t> position) const;
// LoadOp
mlir::Value load(mlir::Value addr) const;
// NullOp
mlir::Value null(mlir::Type type) const;
mlir::Value nullI8Ptr() const;
// ReturnOp
void _return(mlir::Value val) const;
// StoreOp
void store(mlir::Value val, mlir::Value addr) const;
//===--------------------------------------------------------------------===//
// Helper functions
//===--------------------------------------------------------------------===//
// Get or insert a function declaration at the beginning of the module.
mlir::FlatSymbolRefAttr getOrInsertSymbolRef(mlir::ModuleOp module,
llvm::StringRef symName, mlir::Type resultType,
llvm::ArrayRef<mlir::Type> operandTypes, bool isVarArg = false) const;
/// Generate code that looks like "if then with optional else" at LLVM.
/// The following prototype code will be generated:
/// ```
/// llvm.condBr cond, ^thenBlock, ^elseBlock
/// ^thenBlock:
/// thenBody
/// ^elseBlock:
/// elseBody
/// ^mainBlock
/// ...
/// ```
void ifThenElse(valueFuncRef cond, voidFuncRef thenFn,
voidFuncRef elseFn = nullptr) const;
};
//===----------------------------------------------------------------------===//
// Multi Dialect Builder
//===----------------------------------------------------------------------===//
/*
Instead of creating multiple builders, e.g.
KrnlBuilder createKrnl(rewriter, loc);
MathBuilder createMath(createKrnl);
MemRefBuilder createMemRef(createKrnl);
createKrnl.defineLoop(1);
createMath.add(i1, i2);
createMemRef.alloca(type);
We can create a single builder composed of multiple types
MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder>
create(rewriter, loc);
create.krnl.defineLoop(1);
create.math.add(i1, i2);
create.mem.alloca(type);
Types that can be used here are
* AffineBuilder, access field with affine
* AffineBuilderKrnlMem, access field with affineKMem
* KrnlBuilder, access field with krnl
* MathBuilder, access field with math
* MemRefBuilder, access field with mem
* ONNXBuilder, access field with onnx
* SCFBuilder, access field with scf
*/
// Anchor class.
template <class... Ts>
struct MultiDialectBuilder {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc) {}
MultiDialectBuilder(const DialectBuilder &db) {}
};
// Recursive class specialized for MathBuilder refereed to as math.
template <class... Ts>
struct MultiDialectBuilder<MathBuilder, Ts...> : MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), math(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), math(db) {}
MathBuilder math;
};
// Recursive class specialized for MemRefBuilder refereed to as mem.
template <class... Ts>
struct MultiDialectBuilder<MemRefBuilder, Ts...> : MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), mem(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), mem(db) {}
MemRefBuilder mem;
};
// Recursive class specialized for AffineBuilder refereed to as affine.
template <class... Ts>
struct MultiDialectBuilder<AffineBuilder, Ts...> : MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), affine(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), affine(db) {}
AffineBuilder affine;
};
// Recursive class specialized for SCFBuilder refereed to as scf.
template <class... Ts>
struct MultiDialectBuilder<SCFBuilder, Ts...> : MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), scf(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), scf(db) {}
SCFBuilder scf;
};
// Recursive class specialized for VectorBuilder refereed to as vec.
template <class... Ts>
struct MultiDialectBuilder<VectorBuilder, Ts...> : MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), vec(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), vec(db) {}
VectorBuilder vec;
};
// Recursive class specialized for LLVMBuilder refereed to as llvm.
template <class... Ts>
struct MultiDialectBuilder<LLVMBuilder, Ts...> : MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), llvm(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), llvm(db) {}
LLVMBuilder llvm;
};
// Include template implementations.
#include "DialectBuilder.hpp.inc"
} // namespace onnx_mlir