diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index 63ba6876..3ba73974 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -32,7 +32,7 @@ def get_block_size( dtype is not None and quant_type is not None ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations' - with init_empty_weights(include_buffers=True): + with init_empty_weights(include_buffers=False): block = get_model_block(config) n_params = sum(param.numel() for param in block.parameters())