Skip to content

Commit

Permalink
Fix missing shape specification
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Nov 10, 2023
1 parent 48a27f9 commit f666790
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion docs/tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ And here is how to replace the `weight` of every linear layer in some arbitrary
def trunc_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
out, in_ = weight.shape
stddev = math.sqrt(1 / in_)
return stddev * jax.random.truncated_normal(key, lower=-2, upper=2)
return stddev * jax.random.truncated_normal(key, shape=(out, in_), lower=-2, upper=2)

def init_linear_weight(model, init_fn, key):
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
Expand Down

0 comments on commit f666790

Please sign in to comment.