Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
latentvector committed Apr 3, 2024
1 parent 7e0eec1 commit 0af1da2
Show file tree
Hide file tree
Showing 22 changed files with 337 additions and 377 deletions.
172 changes: 91 additions & 81 deletions commune/module/module.py

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions commune/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ def namespace(cls, search=None,
network, netuid = network.split('.')
else:
netuid = netuid or 0

if c.is_digit(netuid):
if c.is_int(netuid):
netuid = int(netuid)

return c.module(network)().namespace(search=search,
update=update,
max_age=max_age,
Expand Down
44 changes: 29 additions & 15 deletions commune/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
# Do whatever you want with this code
# Dont pull up with your homies if it dont work.
import numpy as np
import torch
import msgpack
import msgpack_numpy
from typing import Tuple, List, Union, Optional

from typing import *
from copy import deepcopy
from munch import Munch

import commune as c
import json
Expand Down Expand Up @@ -58,6 +55,8 @@ def resolve_value(self, x):
v_type = type(x)
if v_type in [dict, list, tuple, set]:
new_value = self.serialize(x, mode=None)
elif v_type in [int, float, str, bool]:
new_value = x
else:
# GET THE TYPE OF THE VALUE
str_v_type = self.get_type_str(data=x)
Expand All @@ -66,8 +65,7 @@ def resolve_value(self, x):
x = getattr(self, f'serialize_{str_v_type}')(data=x)
new_value = {'data': x, 'data_type': str_v_type, 'serialized': True}
else:
# SERIALIZE MODE OFF
new_value = x
new_value = {"success": False, "error": f"Type {str_v_type} not supported"}

return new_value

Expand All @@ -88,6 +86,10 @@ def deserialize(self, x) -> object:
if x.startswith('{') or x.startswith('['):
x = self.str2dict(x)
else:
if c.is_int(x):
x = int(x)
elif c.is_float(x):
x = float(x)
return x

is_single = isinstance(x,dict) and all([k in x for k in ['data', 'data_type', 'serialized']])
Expand Down Expand Up @@ -155,16 +157,22 @@ def deserialize_munch(self, data: bytes) -> 'Munch':
return self.dict2munch(self.str2dict(data))

def dict2bytes(self, data:dict) -> bytes:
import msgpack
data_json_str = json.dumps(data)
data_json_bytes = msgpack.packb(data_json_str)
return data_json_bytes

def dict2str(self, data:dict) -> bytes:
data_json_str = json.dumps(data)
return data_json_str

def str2dict(self, data:str) -> bytes:
if isinstance(data, bytes):
data = data.decode('utf-8')
return json.loads(data)
if isinstance(data, str):
data = json.loads(data)
assert isinstance(data, dict), f"data must be a dict, not {type(data)}"
return data

@classmethod
def hex2str(cls, x, **kwargs):
Expand All @@ -181,6 +189,7 @@ def str2hex(cls, x, **kwargs):


def bytes2dict(self, data:bytes) -> dict:
import msgpack
json_object_bytes = msgpack.unpackb(data)
return json.loads(json_object_bytes)

Expand All @@ -190,20 +199,21 @@ def bytes2dict(self, data:bytes) -> dict:
def torch2bytes(self, data:'torch.Tensor')-> bytes:
return self.numpy2bytes(self.torch2numpy(data))

def torch2numpy(self, data:'torch.Tensor')-> np.ndarray:
def torch2numpy(self, data:'torch.Tensor')-> 'np.ndarray':
if data.requires_grad:
data = data.detach()
data = data.cpu().numpy()
return data


def numpy2bytes(self, data:np.ndarray)-> bytes:
import msgpack_numpy
import msgpack
output = msgpack.packb(data, default=msgpack_numpy.encode)
return output

def bytes2torch(self, data:bytes, ) -> 'torch.Tensor':
import torch
numpy_object = self.bytes2numpy(data)

int64_workaround = bool(numpy_object.dtype == np.int64)
if int64_workaround:
numpy_object = numpy_object.astype(np.float64)
Expand All @@ -213,18 +223,20 @@ def bytes2torch(self, data:bytes, ) -> 'torch.Tensor':
return torch_object

def bytes2numpy(self, data:bytes) -> np.ndarray:
import msgpack_numpy
import msgpack
output = msgpack.unpackb(data, object_hook=msgpack_numpy.decode)
return output


def deserialize_torch(self, data: dict) -> torch.Tensor:
def deserialize_torch(self, data: dict) -> 'torch.Tensor':
from safetensors.torch import load
if isinstance(data, str):
data = self.str2bytes(data)
data = load(data)
return data['data']

def serialize_torch(self, data: torch.Tensor) -> 'DataBlock':
def serialize_torch(self, data: 'torch.Tensor') -> 'DataBlock':
from safetensors.torch import save
output = save({'data':data})
return self.bytes2str(output)
Expand All @@ -238,12 +250,11 @@ def deserialize_numpy(self, data: bytes) -> 'np.ndarray':
data = self.str2bytes(data)
return self.bytes2numpy(data)



def get_type_str(self, data):
'''
## Documentation for get_type_str function
### Purpose
The purpose of this function is to determine and return the data type of the input given to it in string format. It supports identification of various data types including Munch, Tensor, ndarray, and DataFrame.
Expand Down Expand Up @@ -282,13 +293,15 @@ def get_type_str(self, data):

@classmethod
def test_serialize(cls):
import torch
module = Serializer()
data = {'bro': {'fam': torch.ones(2,2), 'bro': [torch.ones(1,1)]}}
proto = module.serialize(data)
module.deserialize(proto)

@classmethod
def test_deserialize(cls):
import torch
module = Serializer()

t = c.time()
Expand All @@ -301,6 +314,7 @@ def test_deserialize(cls):

@classmethod
def test(cls, size=1):
import torch
self = cls()
stats = {}
data = {'bro': {'fam': torch.randn(size,size), 'bro': [np.ones((2,1))]}}
Expand Down
2 changes: 1 addition & 1 deletion commune/server/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def verify(self, input:dict) -> dict:
return {'success': True, 'msg': f'is verified admin'}


whitelist = list(set(self.module.whitelist + c.helper_functions))
whitelist = list(set(self.module.whitelist + c.whitelist))
blacklist = self.module.blacklist

assert fn in whitelist , f"Function {fn} not in whitelist={whitelist}"
Expand Down
7 changes: 6 additions & 1 deletion commune/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def __init__(
self.chunk_size = chunk_size
self.timeout = timeout
self.public = public

if module == None:
module = 'module'

if isinstance(module, str):
module = c.module(module)()

Expand Down Expand Up @@ -282,7 +284,10 @@ def test_serving(cls):
module = c.serve(module_name, wait_for_server=True)
module = c.connect(module_name)
module.put("hey",1)
assert module.get("hey") == 1, f"get failed {module.get('hey')}"
c.kill(module_name)
return {'success': True, 'msg': 'server test passed'}


@classmethod
def test_serving_with_different_key(cls):
Expand Down
Loading

0 comments on commit 0af1da2

Please sign in to comment.