diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py index 68c6f45a9..638907474 100644 --- a/dinov2/layers/block.py +++ b/dinov2/layers/block.py @@ -27,7 +27,7 @@ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: - from xformers.ops import scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat + from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat def scaled_index_add(input, index, source, scaling, alpha): is_proper_embed_dim = input.shape[-1] % 256 == 0