Skip to content

Commit

Permalink
refactor: organize with module and add logs
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanhong-Lan committed Sep 23, 2024
1 parent 79b93b4 commit e122f48
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 237 deletions.
44 changes: 23 additions & 21 deletions ai/embedding_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,28 @@
from numpy import ndarray


def cos_similarity(a: ndarray, b: ndarray) -> float:
assert a.ndim == 1, f"The dim of a should be 1, but was {a.ndim}"
assert a.shape == b.shape, f"The shape of a and b should be the same, but was {a.shape} and {b.shape}"
dot_product = np.dot(a, b)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
return round(dot_product / (norm_a * norm_b), 5)
class EmbeddingUtil:
@classmethod
def cos_similarity(cls, a: ndarray, b: ndarray) -> float:
assert a.ndim == 1, f"The dim of a should be 1, but was {a.ndim}"
assert a.shape == b.shape, f"The shape of a and b should be the same, but was {a.shape} and {b.shape}"
dot_product = np.dot(a, b)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
return round(dot_product / (norm_a * norm_b), 5)

@classmethod
def cos_similarity_normalized(cls, a: ndarray, b: ndarray):
temp = cls.cos_similarity(a, b)
upper_bound = 1
lower_bound = -1
return (temp - lower_bound) / (upper_bound - lower_bound)

def cos_similarity_normalized(a: ndarray, b: ndarray):
temp = cos_similarity(a, b)
upper_bound = 1
lower_bound = -1
return (temp - lower_bound) / (upper_bound - lower_bound)


def self_cos_similarty_compare(embedding_list):
n = len(embedding_list)
res = np.zeros((n, n))
for i in range(n):
for j in range(n):
res[i][j] = cos_similarity_normalized(embedding_list[i], embedding_list[j])
return np.round(res, 4)
@classmethod
def self_cos_similarity_compare(cls, embedding_list):
n = len(embedding_list)
res = np.zeros((n, n))
for i in range(n):
for j in range(n):
res[i][j] = cls.cos_similarity_normalized(embedding_list[i], embedding_list[j])
return np.round(res, 4)
30 changes: 17 additions & 13 deletions ai/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
from android_testing_utils.log import my_logger


def eliminate_randomness(seed):
my_logger.hint(my_logger.LogLevel.WARNING, "TestingUtil", True,
f" #### Eliminate randomness with seed {seed} ####")
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
class TestingUtil:
@classmethod
def eliminate_randomness(cls, seed):
my_logger.auto_hint(
my_logger.LogLevel.WARNING, cls, True, f" #### Eliminate randomness with seed {seed} ####"
)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True


def model_info(model, input_shape):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone = model.to(device)
summary(backbone, input_shape)
@classmethod
def model_info(cls, model, input_shape):
my_logger.auto_hint(my_logger.LogLevel.INFO, cls, True, f" #### Model Summary ####")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone = model.to(device)
summary(backbone, input_shape)
18 changes: 12 additions & 6 deletions monitor/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ def wrapper(*args, **kwargs):
before = time.time()
res = function(*args, **kwargs)
after = time.time()
my_logger.hint(my_logger.LogLevel.INFO, "TimeCountDecorator", False,
f"Function {function.__name__} cost time: {round(after-before, 4)}s")
my_logger.auto_hint(
my_logger.LogLevel.INFO, "TimeCountDecorator", False,
f"Function {function.__name__} cost time: {round(after-before, 4)}s"
)
return res
return wrapper

Expand All @@ -26,8 +28,10 @@ def wrapper(self, *args, **kwargs):
before = time.time()
res = function(self, *args, **kwargs)
after = time.time()
my_logger.hint(my_logger.LogLevel.INFO, "TimeCountDecorator", False,
f"Function {function.__name__} cost time: {round(after-before, 2)}s")
my_logger.auto_hint(
my_logger.LogLevel.INFO, "TimeCountDecorator", False,
f"Function {function.__name__} cost time: {round(after-before, 2)}s"
)
return res
return wrapper

Expand All @@ -38,7 +42,9 @@ def wrapper(*args, **kwargs):
before = time.time()
res = function(*args, **kwargs)
after = time.time()
my_logger.hint(my_logger.LogLevel.INFO, "TimeCountDecorator", False,
f"Function {function.__name__} cost time: {round(after-before, 2)}s")
my_logger.auto_hint(
my_logger.LogLevel.INFO, "TimeCountDecorator", False,
f"Function {function.__name__} cost time: {round(after-before, 2)}s"
)
return res
return wrapper
Loading

0 comments on commit e122f48

Please sign in to comment.