Skip to content

Commit

Permalink
add_bias
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 3, 2024
1 parent 9e7b69c commit 88994c1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_moe_dolomite_tp_state_dict(
tensor_parallel_word_embeddings=tensor_parallel_word_embeddings,
)
)

return state_dict


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def unshard_tensor_parallel_state_dicts(
check_correctness=check_correctness,
)
)

return output_state_dict


Expand Down Expand Up @@ -136,6 +137,8 @@ def _concatenate_tensors_from_scattermoe(
def _get_moe(
tensor_parallel_state_dicts: list[dict], config: MoEDolomiteConfig, prefix: str, check_correctness: bool
) -> dict:
assert not config.add_bias

output = {
prefix
+ "gate.weight": _get_once_from_state_dicts_with_check(
Expand All @@ -159,6 +162,8 @@ def _get_moe(


def _fix_moe_weights(config: MoEDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str) -> dict:
assert not config.add_bias

if is_glu(config.activation_function):
for layer_idx in range(config.n_layer):
key = f"{prefix}transformer.h.{layer_idx}.mlp.c_fc.weight"
Expand Down

0 comments on commit 88994c1

Please sign in to comment.