From 8c630b25625076654b12b86b0d6f03bae16bf376 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:21:43 +0100 Subject: [PATCH 1/4] allow multi-gpu sharding --- entropix/weights.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/entropix/weights.py b/entropix/weights.py index de1b426..5670757 100644 --- a/entropix/weights.py +++ b/entropix/weights.py @@ -90,7 +90,10 @@ def load_weights( model_params.head_dim, ) # print(name, weight.shape, sharding._to_xla_hlo_sharding(weight.ndim)) - w[name] = jax.device_put(weight, sharding) + if weight.ndim == 1: + w[name] = jax.device_put_replicated(weight, jax.devices()) + else: + w[name] = jax.device_put(weight, sharding) for i in range(model_params.n_layers): layer_weights.append( From 21c7303ecbb54c0a6f41086ab092b330cd1a1e19 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 14 Dec 2024 12:28:50 +0100 Subject: [PATCH 2/4] always use sharding --- entropix/weights.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/entropix/weights.py b/entropix/weights.py index 5670757..96bbde1 100644 --- a/entropix/weights.py +++ b/entropix/weights.py @@ -90,10 +90,12 @@ def load_weights( model_params.head_dim, ) # print(name, weight.shape, sharding._to_xla_hlo_sharding(weight.ndim)) - if weight.ndim == 1: - w[name] = jax.device_put_replicated(weight, jax.devices()) - else: - w[name] = jax.device_put(weight, sharding) + if weight.ndim == 0: + weight = jnp.stack([weight] * jax.device_count()) + elif weight.ndim == 1: + weight = jnp.concatenate([weight] * jax.device_count()) + w[name] = jax.device_put(weight, sharding) + for i in range(model_params.n_layers): layer_weights.append( From 87fcfc84cf9835a69ddebe225b8a5844ec2598d0 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 14 Dec 2024 15:02:03 +0100 Subject: [PATCH 3/4] Update weights.py --- entropix/weights.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/entropix/weights.py b/entropix/weights.py index 96bbde1..de66985 100644 --- a/entropix/weights.py +++ b/entropix/weights.py @@ -92,8 +92,6 @@ def load_weights( # print(name, weight.shape, sharding._to_xla_hlo_sharding(weight.ndim)) if weight.ndim == 0: weight = jnp.stack([weight] * jax.device_count()) - elif weight.ndim == 1: - weight = jnp.concatenate([weight] * jax.device_count()) w[name] = jax.device_put(weight, sharding) From b5aa9d8ad10dbc34bd3420c2f7aefe5c622e6e40 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 14 Dec 2024 15:38:41 +0100 Subject: [PATCH 4/4] Update weights.py --- entropix/weights.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/entropix/weights.py b/entropix/weights.py index de66985..5983e8d 100644 --- a/entropix/weights.py +++ b/entropix/weights.py @@ -79,7 +79,7 @@ def load_weights( for file in ckpt_dir.glob("*.npy"): name = ".".join(str(file).split("/")[-1].split(".")[:-1]) weight = jnp.load(file=file, mmap_mode="r", allow_pickle=True) - partition_spec = create_partition_spec(name) + partition_spec = jax.sharding.PartitionSpec() if weight.ndim == 0 else create_partition_spec(name) sharding = NamedSharding(mesh, partition_spec) if any(lyr in name for lyr in ["wq", "wk", "wv", "wo", "w1", "w2", "w3"]): weight = weight.T @@ -90,11 +90,8 @@ def load_weights( model_params.head_dim, ) # print(name, weight.shape, sharding._to_xla_hlo_sharding(weight.ndim)) - if weight.ndim == 0: - weight = jnp.stack([weight] * jax.device_count()) w[name] = jax.device_put(weight, sharding) - for i in range(model_params.n_layers): layer_weights.append( LayerWeights(