-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from joshhan619/ltsm-stack
Baseline model implementation and unit tests: code looks good to me
- Loading branch information
Showing
17 changed files
with
1,692 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# code from https://github.com/yuqinie98/PatchTST, with minor modifications | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.utils import weight_norm | ||
import math | ||
|
||
|
||
class PositionalEmbedding(nn.Module): | ||
def __init__(self, d_model, max_len=5000): | ||
super(PositionalEmbedding, self).__init__() | ||
# Compute the positional encodings once in log space. | ||
pe = torch.zeros(max_len, d_model).float() | ||
pe.require_grad = False | ||
|
||
position = torch.arange(0, max_len).float().unsqueeze(1) | ||
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() | ||
|
||
pe[:, 0::2] = torch.sin(position * div_term) | ||
pe[:, 1::2] = torch.cos(position * div_term) | ||
|
||
pe = pe.unsqueeze(0) | ||
self.register_buffer('pe', pe) | ||
|
||
def forward(self, x): | ||
return self.pe[:, :x.size(1)] | ||
|
||
|
||
class TokenEmbedding(nn.Module): | ||
def __init__(self, c_in, d_model): | ||
super(TokenEmbedding, self).__init__() | ||
padding = 1 if torch.__version__ >= '1.5.0' else 2 | ||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, | ||
kernel_size=3, padding=padding, padding_mode='circular', bias=False) | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv1d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') | ||
|
||
def forward(self, x): | ||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) | ||
return x | ||
|
||
|
||
class FixedEmbedding(nn.Module): | ||
def __init__(self, c_in, d_model): | ||
super(FixedEmbedding, self).__init__() | ||
|
||
w = torch.zeros(c_in, d_model).float() | ||
w.require_grad = False | ||
|
||
position = torch.arange(0, c_in).float().unsqueeze(1) | ||
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() | ||
|
||
w[:, 0::2] = torch.sin(position * div_term) | ||
w[:, 1::2] = torch.cos(position * div_term) | ||
|
||
self.emb = nn.Embedding(c_in, d_model) | ||
self.emb.weight = nn.Parameter(w, requires_grad=False) | ||
|
||
def forward(self, x): | ||
return self.emb(x).detach() | ||
|
||
|
||
class TemporalEmbedding(nn.Module): | ||
def __init__(self, d_model, embed_type='fixed', freq='h'): | ||
super(TemporalEmbedding, self).__init__() | ||
|
||
minute_size = 4 | ||
hour_size = 24 | ||
weekday_size = 7 | ||
day_size = 32 | ||
month_size = 13 | ||
|
||
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding | ||
if freq == 't': | ||
self.minute_embed = Embed(minute_size, d_model) | ||
self.hour_embed = Embed(hour_size, d_model) | ||
self.weekday_embed = Embed(weekday_size, d_model) | ||
self.day_embed = Embed(day_size, d_model) | ||
self.month_embed = Embed(month_size, d_model) | ||
|
||
def forward(self, x): | ||
x = x.long() | ||
|
||
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. | ||
hour_x = self.hour_embed(x[:, :, 3]) | ||
weekday_x = self.weekday_embed(x[:, :, 2]) | ||
day_x = self.day_embed(x[:, :, 1]) | ||
month_x = self.month_embed(x[:, :, 0]) | ||
|
||
return hour_x + weekday_x + day_x + month_x + minute_x | ||
|
||
|
||
class TimeFeatureEmbedding(nn.Module): | ||
def __init__(self, d_model, embed_type='timeF', freq='h'): | ||
super(TimeFeatureEmbedding, self).__init__() | ||
|
||
freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} | ||
d_inp = freq_map[freq] | ||
self.embed = nn.Linear(d_inp, d_model, bias=False) | ||
|
||
def forward(self, x): | ||
return self.embed(x) | ||
|
||
|
||
class DataEmbedding(nn.Module): | ||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): | ||
super(DataEmbedding, self).__init__() | ||
|
||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) | ||
self.position_embedding = PositionalEmbedding(d_model=d_model) | ||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, | ||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( | ||
d_model=d_model, embed_type=embed_type, freq=freq) | ||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
def forward(self, x, x_mark): | ||
x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) | ||
return self.dropout(x) | ||
|
||
|
||
class DataEmbedding_wo_pos(nn.Module): | ||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): | ||
super(DataEmbedding_wo_pos, self).__init__() | ||
|
||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) | ||
self.position_embedding = PositionalEmbedding(d_model=d_model) | ||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, | ||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( | ||
d_model=d_model, embed_type=embed_type, freq=freq) | ||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
def forward(self, x, x_mark): | ||
x = self.value_embedding(x) + self.temporal_embedding(x_mark) | ||
return self.dropout(x) | ||
|
||
class DataEmbedding_wo_pos_temp(nn.Module): | ||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): | ||
super(DataEmbedding_wo_pos_temp, self).__init__() | ||
|
||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) | ||
self.position_embedding = PositionalEmbedding(d_model=d_model) | ||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, | ||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( | ||
d_model=d_model, embed_type=embed_type, freq=freq) | ||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
def forward(self, x, x_mark): | ||
x = self.value_embedding(x) | ||
return self.dropout(x) | ||
|
||
class DataEmbedding_wo_temp(nn.Module): | ||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): | ||
super(DataEmbedding_wo_temp, self).__init__() | ||
|
||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) | ||
self.position_embedding = PositionalEmbedding(d_model=d_model) | ||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, | ||
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( | ||
d_model=d_model, embed_type=embed_type, freq=freq) | ||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
def forward(self, x, x_mark): | ||
x = self.value_embedding(x) + self.position_embedding(x) | ||
return self.dropout(x) |
Oops, something went wrong.