Skip to content

Commit

Permalink
Fix/small improvements (#18)
Browse files Browse the repository at this point in the history
* Fix formatting for arguments description
* LoRA bias depend on bias from the config.
* LRSchedularBase class now an abstract one
  • Loading branch information
Andrei-Aksionov authored Jun 3, 2023
1 parent fb208fb commit c2eedfe
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/data/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def download(config: DictConfig, *, override_if_exists: bool = True) -> Path:
omegaconf's dictionary with three keys: url, folder and filename
url is from where to download the file
folder - in which folder to put the downloaded file
override_if_exists: bool
override_if_exists : bool
if True will download even if file with such name already exists
Raises
Expand Down
7 changes: 6 additions & 1 deletion src/model/gpt_language_model/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
| query | key | value |
----------------------------------------
Parameters
----------
x : torch.Tensor
tensor with weights update that needs to be padded with zeros
Returns
-------
torch.Tensor
Expand Down Expand Up @@ -505,7 +510,7 @@ def __init__(
enable_lora=[True, False, True],
fan_in_fan_out=False,
merge_weights=True,
bias=False,
bias=bias,
)


Expand Down
6 changes: 3 additions & 3 deletions src/model/lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
from abc import abstractclassmethod
from abc import ABC, abstractmethod
from typing import Optional

import torch


class LRSchedulerBase:
class LRSchedulerBase(ABC):
def __init__(self, optimizer: torch.optim.Optimizer) -> None:
"""Create base class for custom learning rate schedulers.
Expand All @@ -17,7 +17,7 @@ def __init__(self, optimizer: torch.optim.Optimizer) -> None:
super().__init__()
self.optimizer = optimizer

@abstractclassmethod
@abstractmethod
def _get_lr(self, iteration: int) -> float:
pass

Expand Down

0 comments on commit c2eedfe

Please sign in to comment.