-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserialize.py
178 lines (152 loc) · 5.38 KB
/
serialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
$URL: svn+ssh://svn.mems-exchange.org/repos/trunk/durus/serialize.py $
$Id: serialize.py 31144 2008-09-17 13:46:42Z dbinger $
"""
from durus.persistent import call_if_persistent
from durus.utils import int4_to_str, str_to_int4, join_bytes, BytesIO
from durus.utils import Pickler, Unpickler, loads, dumps, as_bytes
from types import MethodType
from zlib import compress, decompress, error as zlib_error
import struct
import sys
WRITE_COMPRESSED_STATE_PICKLES = True
PICKLE_PROTOCOL = 2
def pack_record(oid, data, refs):
"""(oid:str, data:str, refs:str) -> record:str
"""
return join_bytes([oid, int4_to_str(len(data)), data, refs])
def unpack_record(record):
"""(record:str) -> oid:str, data:str, refs:str
The inverse of pack_record().
"""
oid = record[:8]
data_length = str_to_int4(record[8:12])
data_end = 12 + data_length
data = record[12:data_end]
refs = record[data_end:]
return oid, data, refs
def split_oids(s):
"""(s:str) -> [str]
s is a packed string of oids. Return a list of oid strings.
"""
if not s:
return []
num, extra = divmod(len(s), 8)
assert extra == 0, s
fmt = '8s' * num
return list(struct.unpack('>' + fmt, s))
NEWLINE = as_bytes('\n')
def extract_class_name(record):
try:
oid, state, refs = unpack_record(record)
return state.split(NEWLINE, 2)[1]
except IndexError:
return "?"
if sys.version < "3":
def method(a, b):
return MethodType(a, b, object)
else:
def method(a, b):
return MethodType(a, b)
class ObjectWriter (object):
"""
Serializes objects for storage in the database.
The client is responsible for calling the close() method to avoid
leaking memory. The ObjectWriter uses a Pickler internally, and
Pickler objects do not participate in garbage collection.
"""
def __init__(self, connection):
self.sio = BytesIO()
self.pickler = Pickler(self.sio, PICKLE_PROTOCOL)
self.pickler.persistent_id = method(
call_if_persistent, self._persistent_id)
self.objects_found = []
self.refs = set() # populated by _persistent_id()
self.connection = connection
def close(self):
# see ObjectWriter.__doc__
# Explicitly break cycle involving pickler
self.pickler.persistent_id = int
self.pickler = None
def _persistent_id(self, obj):
"""(PersistentBase) -> (oid:str, klass:type)
This is called on PersistentBase instances during pickling.
"""
if obj._p_oid is None:
obj._p_oid = self.connection.new_oid()
obj._p_connection = self.connection
self.objects_found.append(obj)
elif obj._p_connection is not self.connection:
raise ValueError(
"Reference to %r has a different connection." % obj)
self.refs.add(obj._p_oid)
return obj._p_oid, type(obj)
def gen_new_objects(self, obj):
def once(obj):
raise RuntimeError('gen_new_objects() already called.')
self.gen_new_objects = once
yield obj # The modified object is also a "new" object.
for obj in self.objects_found:
yield obj
def get_state(self, obj):
self.sio.seek(0) # recycle BytesIO instance
self.sio.truncate()
self.pickler.clear_memo()
self.pickler.dump(type(obj))
self.refs.clear()
position = self.sio.tell()
self.pickler.dump(obj.__getstate__())
uncompressed = self.sio.getvalue()
pickled_type = uncompressed[:position]
pickled_state = uncompressed[position:]
if WRITE_COMPRESSED_STATE_PICKLES:
state = compress(pickled_state)
else:
state = pickled_state
data = pickled_type + state
self.refs.discard(obj._p_oid)
return data, join_bytes(self.refs)
COMPRESSED_START_BYTE = compress(dumps({}, 2))[0]
class ObjectReader (object):
def __init__(self, connection):
self.connection = connection
self.load_count = 0
def _get_unpickler(self, file):
connection = self.connection
get_instance = connection.get_cache().get_instance
def persistent_load(oid_klass):
oid, klass = oid_klass
return get_instance(oid, klass, connection)
unpickler = Unpickler(file)
unpickler.persistent_load = persistent_load
return unpickler
def get_ghost(self, data):
klass = loads(data)
instance = klass.__new__(klass)
instance._p_set_status_ghost()
return instance
def get_state(self, data, load=True):
self.load_count += 1
s = BytesIO()
s.write(data)
s.seek(0)
unpickler = self._get_unpickler(s)
klass = unpickler.load()
position = s.tell()
if data[s.tell()] == COMPRESSED_START_BYTE:
# This is almost certainly a compressed pickle.
try:
decompressed = decompress(data[position:])
except zlib_error:
pass # let the unpickler try anyway.
else:
s.write(decompressed)
s.seek(position)
if load:
return unpickler.load()
else:
return s.read()
def get_state_pickle(self, data):
return self.get_state(data, load=False)
def get_load_count(self):
return self.load_count