Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
Add inference mode support to TabularDatamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
manujosephv committed Nov 28, 2023
1 parent 44ef97a commit 390bb3b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ examples/test_save/
.ruff_cache/
tests/.datasets/
test.py
lightning_logs/
43 changes: 23 additions & 20 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class TabularDatamodule(pl.LightningDataModule):
class CACHE_MODES(Enum):
MEMORY = "memory"
DISK = "disk"
INFERENCE = "inference"

def __init__(
self,
Expand Down Expand Up @@ -451,6 +452,10 @@ def _cache_dataset(self):
self.train_dataset = train_dataset
self.validation_dataset = validation_dataset
self.test_dataset = test_dataset
elif self.cache_mode is self.CACHE_MODES.INFERENCE:
self.train_dataset = None
self.validation_dataset = None
self.test_dataset = None
else:
raise ValueError(f"{self.cache_mode} is not a valid cache mode")

Expand Down Expand Up @@ -491,26 +496,22 @@ def setup(self, stage: Optional[str] = None) -> None:
self._fitted = True
self._cache_dataset()

# def inference_only_copy(self):
# """Creates a copy of the datamodule with the train and validation datasets removed.
# This is useful for inference only scenarios where we don't want to save the train and validation datasets.

# Returns:
# TabularDatamodule: A copy of the datamodule with the train and validation datasets removed.
# """
# if self._fitted:
# raise RuntimeError("Cannot create an inference only copy after setup has been called")
# return TabularDatamodule(
# train=None,
# validation=None,
# test=self.test,
# config=self.config,
# target_transform=self.target_transform_template,
# train_sampler=self.train_sampler,
# seed=self.seed,
# cache_data=self.cache_mode,
# copy_data=False,
# )
def inference_only_copy(self):
"""Creates a copy of the datamodule with the train and validation datasets removed.
This is useful for inference only scenarios where we don't want to save the train and validation datasets.
Returns:
TabularDatamodule: A copy of the datamodule with the train and validation datasets removed.
"""
if not self._fitted:
raise RuntimeError("Can create an inference only copy only after model is fitted")
dm_inference = copy.copy(self)
dm_inference.train_dataset = None
dm_inference.validation_dataset = None
dm_inference.test_dataset = None
dm_inference.cache_mode = self.CACHE_MODES.INFERENCE
return dm_inference

# adapted from gluonts
@classmethod
def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
Expand Down Expand Up @@ -715,6 +716,8 @@ def _load_dataset_from_cache(self, tag: str = "train"):
dataset = torch.load(self.cache_dir / f"{tag}_dataset")
except FileNotFoundError:
raise FileNotFoundError(f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader")
elif self.cache_mode is self.CACHE_MODES.INFERENCE:
raise RuntimeError("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead")
else:
raise ValueError(f"{self.cache_mode} is not a valid cache mode")
return dataset
Expand Down
14 changes: 9 additions & 5 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,29 +1302,33 @@ def save_datamodule(self, dir: str, inference_only:bool = False) -> None:
without data. This cannot be used for further training, but can be
used for inference. Defaults to False.
"""
# if inference_only:
if inference_only:
dm = self.datamodule.inference_only_copy()
else:
dm = self.datamodule

joblib.dump(self.datamodule, os.path.join(dir, "datamodule.sav"))
joblib.dump(dm, os.path.join(dir, "datamodule.sav"))

def save_config(self, dir: str) -> None:
"""Saves the config in the specified directory."""
with open(os.path.join(dir, "config.yml"), "w") as fp:
OmegaConf.save(self.config, fp, resolve=True)

def save_model(self, dir: str) -> None:
def save_model(self, dir: str, inference_only:bool = False) -> None:
"""Saves the model and checkpoints in the specified directory.
Args:
dir (str): The path to the directory to save the model
drop_dataset (bool): Exclude the entire dataset to be storage friendly.
inference_only (bool): If True, will only save the inference
only version of the datamodule
"""
if os.path.exists(dir) and (os.listdir(dir)):
logger.warning("Directory is not empty. Overwriting the contents.")
for f in os.listdir(dir):
os.remove(os.path.join(dir, f))
os.makedirs(dir, exist_ok=True)
self.save_config(dir)
self.save_datamodule(dir)
self.save_datamodule(dir, inference_only=inference_only)
if hasattr(self.config, "log_target") and self.config.log_target is not None:
joblib.dump(self.logger, os.path.join(dir, "exp_logger.sav"))
if hasattr(self, "callbacks"):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def fake_metric(y_hat, y):
@pytest.mark.parametrize("custom_loss", [None, torch.nn.L1Loss()])
@pytest.mark.parametrize("custom_optimizer", [None, torch.optim.Adagrad])
@pytest.mark.parametrize("cache_data", ["memory", "disk"])
@pytest.mark.parametrize("inference_only", [True, False])
def test_save_load(
regression_data,
model_config_class,
Expand All @@ -81,6 +82,7 @@ def test_save_load(
custom_loss,
custom_optimizer,
cache_data,
inference_only,
tmp_path_factory,
):
(train, test, target) = regression_data
Expand Down Expand Up @@ -124,7 +126,7 @@ def test_save_load(
# sv_dir = tmpdir/"save_model"
# sv_dir.mkdir(exist_ok=True, parents=True)
sv_dir = tmp_path_factory.mktemp("saved_model")
tabular_model.save_model(str(sv_dir), drop_dataset=drop_data)
tabular_model.save_model(str(sv_dir), inference_only=inference_only)
new_mdl = TabularModel.load_from_checkpoint(str(sv_dir))
result_2 = new_mdl.evaluate(test)
assert (
Expand Down

0 comments on commit 390bb3b

Please sign in to comment.