diff --git a/ltsm/data_pipeline/data_pipeline.py b/ltsm/data_pipeline/data_pipeline.py index 0a7a24a..2b021bd 100644 --- a/ltsm/data_pipeline/data_pipeline.py +++ b/ltsm/data_pipeline/data_pipeline.py @@ -51,8 +51,6 @@ def run(self): - Evaluating the model on test datasets and logging metrics. """ logging.info(self.args) - - model = self.model_manager.create_model() # Training settings training_args = TrainingArguments( @@ -75,6 +73,12 @@ def run(self): train_dataset, eval_dataset, test_datasets, _ = get_datasets(self.args) train_dataset, eval_dataset= HF_Dataset(train_dataset), HF_Dataset(eval_dataset) + + if self.args.model == 'PatchTST' or self.args.model == 'DLinear': + # Set the patch number to the size of the input sequence including the prompt sequence + self.model_manager.args.seq_len = train_dataset[0]["input_data"].size()[0] + + model = self.model_manager.create_model() trainer = Trainer( model=model, @@ -140,14 +144,31 @@ def get_args(): parser.add_argument('--d_ff', type=int, default=512, help='dimension of fcn') parser.add_argument('--dropout', type=float, default=0.2, help='dropout') parser.add_argument('--enc_in', type=int, default=1, help='encoder input size') + parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') parser.add_argument('--c_out', type=int, default=862, help='output size') parser.add_argument('--patch_size', type=int, default=16, help='patch size') parser.add_argument('--pretrain', type=int, default=1, help='is pretrain') parser.add_argument('--local_pretrain', type=str, default="None", help='local pretrain weight') parser.add_argument('--freeze', type=int, default=1, help='is model weight frozen') - parser.add_argument('--model', type=str, default='model', help='model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer]') + parser.add_argument('--model', type=str, default='model', help='model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer, DLinear, PatchTST, Informer]') parser.add_argument('--stride', type=int, default=8, help='stride') parser.add_argument('--tmax', type=int, default=10, help='tmax') + parser.add_argument('--dropout', type=float, default=0.05, help='dropout') + parser.add_argument('--embed', type=str, default='timeF', + help='time features encoding, options:[timeF, fixed, learned]') + parser.add_argument('--activation', type=str, default='gelu', help='activation') + parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') + parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') + parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') + parser.add_argument('--factor', type=int, default=1, help='attn factor') + parser.add_argument('--distil', action='store_false', + help='whether to use distilling in encoder, using this argument means not using distilling', + default=True) + parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') + parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') + parser.add_argument('--embed_type', type=int, default=0, help='0: default 1: value embedding + temporal embedding + positional embedding 2: value embedding + temporal embedding 3: value embedding + positional embedding 4: value embedding') + parser.add_argument('--freq', type=str, default='h', + help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') # Training Settings parser.add_argument('--eval', type=int, default=0, help='evaluation') @@ -163,6 +184,20 @@ def get_args(): parser.add_argument('--lradj', type=str, default='type1', help='learning rate adjustment type') parser.add_argument('--patience', type=int, default=3, help='early stopping patience') parser.add_argument('--gradient_accumulation_steps', type=int, default=64, help='gradient accumulation steps') + + + # PatchTST + parser.add_argument('--fc_dropout', type=float, default=0.05, help='fully connected dropout') + parser.add_argument('--head_dropout', type=float, default=0.0, help='head dropout') + parser.add_argument('--patch_len', type=int, default=16, help='patch length') + parser.add_argument('--padding_patch', default='end', help='None: None; end: padding on the end') + parser.add_argument('--revin', type=int, default=1, help='RevIN; True 1 False 0') + parser.add_argument('--affine', type=int, default=0, help='RevIN-affine; True 1 False 0') + parser.add_argument('--subtract_last', type=int, default=0, help='0: subtract mean; 1: subtract last') + parser.add_argument('--decomposition', type=int, default=0, help='decomposition; True 1 False 0') + parser.add_argument('--kernel_size', type=int, default=25, help='decomposition-kernel') + parser.add_argument('--individual', type=int, default=0, help='individual head; True 1 False 0') + args, unknown = parser.parse_known_args() return args diff --git a/ltsm/layers/Embed.py b/ltsm/layers/Embed.py new file mode 100644 index 0000000..a170bc7 --- /dev/null +++ b/ltsm/layers/Embed.py @@ -0,0 +1,165 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= '1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular', bias=False) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type='fixed', freq='h'): + super(TemporalEmbedding, self).__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding + if freq == 't': + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + + minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) + return self.dropout(x) + + +class DataEmbedding_wo_pos(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_pos, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + return self.dropout(x) + +class DataEmbedding_wo_pos_temp(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_pos_temp, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + return self.dropout(x) + +class DataEmbedding_wo_temp(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_temp, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x) \ No newline at end of file diff --git a/ltsm/layers/PatchTST_backbone.py b/ltsm/layers/PatchTST_backbone.py new file mode 100644 index 0000000..8ddef87 --- /dev/null +++ b/ltsm/layers/PatchTST_backbone.py @@ -0,0 +1,379 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +__all__ = ['PatchTST_backbone'] + +# Cell +from typing import Callable, Optional +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import numpy as np + +#from collections import OrderedDict +from ltsm.layers.PatchTST_layers import * +from ltsm.layers.RevIN import RevIN + +# Cell +class PatchTST_backbone(nn.Module): + def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024, + n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None, + d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto', + padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False, + pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None, + pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False, + verbose:bool=False, **kwargs): + + super().__init__() + + # RevIn + self.revin = revin + if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last) + + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch = padding_patch + patch_num = int((context_window - patch_len)/stride + 1) + if padding_patch == 'end': # can be modified to general case + self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) + patch_num += 1 + + # Backbone + self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len, + n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, + attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, + attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, + pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs) + + # Head + self.head_nf = d_model * patch_num + self.n_vars = c_in + self.pretrain_head = pretrain_head + self.head_type = head_type + self.individual = individual + + if self.pretrain_head: + self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs + elif head_type == 'flatten': + self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout) + + + def forward(self, z): # z: [bs x nvars x seq_len] + # norm + if self.revin: + z = z.permute(0,2,1) + z = self.revin_layer(z, 'norm') + z = z.permute(0,2,1) + + # do patching + if self.padding_patch == 'end': + z = self.padding_patch_layer(z) + z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len] + z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num] + + # model + z = self.backbone(z) # z: [bs x nvars x d_model x patch_num] + z = self.head(z) # z: [bs x nvars x target_window] + + # denorm + if self.revin: + z = z.permute(0,2,1) + z = self.revin_layer(z, 'denorm') + z = z.permute(0,2,1) + return z + + def create_pretrain_head(self, head_nf, vars, dropout): + return nn.Sequential(nn.Dropout(dropout), + nn.Conv1d(head_nf, vars, 1) + ) + + +class Flatten_Head(nn.Module): + def __init__(self, individual, n_vars, nf, target_window, head_dropout=0): + super().__init__() + + self.individual = individual + self.n_vars = n_vars + + if self.individual: + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.n_vars): + self.flattens.append(nn.Flatten(start_dim=-2)) + self.linears.append(nn.Linear(nf, target_window)) + self.dropouts.append(nn.Dropout(head_dropout)) + else: + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): # x: [bs x nvars x d_model x patch_num] + if self.individual: + x_out = [] + for i in range(self.n_vars): + z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num] + z = self.linears[i](z) # z: [bs x target_window] + z = self.dropouts[i](z) + x_out.append(z) + x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window] + else: + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x + + + + +class TSTiEncoder(nn.Module): #i means channel-independent + def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024, + n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None, + d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False, + key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False, + pe='zeros', learn_pe=True, verbose=False, **kwargs): + + + super().__init__() + + self.patch_num = patch_num + self.patch_len = patch_len + + # Input encoding + q_len = patch_num + self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space + self.seq_len = q_len + + # Positional encoding + self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + # Encoder + self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, + pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn) + + + def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num] + + n_vars = x.shape[1] + # Input encoding + x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len] + x = self.W_P(x) # x: [bs x nvars x patch_num x d_model] + + u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model] + u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model] + + # Encoder + z = self.encoder(u) # z: [bs * nvars x patch_num x d_model] + z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model] + z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num] + + return z + + + +# Cell +class TSTEncoder(nn.Module): + def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None, + norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu', + res_attention=False, n_layers=1, pre_norm=False, store_attn=False): + super().__init__() + + self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, + attn_dropout=attn_dropout, dropout=dropout, + activation=activation, res_attention=res_attention, + pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)]) + self.res_attention = res_attention + + def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + output = src + scores = None + if self.res_attention: + for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return output + else: + for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return output + + + +class TSTEncoderLayer(nn.Module): + def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False, + norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False): + super().__init__() + assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" + d_k = d_model // n_heads if d_k is None else d_k + d_v = d_model // n_heads if d_v is None else d_v + + # Multi-Head attention + self.res_attention = res_attention + self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention) + + # Add & Norm + self.dropout_attn = nn.Dropout(dropout) + if "batch" in norm.lower(): + self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) + else: + self.norm_attn = nn.LayerNorm(d_model) + + # Position-wise Feed-Forward + self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias), + get_activation_fn(activation), + nn.Dropout(dropout), + nn.Linear(d_ff, d_model, bias=bias)) + + # Add & Norm + self.dropout_ffn = nn.Dropout(dropout) + if "batch" in norm.lower(): + self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) + else: + self.norm_ffn = nn.LayerNorm(d_model) + + self.pre_norm = pre_norm + self.store_attn = store_attn + + + def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor: + + # Multi-Head attention sublayer + if self.pre_norm: + src = self.norm_attn(src) + ## Multi-Head attention + if self.res_attention: + src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + else: + src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + if self.store_attn: + self.attn = attn + ## Add & Norm + src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout + if not self.pre_norm: + src = self.norm_attn(src) + + # Feed-forward sublayer + if self.pre_norm: + src = self.norm_ffn(src) + ## Position-wise Feed-Forward + src2 = self.ff(src) + ## Add & Norm + src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout + if not self.pre_norm: + src = self.norm_ffn(src) + + if self.res_attention: + return src, scores + else: + return src + + + + +class _MultiheadAttention(nn.Module): + def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False): + """Multi Head Attention Layer + Input shape: + Q: [batch_size (bs) x max_q_len x d_model] + K, V: [batch_size (bs) x q_len x d_model] + mask: [q_len x q_len] + """ + super().__init__() + d_k = d_model // n_heads if d_k is None else d_k + d_v = d_model // n_heads if d_v is None else d_v + + self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v + + self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) + self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) + self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) + + # Scaled Dot-Product Attention (multiple heads) + self.res_attention = res_attention + self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa) + + # Poject output + self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)) + + + def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None, + key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + + bs = Q.size(0) + if K is None: K = Q + if V is None: V = Q + + # Linear (+ split in multiple heads) + q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k] + k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) + v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v] + + # Apply Scaled Dot-Product Attention (multiple heads) + if self.res_attention: + output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + else: + output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len] + + # back to the original inputs dimensions + output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v] + output = self.to_out(output) + + if self.res_attention: return output, attn_weights, attn_scores + else: return output, attn_weights + + +class _ScaledDotProductAttention(nn.Module): + r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer + (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets + by Lee et al, 2021)""" + + def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False): + super().__init__() + self.attn_dropout = nn.Dropout(attn_dropout) + self.res_attention = res_attention + head_dim = d_model // n_heads + self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa) + self.lsa = lsa + + def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + ''' + Input shape: + q : [bs x n_heads x max_q_len x d_k] + k : [bs x n_heads x d_k x seq_len] + v : [bs x n_heads x seq_len x d_v] + prev : [bs x n_heads x q_len x seq_len] + key_padding_mask: [bs x seq_len] + attn_mask : [1 x seq_len x seq_len] + Output shape: + output: [bs x n_heads x q_len x d_v] + attn : [bs x n_heads x q_len x seq_len] + scores : [bs x n_heads x q_len x seq_len] + ''' + + # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence + attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len] + + # Add pre-softmax attention scores from the previous layer (optional) + if prev is not None: attn_scores = attn_scores + prev + + # Attention mask (optional) + if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len + if attn_mask.dtype == torch.bool: + attn_scores.masked_fill_(attn_mask, -np.inf) + else: + attn_scores += attn_mask + + # Key padding mask (optional) + if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len) + attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf) + + # normalize the attention weights + attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len] + attn_weights = self.attn_dropout(attn_weights) + + # compute the new values given the attention weights + output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v] + + if self.res_attention: return output, attn_weights, attn_scores + else: return output, attn_weights \ No newline at end of file diff --git a/ltsm/layers/PatchTST_layers.py b/ltsm/layers/PatchTST_layers.py new file mode 100644 index 0000000..d168313 --- /dev/null +++ b/ltsm/layers/PatchTST_layers.py @@ -0,0 +1,122 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding'] + +import torch +from torch import nn +import math + +class Transpose(nn.Module): + def __init__(self, *dims, contiguous=False): + super().__init__() + self.dims, self.contiguous = dims, contiguous + def forward(self, x): + if self.contiguous: return x.transpose(*self.dims).contiguous() + else: return x.transpose(*self.dims) + + +def get_activation_fn(activation): + if callable(activation): return activation() + elif activation.lower() == "relu": return nn.ReLU() + elif activation.lower() == "gelu": return nn.GELU() + raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable') + + +# decomposition + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + + +# pos_encoding + +def PositionalEncoding(q_len, d_model, normalize=True): + pe = torch.zeros(q_len, d_model) + position = torch.arange(0, q_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + if normalize: + pe = pe - pe.mean() + pe = pe / (pe.std() * 10) + return pe + +SinCosPosEncoding = PositionalEncoding + +def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False): + x = .5 if exponential else 1 + i = 0 + for i in range(100): + cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 + pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose) + if abs(cpe.mean()) <= eps: break + elif cpe.mean() > eps: x += .001 + else: x -= .001 + i += 1 + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + return cpe + +def Coord1dPosEncoding(q_len, exponential=False, normalize=True): + cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1) + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + return cpe + +def positional_encoding(pe, learn_pe, q_len, d_model): + # Positional encoding + if pe == None: + W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe + nn.init.uniform_(W_pos, -0.02, 0.02) + learn_pe = False + elif pe == 'zero': + W_pos = torch.empty((q_len, 1)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'zeros': + W_pos = torch.empty((q_len, d_model)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'normal' or pe == 'gauss': + W_pos = torch.zeros((q_len, 1)) + torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) + elif pe == 'uniform': + W_pos = torch.zeros((q_len, 1)) + nn.init.uniform_(W_pos, a=0.0, b=0.1) + elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) + elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) + elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) + elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) + elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True) + else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ + 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") + return nn.Parameter(W_pos, requires_grad=learn_pe) \ No newline at end of file diff --git a/ltsm/layers/RevIN.py b/ltsm/layers/RevIN.py new file mode 100644 index 0000000..21402c1 --- /dev/null +++ b/ltsm/layers/RevIN.py @@ -0,0 +1,63 @@ +# code from https://github.com/ts-kim/RevIN, with minor modifications + +import torch +import torch.nn as nn + +class RevIN(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(RevIN, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + if self.affine: + self._init_params() + + def forward(self, x, mode:str): + if mode == 'norm': + self._get_statistics(x) + x = self._normalize(x) + elif mode == 'denorm': + x = self._denormalize(x) + else: raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim-1)) + if self.subtract_last: + self.last = x[:,-1,:].unsqueeze(1) + else: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + + def _normalize(self, x): + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps*self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x = x + self.mean + return x \ No newline at end of file diff --git a/ltsm/layers/SelfAttention_Family.py b/ltsm/layers/SelfAttention_Family.py new file mode 100644 index 0000000..80aeb49 --- /dev/null +++ b/ltsm/layers/SelfAttention_Family.py @@ -0,0 +1,167 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +import torch +import torch.nn as nn +import torch.nn.functional as F + +import matplotlib.pyplot as plt + +import numpy as np +import math +from math import sqrt +from ltsm.utils.masking import TriangularCausalMask, ProbMask +import os + + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return (V.contiguous(), A) + else: + return (V.contiguous(), None) + + +class ProbAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(ProbAttention, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q + K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + M_top, :] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() + else: # use mask + assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :] = torch.matmul(attn, V).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) + attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn + return (context_in, attns) + else: + return (context_in, None) + + def forward(self, queries, keys, values, attn_mask): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = queries.transpose(2, 1) + keys = keys.transpose(2, 1) + values = values.transpose(2, 1) + + U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) + u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) + + U_part = U_part if U_part < L_K else L_K + u = u if u < L_Q else L_Q + + scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) + + # add scale factor + scale = self.scale or 1. / sqrt(D) + if scale is not None: + scores_top = scores_top * scale + # get the context + context = self._get_initial_context(values, L_Q) + # update the context with selected top_k queries + context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) + + return context.contiguous(), attn + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, + keys, + values, + attn_mask + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn \ No newline at end of file diff --git a/ltsm/layers/Transformer_EncDec.py b/ltsm/layers/Transformer_EncDec.py new file mode 100644 index 0000000..95b4964 --- /dev/null +++ b/ltsm/layers/Transformer_EncDec.py @@ -0,0 +1,132 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + self.downConv = nn.Conv1d(in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=2, + padding_mode='circular') + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + def __init__(self, self_attention, cross_attention, d_model, d_ff=None, + dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask + )[0]) + x = self.norm1(x) + + x = x + self.dropout(self.cross_attention( + x, cross, cross, + attn_mask=cross_mask + )[0]) + + y = x = self.norm2(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + + +class Decoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None): + for layer in self.layers: + x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x \ No newline at end of file diff --git a/ltsm/models/DLinear.py b/ltsm/models/DLinear.py new file mode 100644 index 0000000..e0f2847 --- /dev/null +++ b/ltsm/models/DLinear.py @@ -0,0 +1,94 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from transformers import PreTrainedModel +from .ltsm_base import LTSMConfig + +class DLinear(PreTrainedModel): + """ + Decomposition-Linear + """ + config_class = LTSMConfig + + def __init__(self, config: LTSMConfig, **kwargs): + super().__init__(config) + self.seq_len = config.seq_len + self.pred_len = config.pred_len + + # Decompsition Kernel Size + kernel_size = 25 + self.decompsition = series_decomp(kernel_size) + self.individual = config.individual + self.channels = config.enc_in + + if self.individual: + self.Linear_Seasonal = nn.ModuleList() + self.Linear_Trend = nn.ModuleList() + + for i in range(self.channels): + self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len)) + self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len)) + + # Use this two lines if you want to visualize the weights + # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + else: + self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len) + self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len) + + # Use this two lines if you want to visualize the weights + # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + + def forward(self, x: Tensor): + # x: [Batch, Input length, Channel] + seasonal_init, trend_init = self.decompsition(x) + seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1) + if self.individual: + seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device) + trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device) + for i in range(self.channels): + seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:]) + trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:]) + else: + seasonal_output = self.Linear_Seasonal(seasonal_init) + trend_output = self.Linear_Trend(trend_init) + + x = seasonal_output + trend_output + return x.permute(0,2,1) # to [Batch, Output length, Channel] + + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean \ No newline at end of file diff --git a/ltsm/models/Informer.py b/ltsm/models/Informer.py new file mode 100644 index 0000000..5b6c6c8 --- /dev/null +++ b/ltsm/models/Informer.py @@ -0,0 +1,106 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from ltsm.utils.masking import TriangularCausalMask, ProbMask +from ltsm.layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer +from ltsm.layers.SelfAttention_Family import FullAttention, ProbAttention, AttentionLayer +from ltsm.layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp +import numpy as np +from transformers import PreTrainedModel +from .ltsm_base import LTSMConfig + +class Informer(PreTrainedModel): + """ + Informer with Propspare attention in O(LlogL) complexity + """ + config_class = LTSMConfig + + def __init__(self, config: LTSMConfig, **kwargs): + super().__init__(config) + self.pred_len = config.pred_len + self.output_attention = config.output_attention + + # Embedding + if config.embed_type == 0: + self.enc_embedding = DataEmbedding(config.enc_in, config.d_model, config.embed, config.freq, + config.dropout) + self.dec_embedding = DataEmbedding(config.dec_in, config.d_model, config.embed, config.freq, + config.dropout) + elif config.embed_type == 1: + self.enc_embedding = DataEmbedding(config.enc_in, config.d_model, config.embed, config.freq, + config.dropout) + self.dec_embedding = DataEmbedding(config.dec_in, config.d_model, config.embed, config.freq, + config.dropout) + elif config.embed_type == 2: + self.enc_embedding = DataEmbedding_wo_pos(config.enc_in, config.d_model, config.embed, config.freq, + config.dropout) + self.dec_embedding = DataEmbedding_wo_pos(config.dec_in, config.d_model, config.embed, config.freq, + config.dropout) + + elif config.embed_type == 3: + self.enc_embedding = DataEmbedding_wo_temp(config.enc_in, config.d_model, config.embed, config.freq, + config.dropout) + self.dec_embedding = DataEmbedding_wo_temp(config.dec_in, config.d_model, config.embed, config.freq, + config.dropout) + elif config.embed_type == 4: + self.enc_embedding = DataEmbedding_wo_pos_temp(config.enc_in, config.d_model, config.embed, config.freq, + config.dropout) + self.dec_embedding = DataEmbedding_wo_pos_temp(config.dec_in, config.d_model, config.embed, config.freq, + config.dropout) + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + ProbAttention(False, config.factor, attention_dropout=config.dropout, + output_attention=config.output_attention), + config.d_model, config.n_heads), + config.d_model, + config.d_ff, + dropout=config.dropout, + activation=config.activation + ) for l in range(config.e_layers) + ], + [ + ConvLayer( + config.d_model + ) for l in range(config.e_layers - 1) + ] if config.distil else None, + norm_layer=torch.nn.LayerNorm(config.d_model) + ) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer( + ProbAttention(True, config.factor, attention_dropout=config.dropout, output_attention=False), + config.d_model, config.n_heads), + AttentionLayer( + ProbAttention(False, config.factor, attention_dropout=config.dropout, output_attention=False), + config.d_model, config.n_heads), + config.d_model, + config.d_ff, + dropout=config.dropout, + activation=config.activation, + ) + for l in range(config.d_layers) + ], + norm_layer=torch.nn.LayerNorm(config.d_model), + projection=nn.Linear(config.d_model, config.c_out, bias=True) + ) + + def forward(self, x_enc: Tensor, x_mark_enc: Tensor, x_dec: Tensor, x_mark_dec: Tensor, + enc_self_mask: Tensor=None, dec_self_mask: Tensor=None, dec_enc_mask: Tensor=None): + + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(x_dec, x_mark_dec) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + + if self.output_attention: + return dec_out[:, -self.pred_len:, :], attns + else: + return dec_out[:, -self.pred_len:, :] # [B, L, D] \ No newline at end of file diff --git a/ltsm/models/PatchTST.py b/ltsm/models/PatchTST.py new file mode 100644 index 0000000..971a773 --- /dev/null +++ b/ltsm/models/PatchTST.py @@ -0,0 +1,51 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +import torch +from torch import Tensor + +from .ltsm_base import LTSMConfig +from ltsm.layers.PatchTST_backbone import PatchTST_backbone +from ltsm.layers.PatchTST_layers import series_decomp +from transformers import PreTrainedModel + +class PatchTST(PreTrainedModel): + config_class = LTSMConfig + + def __init__(self, config: LTSMConfig, **kwargs): + super().__init__(config) + + self.decomposition = config.decomposition + if self.decomposition: + self.decomp_module = series_decomp(config.kernel_size) + self.model_trend = PatchTST_backbone(config.enc_in, + config.seq_len, + config.pred_len, + config.patch_len, + config.stride, + **kwargs) + self.model_res = PatchTST_backbone(config.enc_in, + config.seq_len, + config.pred_len, + config.patch_len, + config.stride, + **kwargs) + else: + self.model = PatchTST_backbone(config.enc_in, + config.seq_len, + config.pred_len, + config.patch_len, + config.stride, + **kwargs) + + def forward(self, x: Tensor): + if self.decomposition: + res_init, trend_init = self.decomp_module(x) + res_init, trend_init = res_init.permute(0, 2, 1), trend_init.permute(0, 2, 1) # [Batch, Channel, Input length] + res = self.model_res(res_init) + trend = self.model_trend(trend_init) + x = res + trend + x = x.permute(0, 2, 1) # [Batch, Input length, Channel] + else: + x = x.permute(0, 2, 1) # [Batch, Channel, Input length] + x = self.model(x) + x = x.permute(0, 2, 1) # [Batch, Input length, Channel] + return x \ No newline at end of file diff --git a/ltsm/models/utils.py b/ltsm/models/utils.py index 93dc1c5..71e1d67 100644 --- a/ltsm/models/utils.py +++ b/ltsm/models/utils.py @@ -133,14 +133,21 @@ def get_model(config): elif config.model == 'LTSM_Tokenizer': from .ltsm_ts_tokenizer import LTSM_Tokenizer model = LTSM_Tokenizer(config) - else: + elif config.model == 'LTSM': from .ltsm_stat_model import LTSM if config.local_pretrain == "None": model = LTSM(config) else: model_config = PretrainedConfig.from_pretrained(config.local_pretrain) model = LTSM.from_pretrained(config.local_pretrain, model_config) - - - return model - + elif config.model == 'PatchTST': + from .PatchTST import PatchTST + model = PatchTST(config) + elif config.model == 'DLinear': + from .DLinear import DLinear + model = DLinear(config) + elif config.model == 'Informer': + from .Informer import Informer + model = Informer(config) + + return model \ No newline at end of file diff --git a/ltsm/utils/masking.py b/ltsm/utils/masking.py new file mode 100644 index 0000000..d238ec1 --- /dev/null +++ b/ltsm/utils/masking.py @@ -0,0 +1,26 @@ +# code from https://github.com/yuqinie98/PatchTST, with minor modifications +import torch + +class TriangularCausalMask(): + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + + +class ProbMask(): + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask \ No newline at end of file diff --git a/tests/data_pipeline/data_pipeline_test.py b/tests/data_pipeline/data_pipeline_test.py index b926332..2a3a13f 100644 --- a/tests/data_pipeline/data_pipeline_test.py +++ b/tests/data_pipeline/data_pipeline_test.py @@ -12,6 +12,7 @@ def mock_args(): #Fixture for creating mock arguments arg_dict = { + 'model': 'LTSM', 'data_path':'./datasets', 'prompt_data_path':'./prompt_bank', 'output_dir': './output', diff --git a/tests/model/DLinear_test.py b/tests/model/DLinear_test.py new file mode 100644 index 0000000..cbdf674 --- /dev/null +++ b/tests/model/DLinear_test.py @@ -0,0 +1,61 @@ +import pytest +from ltsm.models import get_model, LTSMConfig +from transformers import PreTrainedModel +import torch +import numpy as np + +@pytest.fixture +def config(tmp_path): + data_path = tmp_path / "test.csv" + prompt_data_path = tmp_path / "prompt_normalize_split" + prompt_data_path.mkdir() + OUTPUT_PATH = data_path / "output" + + config = { + "data_path": str(data_path), + "model": "DLinear", + "model_name_or_path": "gpt2-medium", + "pred_len": 96, + "gradient_accumulation_steps": 64, + "test_data_path_list": [str(data_path)], + "prompt_data_path": str(prompt_data_path), + "enc_in": 1, + "seq_len": 336+133, # Equal to the sequence length + the length of prompt + "train_epochs": 1000, + "patience": 10, + "lradj": 'TST', + "pct_start": 0.2, + "freeze": 0, + "itr": 1, + "batch_size": 32, + "learning_rate": 1e-3, + "downsample_rate": 20, + "output_dir": str(OUTPUT_PATH), + "eval": 0, + "individual": 0, + } + return LTSMConfig(**config) + +def test_model_initialization(config): + model = get_model(config) + assert model is not None + assert isinstance(model, PreTrainedModel) + + +def test_parameter_count(config): + model = get_model(config) + param_count = sum([p.numel() for p in model.parameters() if p.requires_grad]) + + expected_param_count = 2*(config.seq_len*config.pred_len + config.pred_len) + + assert param_count == expected_param_count + +def test_forward_output_shape(config): + torch.set_default_dtype(torch.float64) + model = get_model(config) + batch_size = 32 + channel = 16 + input_length = config.seq_len + input = torch.tensor(np.zeros((batch_size, input_length, channel))) + output = model(input) + assert output.size() == torch.Size([batch_size, config.pred_len, channel]) \ No newline at end of file diff --git a/tests/model/Informer_test.py b/tests/model/Informer_test.py new file mode 100644 index 0000000..1b3395c --- /dev/null +++ b/tests/model/Informer_test.py @@ -0,0 +1,131 @@ +import pytest +from ltsm.models import get_model, LTSMConfig +from transformers import PreTrainedModel +import torch +import numpy as np + +@pytest.fixture +def config(tmp_path): + data_path = tmp_path / "test.csv" + prompt_data_path = tmp_path / "prompt_normalize_split" + prompt_data_path.mkdir() + OUTPUT_PATH = data_path / "output" + + config = { + "data_path": str(data_path), + "model": "Informer", + "model_name_or_path": "gpt2-medium", + "pred_len": 96, + "gradient_accumulation_steps": 64, + "test_data_path_list": [str(data_path)], + "prompt_data_path": str(prompt_data_path), + "enc_in": 1, + "e_layers": 3, + "d_layers": 1, + "n_heads": 16, + "d_model": 128, + "d_ff": 256, + "dropout": 0.2, + "fc_dropout": 0.2, + "head_dropout": 0, + "seq_len": 336+133, # Equal to the sequence length + the length of prompt + "patch_len": 16, + "stride": 8, + "des": 'Exp', + "train_epochs": 1000, + "patience": 10, + "lradj": 'TST', + "pct_start": 0.2, + "freeze": 0, + "itr": 1, + "batch_size": 32, + "learning_rate": 1e-3, + "downsample_rate": 20, + "output_dir": str(OUTPUT_PATH), + "eval": 0, + "fc_dropout": 0.05, + "head_dropout": 0.0, + "patch_len": 16, + "padding_patch": 'end', + "revin": 1, + "affine": 0, + "subtract_last": 0, + "decomposition": 0, + "kernel_size": 25, + "individual": 0, + "output_attention": 0, + "freq": "h", + "embed": "timeF", + "factor": 1, + "c_out": 862, + "distil": True, + "embed_type": 0, + "dec_in": 7, + "activation": "gelu" + } + return LTSMConfig(**config) + +def test_model_initialization(config): + model = get_model(config) + assert model is not None + assert isinstance(model, PreTrainedModel) + +def test_parameter_count(config): + model = get_model(config) + param_count = sum([p.numel() for p in model.parameters() if p.requires_grad]) + + # Encoder Embedding parameter count + expected_param_count = config.d_model*config.enc_in*3 + 4*config.d_model + + # Decoder Embedding parameter count + expected_param_count += config.d_model*config.dec_in*3 + 4*config.d_model + + # Encoder parameter count + # Encoder layer Conv + encoder_param_count = 2*config.d_model*config.d_ff + config.d_model + config.d_ff + # Encoder Layer Norm + encoder_param_count += 4*config.d_model + # Attention Layer + encoder_param_count += 4*(config.d_model*config.d_model + config.d_model) + # Multiply by number of encoder layers + encoder_param_count *= config.e_layers + + # Conv layer + encoder_param_count += (config.e_layers-1)*(config.d_model*config.d_model*3 + 3*config.d_model) + # Layer Norm + encoder_param_count += 2*config.d_model + + expected_param_count += encoder_param_count + + # Decoder layer parameter count + # Decoder Conv layers + decoder_param_count = 2*config.d_model*config.d_ff + config.d_model + config.d_ff + # Decoder Layer Norm + decoder_param_count += 6*config.d_model + # Attention Layer + decoder_param_count += 8*(config.d_model*config.d_model + config.d_model) + # Multiply by number of decoder layers + decoder_param_count *= config.d_layers + + # Layer Norm parameter count + decoder_param_count += 2*config.d_model + + # Projection layer parameter count + decoder_param_count += config.d_model*config.c_out+config.c_out + + expected_param_count += decoder_param_count + + assert param_count == expected_param_count + + +def test_forward_output_shape(config): + torch.set_default_dtype(torch.float64) + model = get_model(config) + batch_size = 32 + input_length = config.seq_len + input = torch.tensor(np.zeros((batch_size, input_length, config.enc_in))) + input_mark = torch.tensor(np.zeros((batch_size, input_length, 4))) + dec_inp = torch.tensor(np.zeros((batch_size, input_length, config.dec_in))) + dec_mark = torch.tensor(np.zeros((batch_size, input_length, 4))) + output = model(input, input_mark, dec_inp, dec_mark) + assert output.size() == torch.Size([batch_size, config.pred_len, config.c_out]) \ No newline at end of file diff --git a/tests/model/PatchTST_test.py b/tests/model/PatchTST_test.py new file mode 100644 index 0000000..010a194 --- /dev/null +++ b/tests/model/PatchTST_test.py @@ -0,0 +1,100 @@ +import pytest +from ltsm.models import get_model, LTSMConfig +from transformers import PreTrainedModel +import torch +import numpy as np + +@pytest.fixture +def config(tmp_path): + data_path = tmp_path / "test.csv" + prompt_data_path = tmp_path / "prompt_normalize_split" + prompt_data_path.mkdir() + OUTPUT_PATH = data_path / "output" + + config = { + "data_path": str(data_path), + "model": "PatchTST", + "model_name_or_path": "gpt2-medium", + "pred_len": 96, + "gradient_accumulation_steps": 64, + "test_data_path_list": [str(data_path)], + "prompt_data_path": str(prompt_data_path), + "enc_in": 1, + "e_layers": 3, + "n_heads": 16, + "d_model": 128, + "d_ff": 256, + "dropout": 0.2, + "fc_dropout": 0.2, + "head_dropout": 0, + "seq_len": 336+133, # Equal to the sequence length + the length of prompt + "patch_len": 16, + "stride": 8, + "des": 'Exp', + "train_epochs": 1000, + "patience": 10, + "lradj": 'TST', + "pct_start": 0.2, + "freeze": 0, + "itr": 1, + "batch_size": 32, + "learning_rate": 1e-3, + "downsample_rate": 20, + "output_dir": str(OUTPUT_PATH), + "eval": 0, + "fc_dropout": 0.05, + "head_dropout": 0.0, + "patch_len": 16, + "padding_patch": 'end', + "revin": 1, + "affine": 0, + "subtract_last": 0, + "decomposition": 0, + "kernel_size": 25, + "individual": 0, + } + return LTSMConfig(**config) + +def test_model_initialization(config): + model = get_model(config) + assert model is not None + assert isinstance(model, PreTrainedModel) + + +def test_parameter_count(config): + model = get_model(config) + param_count = sum([p.numel() for p in model.parameters() if p.requires_grad]) + + patch_num = int((config.seq_len - config.patch_len) / config.stride + 1) + # multi-head self-attention parameter count (W_Q, W_K, W_V, to_out) + expected_param_count = 4*(config.d_model * config.d_model + config.d_model) + # feed-forward nn parameter count + expected_param_count += 2*config.d_model*config.d_ff + config.d_model + config.d_ff + # layer norm parameter count + expected_param_count += 4*config.d_model + + # multiply by number of encoder layers + expected_param_count *= config.e_layers + + # Input encoding parameter count + expected_param_count += config.patch_len*config.d_model + config.d_model + + # Positional encoding parameter count + expected_param_count += patch_num*config.d_model + + # RevIn parameter count + expected_param_count += 2 + + # Flatten Head parameter count + expected_param_count += config.d_model*patch_num*config.pred_len + config.pred_len + + assert param_count == expected_param_count + +def test_forward_output_shape(config): + model = get_model(config) + batch_size = 32 + channel = 16 + input_length = config.seq_len + input = torch.tensor(np.zeros((batch_size, input_length, channel))).float() + output = model(input) + assert output.size() == torch.Size([batch_size, config.pred_len, channel]) \ No newline at end of file diff --git a/tests/test_scripts/train_patchtst_csv.sh b/tests/test_scripts/train_patchtst_csv.sh new file mode 100755 index 0000000..38ecf35 --- /dev/null +++ b/tests/test_scripts/train_patchtst_csv.sh @@ -0,0 +1,44 @@ +TRAIN="../../datasets/electricity/electricity.csv" +TEST="../../datasets/electricity/electricity.csv" +PROMPT="../../prompt_bank/prompt_data_normalize_split" + +epoch=1000 +downsample_rate=20 +freeze=0 +lr=1e-3 + +OUTPUT_PATH="output/patchtst_lr${lr}_loraFalse_down${downsample_rate}_freeze${freeze}_e${epoch}_pred${pred_len}/" +echo "Current OUTPUT_PATH: ${OUTPUT_PATH}" + +for pred_len in 96 192 336 720 +do + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main_ltsm.py \ + --data_path ${TRAIN} \ + --model PatchTST \ + --model_name_or_path gpt2-medium \ + --pred_len ${pred_len} \ + --gradient_accumulation_steps 64 \ + --test_data_path_list ${TEST} \ + --prompt_data_path ${PROMPT} \ + --enc_in 1 \ + --e_layers 3 \ + --n_heads 16 \ + --d_model 128 \ + --d_ff 256 \ + --dropout 0.2\ + --fc_dropout 0.2\ + --head_dropout 0\ + --seq_len 336\ + --patch_len 16\ + --stride 8\ + --des 'Exp' \ + --train_epochs ${epoch}\ + --patience 10\ + --lradj 'TST'\ + --pct_start 0.2\ + --freeze ${freeze} \ + --itr 1 --batch_size 32 --learning_rate ${lr}\ + --downsample_rate ${downsample_rate} \ + --output_dir ${OUTPUT_PATH}\ + --eval 0 +done \ No newline at end of file