forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpickle.h
140 lines (122 loc) · 4.58 KB
/
pickle.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/util/ArrayRef.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/unpickler.h>
namespace torch::jit {
/// Pickle an IValue by calling a function to handle writing the data.
///
/// `writer` is a function that takes in a pointer to a chunk of memory and its
/// size and consumes it.
///
/// See `jit::pickle` for more details.
TORCH_API void pickle(
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table = nullptr);
/// Save a `torch::IValue` in a format compatible with Python's `pickle` module
///
/// If present, `tensor_table` is a pointer to a table in which tensors that
/// are contained within `ivalue` are stored, and the bytes returned by the
/// pickler will only include references to these tensors in the table. This can
/// be used to keep the binary blob size small.
/// If not provided, tensors are stored in the same byte stream as the pickle
/// data, similar to `torch.save()` in eager Python.
///
/// Pickled values can be loaded in Python and C++:
/// \rst
/// .. code-block:: cpp
///
/// torch::IValue float_value(2.3);
///
/// // TODO: when tensors are stored in the pickle, delete this
/// std::vector<at::Tensor> tensor_table;
/// auto data = torch::jit::pickle(float_value, &tensor_table);
///
/// std::vector<torch::IValue> ivalues =
/// torch::jit::unpickle(data.data(), data.size());
///
/// .. code-block:: python
///
/// values = torch.load('data.pkl')
/// print(values)
///
/// \endrst
TORCH_API std::vector<char> pickle(
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table = nullptr);
/// Save a `torch::IValue` in a format that can be loaded by both
/// `torch::pickle_load` in C++ and `torch.load` in Python.
TORCH_API std::vector<char> pickle_save(const IValue& ivalue);
/// Deserialize a `torch::IValue` from bytes produced by either
/// `torch::pickle_save` in C++ or `torch.save` in Python
TORCH_API IValue pickle_load(const std::vector<char>& data);
/// Deserialize a `torch::IValue` from bytes produced by either
/// `torch::pickle_save` in C++ or `torch.save` in Python with custom object.
TORCH_API IValue pickle_load_obj(std::string_view data);
/// `reader` is a function that takes in a size to read from some pickled
/// binary. `reader` should remember where it last read, and return
/// the number of bytes read.
/// See `torch::pickle` for details.
/// type_resolver is used to resolve any JIT type based on type str
TORCH_API IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser,
ObjLoader obj_loader = nullptr);
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
///
/// If any `torch::IValue`s in the pickled data are `Object`s, then a
/// `class_resolver` function must be provided.
///
/// See `torch::pickle` for details.
TORCH_API IValue unpickle(
const char* data,
size_t size,
TypeResolver type_resolver = nullptr,
c10::ArrayRef<at::Tensor> tensor_table = {},
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser);
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
///
/// If any `torch::IValue`s in the pickled data are `Object`s, then a
/// `class_resolver` function must be provided.
///
/// See `torch::pickle` for details.
TORCH_API IValue unpickle(
const char* data,
size_t size,
ObjLoader obj_loader,
TypeResolver type_resolver = nullptr,
c10::ArrayRef<at::Tensor> tensor_table = {},
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser);
#ifndef C10_MOBILE
class VectorReader : public caffe2::serialize::ReadAdapterInterface {
public:
VectorReader(std::vector<char> data) : data_(std::move(data)) {}
size_t size() const override {
return data_.size();
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
const override;
private:
std::vector<char> data_;
};
class StringViewReader : public caffe2::serialize::ReadAdapterInterface {
public:
StringViewReader(std::string_view data) : data_(data) {}
size_t size() const override {
return data_.size();
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
const override;
private:
std::string_view data_;
};
#endif
} // namespace torch::jit