Skip to content

Commit

Permalink
update weights (#144)
Browse files Browse the repository at this point in the history
* change msa data source handling

* load new weights

* move COMPONENT_URL to const

* add 1536 to AVAILABLE_MODEL_SIZES

* pt2 -> pt

* mask with none
  • Loading branch information
arogozhnikov authored Nov 5, 2024
1 parent e835b6e commit d6ccd41
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 26 deletions.
37 changes: 24 additions & 13 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,19 @@ class UnsupportedInputError(RuntimeError):
pass


def load_exported(comp_key: str, device: torch.device) -> torch.nn.Module:
class ModuleWrapper:
def __init__(self, jit_module):
self.jit_module = jit_module

def forward(self, crop_size: int, **kw):
return getattr(self.jit_module, f"forward_{crop_size}")(**kw)


def load_exported(comp_key: str, device: torch.device) -> ModuleWrapper:
torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)])
local_path = chai1_component(comp_key)
exported_program = torch.export.load(local_path)
return exported_program.module().to(device)
# specifying map_location=... doesn't load weights properly
return ModuleWrapper(torch.jit.load(local_path).to(device))


# %%
Expand Down Expand Up @@ -405,22 +414,20 @@ def run_folding_on_context(
## Load exported models
##

# Model is size-specific
model_size = min(x for x in AVAILABLE_MODEL_SIZES if n_actual_tokens <= x)
_, _, model_size = msa_mask.shape
assert model_size in AVAILABLE_MODEL_SIZES

feature_embedding = load_exported(f"{model_size}/feature_embedding.pt2", device)
token_input_embedder = load_exported(
f"{model_size}/token_input_embedder.pt2", device
)
trunk = load_exported(f"{model_size}/trunk.pt2", device)
diffusion_module = load_exported(f"{model_size}/diffusion_module.pt2", device)
confidence_head = load_exported(f"{model_size}/confidence_head.pt2", device)
feature_embedding = load_exported("feature_embedding.pt", device)
token_input_embedder = load_exported("token_embedder.pt", device)
trunk = load_exported("trunk.pt", device)
diffusion_module = load_exported("diffusion_module.pt", device)
confidence_head = load_exported("confidence_head.pt", device)

##
## Run the features through the feature embedder
##

embedded_features = feature_embedding.forward(**features)
embedded_features = feature_embedding.forward(crop_size=model_size, **features)
token_single_input_feats = embedded_features["TOKEN"]
token_pair_input_feats, token_pair_structure_input_feats = embedded_features[
"TOKEN_PAIR"
Expand Down Expand Up @@ -448,6 +455,7 @@ def run_folding_on_context(
block_indices_w=block_indices_w,
atom_single_mask=atom_single_mask,
atom_token_indices=atom_token_indices,
crop_size=model_size,
)
token_single_initial_repr, token_single_structure_input, token_pair_initial_repr = (
token_input_embedder_outputs
Expand All @@ -473,6 +481,7 @@ def run_folding_on_context(
template_input_masks=template_input_masks,
token_single_mask=token_single_mask,
token_pair_mask=token_pair_mask,
crop_size=model_size,
)
# We won't be using the trunk anymore; remove it from memory
del trunk
Expand Down Expand Up @@ -502,6 +511,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
atom_noised_coords=atom_noised_coords.float(),
noise_sigma=noise_sigma.float(),
atom_token_indices=atom_token_indices,
crop_size=model_size,
)

num_diffn_samples = 5 # Fixed at export time
Expand Down Expand Up @@ -587,6 +597,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
token_reference_atom_index=token_reference_atom_index,
atom_token_index=atom_token_indices,
atom_within_token_index=atom_within_token_index,
crop_size=model_size,
)
for s in range(num_diffn_samples)
]
Expand Down
3 changes: 2 additions & 1 deletion chai_lab/data/collate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# static graph is exported for different n_tokens,
# we pad to the closest one
AVAILABLE_MODEL_SIZES = [256, 384, 512, 768, 1024, 2048]
AVAILABLE_MODEL_SIZES = [256, 384, 512, 768, 1024, 1536, 2048]


@dataclass(frozen=True)
Expand All @@ -34,5 +34,6 @@ def get_pad_sizes(contexts: list[AllAtomStructureContext]) -> PadSizes:
max_n_atoms = max(context.num_atoms for context in contexts)
n_atoms = 23 * n_tokens
assert max_n_atoms <= n_atoms
assert n_atoms % 32 == 0

return PadSizes(n_tokens=n_tokens, n_atoms=n_atoms)
16 changes: 9 additions & 7 deletions chai_lab/data/features/generators/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,8 @@ class MSADataSourceGenerator(FeatureGenerator):

def __init__(
self,
num_classes: int = 6, # TODO how this works with chai1?
num_classes: int = 6, # chai1 : 5 classes + mask val
):
assert num_classes == max(msa_dataset_source_to_int.values()) + 1

super().__init__(
ty=FeatureType.MSA,
encoding_ty=EncodingType.ONE_HOT,
Expand All @@ -237,8 +235,12 @@ def _generate(
msa_mask: Bool[Tensor, "batch depth tokens"],
msa_sequence_source: UInt8[Tensor, "batch depth tokens"],
) -> Tensor:
msa_sequence_source = msa_sequence_source.masked_fill(
~msa_mask, self.num_classes
)

from chai_lab.data.parsing.msas.data_source import MSADataSource

query = msa_dataset_source_to_int[MSADataSource.QUERY]
none = msa_dataset_source_to_int[MSADataSource.NONE]
# chai-1 specific: replace QUERY with NONE
msa_sequence_source[msa_sequence_source.eq(query)] = none
# use none for masking.
msa_sequence_source = msa_sequence_source.masked_fill(~msa_mask, none)
return self.make_feature(data=msa_sequence_source.unsqueeze(-1))
2 changes: 1 addition & 1 deletion chai_lab/data/parsing/msas/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def encode_source_to_int(source: MSADataSource) -> int:
MSADataSource.UNIPROT_N3: 3,
MSADataSource.UNIREF90_N3: 2,
MSADataSource.MGNIFY_N3: 1,
MSADataSource.QUERY: 5, # TODO how does it work with chai-1?
MSADataSource.QUERY: 5, # in chai-1 remapped to none.
}

database_ids: set[str] = set(x.value for x in MSADataSource)
12 changes: 8 additions & 4 deletions chai_lab/utils/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,19 @@ def get_path(self) -> Path:
path=downloads_path.joinpath("conformers_v1.apkl"),
)

COMPONENT_URL = (
"https://chaiassets.com/chai1-inference-depencencies/models_v2/{comp_key}"
)


def chai1_component(comp_key: str) -> Path:
"""
Downloads exported model, stores in locally in the repo/downloads
comp_key: e.g. '384/trunk.pt2'
comp_key: e.g. 'trunk.pt'
"""
assert comp_key.endswith(".pt2")
url = f"https://chaiassets.com/chai1-inference-depencencies/models/{comp_key}"
result = downloads_path.joinpath("models", comp_key)
assert comp_key.endswith(".pt")
url = COMPONENT_URL.format(comp_key=comp_key)
result = downloads_path.joinpath("models_v2", comp_key)
download_if_not_exists(url, result)

return result

0 comments on commit d6ccd41

Please sign in to comment.