forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontext.h
223 lines (172 loc) · 6.46 KB
/
context.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <ATen/core/ivalue.h>
#include <c10/core/ScalarType.h>
namespace torch::jit::mobile::nnc {
// Specify the requirements on an input tensor.
// TODO: support input tensor with dynamic shape (PR #54982)
struct TORCH_API InputSpec {
InputSpec() = default;
// Deserialize the spec from an IValue.
explicit InputSpec(const c10::IValue& value);
// Serialize the spec into an IValue.
[[nodiscard]] c10::IValue serialize() const;
// Check whether the input tensor adheres to the spec.
[[nodiscard]] bool validate(const at::Tensor& input) const;
std::vector<int64_t> sizes_;
c10::ScalarType dtype_{c10::ScalarType::Undefined};
};
// Specify the sizes/dtype/... of output tensor to preallocate the output.
// TODO: support the case where kernel allocates output tensors dynamically.
struct TORCH_API OutputSpec {
OutputSpec() = default;
// Deserialize the spec from an IValue.
explicit OutputSpec(const c10::IValue& value);
// Serialize the spec into an IValue.
[[nodiscard]] c10::IValue serialize() const;
// Allocate an output tensor in accordance with the spec.
[[nodiscard]] at::Tensor allocate() const;
std::vector<int64_t> sizes_;
c10::ScalarType dtype_{c10::ScalarType::Undefined};
std::optional<double> qscale_;
std::optional<int64_t> qzero_;
};
// Hold the temporary buffers / states needed during the execution.
struct TORCH_API ExecutionState {
ExecutionState() = default;
ExecutionState(const ExecutionState&) = delete;
ExecutionState(ExecutionState&&) = default;
ExecutionState& operator=(const ExecutionState&) = delete;
ExecutionState& operator=(ExecutionState&&) = default;
// Preallocated buffers needed by the NNC kernel.
std::vector<c10::DataPtr> preallocations_;
// The NNC kernel expects the following arguments layout:
// input tensor 1
// ...
// input tensor INPUT_NUM
// output tensor 1
// ...
// output tensor OUTPUT_NUM
// parameter tensor 1
// ...
// parameter tensor PARAM_NUM
// temporary buffer 1
// ...
// temporary buffer BUFFER_NUM
std::vector<void*> arguments_;
};
// Specify how to allocate temporary buffers at initialization.
struct TORCH_API MemoryPlan {
MemoryPlan() = default;
explicit MemoryPlan(const c10::IValue& value);
[[nodiscard]] c10::IValue serialize() const;
void allocate(ExecutionState* state) const;
std::vector<int64_t> buffer_sizes_;
};
// Location of a symbolic shape among dimensions of the inputs
struct TORCH_API SymbolicShapePosition {
SymbolicShapePosition() = default;
SymbolicShapePosition(int64_t input_idx, int64_t dim_idx)
: input_idx_(input_idx), dim_idx_(dim_idx) {}
int64_t input_idx_;
int64_t dim_idx_;
};
// Represents a compiled NNC function which has a 1-1 correspondence with a
// `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function.
class TORCH_API Function {
public:
explicit Function() = default;
// Deserialize from an IValue that is generated by the 'serialize()' method.
explicit Function(const c10::IValue& value);
// Serialize into an IValue.
c10::IValue serialize() const;
// Execute the compiled NNC function.
c10::impl::GenericList run(const c10::impl::GenericList& inputs) const;
// The name of the function as specified in the model code.
c10::QualifiedName name() const {
return name_;
}
void set_name(const c10::QualifiedName& name) {
name_ = name;
}
// The unique id of the generated NNC kernel corresponding to the function.
const std::string& nnc_kernel_id() const {
return nnc_kernel_id_;
}
void set_nnc_kernel_id(const std::string& name) {
nnc_kernel_id_ = name;
}
// The parameters (e.g. weights / bias tensors) to be passed to the generated
// NNC kernel.
const c10::impl::GenericList& parameters() const {
return parameters_;
}
void set_parameters(const c10::impl::GenericList& parameters) {
parameters_ = parameters;
}
const std::vector<InputSpec>& input_specs() const {
return input_specs_;
}
void set_input_specs(const std::vector<InputSpec>& input_specs) {
input_specs_ = input_specs;
}
const std::vector<OutputSpec>& output_specs() const {
return output_specs_;
}
void set_output_specs(const std::vector<OutputSpec>& output_specs) {
output_specs_ = output_specs;
}
const MemoryPlan& memory_plan() const {
return memory_plan_;
}
void set_memory_plan(const MemoryPlan& memory_plan) {
memory_plan_ = memory_plan;
}
const std::vector<SymbolicShapePosition>& sym_shape_positions() const {
return sym_shape_positions_;
}
void set_sym_shape_positions(
const std::vector<SymbolicShapePosition>& sym_shape_pos) {
sym_shape_positions_ = sym_shape_pos;
}
private:
void init_execution_state() const;
c10::QualifiedName name_;
std::string nnc_kernel_id_;
c10::impl::GenericList parameters_{at::AnyType::get()};
std::vector<InputSpec> input_specs_;
std::vector<OutputSpec> output_specs_;
std::vector<SymbolicShapePosition> sym_shape_positions_;
MemoryPlan memory_plan_;
mutable std::unique_ptr<ExecutionState> execution_state_;
};
// CompilationUnit consists of a set of compiled NNC functions. It has a 1-1
// correspondence with a `Module`.
// It's similar as torch::jit::mobile::CompilationUnit.
class TORCH_API CompilationUnit {
public:
CompilationUnit() = default;
CompilationUnit(const CompilationUnit&) = delete;
CompilationUnit(CompilationUnit&&) = default;
CompilationUnit& operator=(const CompilationUnit&) = delete;
CompilationUnit& operator=(CompilationUnit&&) = default;
// Deserialize from an IValue that is generated by the 'serialize()' method.
explicit CompilationUnit(const c10::IValue& value);
// Serialize all registered functions into an IValue. The IValue will be save
// into the compiled TorchScript model file ahead-of-time on the host, and
// will be deserialized at runtime on the target device.
[[nodiscard]] c10::IValue serialize() const;
// Execute a registered function.
[[nodiscard]] c10::impl::GenericList run(
const c10::QualifiedName& function_name,
const c10::impl::GenericList& inputs) const;
// Register a function to the compilation unit.
void register_function(std::unique_ptr<Function> fn);
private:
[[nodiscard]] Function* find_function(const c10::QualifiedName& qn) const;
std::unordered_map<c10::QualifiedName, std::unique_ptr<Function>> functions_;
};
} // namespace torch::jit::mobile::nnc