From c2eedfe406bf7c8851c393825c09280a7e329242 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Sat, 3 Jun 2023 17:10:25 +0300 Subject: [PATCH] Fix/small improvements (#18) * Fix formatting for arguments description * LoRA bias depend on bias from the config. * LRSchedularBase class now an abstract one --- src/data/downloader.py | 2 +- src/model/gpt_language_model/peft/lora.py | 7 ++++++- src/model/lr_schedulers.py | 6 +++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/data/downloader.py b/src/data/downloader.py index 718d2f5..7a12097 100644 --- a/src/data/downloader.py +++ b/src/data/downloader.py @@ -14,7 +14,7 @@ def download(config: DictConfig, *, override_if_exists: bool = True) -> Path: omegaconf's dictionary with three keys: url, folder and filename url is from where to download the file folder - in which folder to put the downloaded file - override_if_exists: bool + override_if_exists : bool if True will download even if file with such name already exists Raises diff --git a/src/model/gpt_language_model/peft/lora.py b/src/model/gpt_language_model/peft/lora.py index 8a507af..d40e490 100644 --- a/src/model/gpt_language_model/peft/lora.py +++ b/src/model/gpt_language_model/peft/lora.py @@ -226,6 +226,11 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: | query | key | value | ---------------------------------------- + Parameters + ---------- + x : torch.Tensor + tensor with weights update that needs to be padded with zeros + Returns ------- torch.Tensor @@ -505,7 +510,7 @@ def __init__( enable_lora=[True, False, True], fan_in_fan_out=False, merge_weights=True, - bias=False, + bias=bias, ) diff --git a/src/model/lr_schedulers.py b/src/model/lr_schedulers.py index dd6b19c..73c37ff 100644 --- a/src/model/lr_schedulers.py +++ b/src/model/lr_schedulers.py @@ -1,11 +1,11 @@ import math -from abc import abstractclassmethod +from abc import ABC, abstractmethod from typing import Optional import torch -class LRSchedulerBase: +class LRSchedulerBase(ABC): def __init__(self, optimizer: torch.optim.Optimizer) -> None: """Create base class for custom learning rate schedulers. @@ -17,7 +17,7 @@ def __init__(self, optimizer: torch.optim.Optimizer) -> None: super().__init__() self.optimizer = optimizer - @abstractclassmethod + @abstractmethod def _get_lr(self, iteration: int) -> float: pass