Skip to content

Commit

Permalink
modify device variable
Browse files Browse the repository at this point in the history
  • Loading branch information
inoue0426 committed Jun 15, 2024
1 parent 3572e13 commit 3df466a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions bigcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,17 @@ def forward(self, data, x, adj, x_t, adj_t, clustering):
return res


def run_model(input_data, params=None, clustering=False, verbose=False, device=False):
def run_model(input_data, params=None, clustering=False, verbose=False, device=None):
"""Run model
input_data: gene expression matrix
params: hyperparameters
clustering: whether to add batch normalized data
"""

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {
"dropout1": 0.3,
"dropout2": 0.1,
Expand All @@ -114,8 +117,6 @@ def run_model(input_data, params=None, clustering=False, verbose=False, device=F
"optimizer": "Adam",
"clustering": True,
}
if device:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x, adj = get_data(input_data)
x_t, adj_t = get_data(input_data.T)
Expand Down

0 comments on commit 3df466a

Please sign in to comment.