Skip to content

Commit

Permalink
change numpy to bytearray as buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 8, 2015
1 parent 6942980 commit a4de0eb
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions wrapper/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ def ctypes2numpy(cptr, length, dtype):
def ctypes2buffer(cptr, length):
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise RuntimeError('expected char pointer')
res = np.zeros(length, dtype='uint8')
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
res = bytearray(length)
rptr = (ctypes.c_char * length).from_buffer(res)
if not ctypes.memmove(rptr, cptr, length):
raise RuntimeError('memmove failed')
return res

def c_str(string):
return ctypes.c_char_p(string.encode('utf-8'))

Expand Down Expand Up @@ -886,7 +887,7 @@ def __getstate__(self):
def __setstate__(self, state):
bst = state["_Booster"]
if bst is not None:
state["_Booster"] = Booster(model_file=booster)
state["_Booster"] = Booster(model_file=bst)
self.__dict__.update(state)

def booster(self):
Expand Down

0 comments on commit a4de0eb

Please sign in to comment.