Skip to content

Commit

Permalink
fix: pytest errors with double and float tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
joshhan619 committed Dec 3, 2024
1 parent aaed220 commit 56787cf
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ltsm/data_pipeline/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_args():
parser.add_argument('--pretrain', type=int, default=1, help='is pretrain')
parser.add_argument('--local_pretrain', type=str, default="None", help='local pretrain weight')
parser.add_argument('--freeze', type=int, default=1, help='is model weight frozen')
parser.add_argument('--model', type=str, default='model', help='model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer]')
parser.add_argument('--model', type=str, default='model', help='model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer, DLinear, PatchTST, Informer]')
parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--tmax', type=int, default=10, help='tmax')
parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
Expand Down
1 change: 1 addition & 0 deletions tests/data_pipeline/data_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def mock_args():
#Fixture for creating mock arguments
arg_dict = {
'model': 'LTSM',
'data_path':'./datasets',
'prompt_data_path':'./prompt_bank',
'output_dir': './output',
Expand Down
3 changes: 2 additions & 1 deletion tests/model/DLinear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ def test_parameter_count(config):
assert param_count == expected_param_count

def test_forward_output_shape(config):
torch.set_default_dtype(torch.float64)
model = get_model(config)
batch_size = 32
channel = 16
input_length = config.seq_len
input = torch.tensor(np.zeros((batch_size, input_length, channel))).float()
input = torch.tensor(np.zeros((batch_size, input_length, channel)))
output = model(input)
assert output.size() == torch.Size([batch_size, config.pred_len, channel])
9 changes: 5 additions & 4 deletions tests/model/Informer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,13 @@ def test_parameter_count(config):


def test_forward_output_shape(config):
torch.set_default_dtype(torch.float64)
model = get_model(config)
batch_size = 32
input_length = config.seq_len
input = torch.tensor(np.zeros((batch_size, input_length, config.enc_in))).float()
input_mark = torch.tensor(np.zeros((batch_size, input_length, 4))).float()
dec_inp = torch.tensor(np.zeros((batch_size, input_length, config.dec_in))).float()
dec_mark = torch.tensor(np.zeros((batch_size, input_length, 4))).float()
input = torch.tensor(np.zeros((batch_size, input_length, config.enc_in)))
input_mark = torch.tensor(np.zeros((batch_size, input_length, 4)))
dec_inp = torch.tensor(np.zeros((batch_size, input_length, config.dec_in)))
dec_mark = torch.tensor(np.zeros((batch_size, input_length, 4)))
output = model(input, input_mark, dec_inp, dec_mark)
assert output.size() == torch.Size([batch_size, config.pred_len, config.c_out])

0 comments on commit 56787cf

Please sign in to comment.