Skip to content

Commit

Permalink
Cleanup deprecated compilation cache APIs.
Browse files Browse the repository at this point in the history
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 596733252
  • Loading branch information
Scenic Authors committed Jan 13, 2024
1 parent a23d9da commit faa897b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
26 changes: 12 additions & 14 deletions scenic/projects/lang4video/main_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@
models.ALL_MODELS['image_text'] = ImageTextModel

trainers.ALL_TRAINERS['visual_text_trainer'] = visual_text_trainer.train
trainers.ALL_TRAINERS[
'visual_text_with_text_pretraining_trainer'] = visual_text_with_text_pretraining_trainer.train
trainers.ALL_TRAINERS[
'zero_shot_classification_trainer'] = zero_shot_classification_trainer.evaluate
trainers.ALL_TRAINERS[
'zero_shot_text_to_visual_retrieval_trainer'] = zero_shot_text_to_visual_retrieval_trainer.evaluate
trainers.ALL_TRAINERS['visual_text_with_text_pretraining_trainer'] = (
visual_text_with_text_pretraining_trainer.train
)
trainers.ALL_TRAINERS['zero_shot_classification_trainer'] = (
zero_shot_classification_trainer.evaluate
)
trainers.ALL_TRAINERS['zero_shot_text_to_visual_retrieval_trainer'] = (
zero_shot_text_to_visual_retrieval_trainer.evaluate
)



Expand Down Expand Up @@ -81,14 +84,9 @@ def main(rng: Optional[jnp.ndarray], config: ml_collections.ConfigDict,

if (config.get('use_jax_compilation_cache', True) and
hasattr(jax.devices()[0].client, 'runtime_type')):
if compilation_cache.is_initialized():
logging.info('JAX compilation cache already initialized.')
else:
jax_cache_dir = os.path.join(workdir, 'jax_cache', 'ttl=30d')
logging.info('JAX compilation cache path: %s', jax_cache_dir)
compilation_cache.initialize_cache(jax_cache_dir)
else:
logging.info('JAX compilation cache not initialized.')
jax_cache_dir = os.path.join(workdir, 'jax_cache', 'ttl=30d')
logging.info('JAX compilation cache path: %s', jax_cache_dir)
compilation_cache.set_cache_dir(jax_cache_dir)

model_cls = models.get_model_cls(config.model_name)
assert model_cls is ImageTextModel
Expand Down
2 changes: 1 addition & 1 deletion scenic/projects/owl_vit/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def main(argv: Sequence[str]) -> None:
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')

compilation_cache.initialize_cache('/tmp/jax_compilation_cache')
compilation_cache.set_cache_dir('/tmp/jax_compilation_cache')

config_name = os.path.splitext(os.path.basename(FLAGS.config))[0]

Expand Down

0 comments on commit faa897b

Please sign in to comment.