diff --git a/bigcn.py b/bigcn.py index 9a7e35f..d96f906 100644 --- a/bigcn.py +++ b/bigcn.py @@ -84,7 +84,7 @@ def forward(self, data, x, adj, x_t, adj_t, clustering): return res -def run_model(input_data, params=None, clustering=False, verbose=False, device=False): +def run_model(input_data, params=None, clustering=False, verbose=False, device=None): """Run model input_data: gene expression matrix @@ -92,6 +92,9 @@ def run_model(input_data, params=None, clustering=False, verbose=False, device=F clustering: whether to add batch normalized data """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + params = { "dropout1": 0.3, "dropout2": 0.1, @@ -114,8 +117,6 @@ def run_model(input_data, params=None, clustering=False, verbose=False, device=F "optimizer": "Adam", "clustering": True, } - if device: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") x, adj = get_data(input_data) x_t, adj_t = get_data(input_data.T)