Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Oct 26, 2024
1 parent d573274 commit a7740f1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
74 changes: 70 additions & 4 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import warnings
from functools import partial
from itertools import repeat
import tarfile
from pathlib import Path
import shutil

#
from multiprocessing import Lock, Pool, Process, cpu_count, shared_memory
Expand Down Expand Up @@ -264,10 +267,10 @@ def __init__(
# X. miscellaneous
self.lazy_loading = lazy_loading
self.lazy_loading_dir = lazy_loading_dir
if self.lazy_loading:
if self.lazy_loading_dir is None:
saving_code = int(np.random.uniform(1, 1e8))
self.lazy_loading_dir = f'./tmp_{saving_code}'
if self.lazy_loading_dir is None:
saving_code = int(np.random.uniform(1, 1e8))
self.lazy_loading_dir = f'./tmp_{saving_code}'
self.lazy_loading_dir = str(Path(self.lazy_loading_dir.rstrip('/\\')))

if not verbosity == 0:
self.verbosity = 1
Expand Down Expand Up @@ -1214,6 +1217,69 @@ def assign_feature_importances_by_points(
out_ = pd.concat([Sample_ST_df, mean_feature_importances_across_ensembles], axis=1).dropna()
return out_

@staticmethod
def load(tar_gz_file, target_lazyloading_path=None, remove_original_file=False):

if target_lazyloading_path is None:
saving_code = int(np.random.uniform(1, 1e8))
target_lazyloading_path = f'./tmp_{saving_code}'
target_lazyloading_path = str(Path(target_lazyloading_path.rstrip('/\\')))

file = tarfile.open(tar_gz_file)
file.extractall(target_lazyloading_path, filter=tarfile.data_filter)
file.close()

with open(os.path.join(target_lazyloading_path, 'model.pkl'), 'rb') as f:
model = pickle.load(f)

if model.lazy_loading:
# then this is lazy loading
if not len(os.listdir(target_lazyloading_path))>1:
raise FileExistsError('Your model is not a lazy_loading model, but more than 1 files/folders are found in the .tar.gz file?')
else:
_, basename = os.path.split(Path(model.lazy_loading_dir))
new_lazy_loading_path = os.path.join(target_lazyloading_path, basename)
model.set_params(lazy_loading_dir=new_lazy_loading_path)
model.model_dict.directory = new_lazy_loading_path

if remove_original_file:
os.remove(tar_gz_file)

return model

def save(self, tar_gz_file, remove_temporary_file = True):
if not os.path.exists(self.lazy_loading_dir):
os.makedirs(self.lazy_loading_dir, exist_ok=False)

if self.lazy_loading:
ensemble_ids = list(self.model_dict.ensemble_models.keys())
for current_in_memory_ensemble in ensemble_ids:
self.model_dict.dump_ensemble(current_in_memory_ensemble)

# check all ensemble on disk
for ensemble_id in range(self.ensemble_fold):
if not f'ensemble_{ensemble_id}_dict.pkl' in os.listdir(self.lazy_loading_dir):
raise FileNotFoundError(f'Ensemble models file ensemble_{ensemble_id}_dict.pkl is missing in lazyloading directory {self.lazy_loading_dir}!')

#
path, basename = os.path.split(Path(tar_gz_file.rstrip('/\\')))

# temporary save the model using pickle
model_path = os.path.join(path, 'model.pkl')
with open(model_path, 'wb') as f:
pickle.dump(self, f)

# save the main model class and potentially lazyloading pieces to the tar.gz file
with tarfile.open(tar_gz_file, "w:gz") as tar:
tar.add(model_path)
if self.lazy_loading:
for pieces in os.listdir(self.lazy_loading_dir):
tar.add(os.path.join(self.lazy_loading_dir, pieces))

if remove_temporary_file:
os.remove(model_path)
if self.lazy_loading:
shutil.rmtree(self.lazy_loading_dir)

class AdaSTEMClassifier(AdaSTEM):
"""AdaSTEM model Classifier interface
Expand Down
2 changes: 2 additions & 0 deletions stemflow/model/Hurdle.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def fit(self, X_train: Union[pd.core.frame.DataFrame, np.ndarray], y_train: Sequ
np.array(X_train[X_train["y_train"] > 0].iloc[:, -1]),
verbosity=1,
)

return self

def predict(
self,
Expand Down
9 changes: 0 additions & 9 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,6 @@ def test_AdaSTEMRegressor():
assert importances_by_points.shape[0] > 0
assert importances_by_points.shape[1] == len(x_names) + 3

#
eval = AdaSTEM.eval_STEM_res("AAA", pred_df.y_true, pred_df.y_pred) # this should not work
all_nan = True
for eval_metric in eval:
if not np.isnan(eval[eval_metric]):
all_nan = False
break
assert all_nan

# score
score_df = model.score(X_test, y_test)

Expand Down

0 comments on commit a7740f1

Please sign in to comment.