diff --git a/multiviewdata/torchdatasets/mediamill.py b/multiviewdata/torchdatasets/mediamill.py index 850edc4..9b248d7 100644 --- a/multiviewdata/torchdatasets/mediamill.py +++ b/multiviewdata/torchdatasets/mediamill.py @@ -46,7 +46,6 @@ def convert_target(target_str): train_file, valid_file, test_file = [os.path.join(root, 'mediamill_' + ds + '.libsvm') for ds in ['train', 'valid', 'test']] - print() @property def raw_folder(self) -> str: diff --git a/multiviewdata/torchdatasets/twitter.py b/multiviewdata/torchdatasets/twitter.py index ef0cb98..7ff2167 100644 --- a/multiviewdata/torchdatasets/twitter.py +++ b/multiviewdata/torchdatasets/twitter.py @@ -109,7 +109,7 @@ def raw_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__, "raw") def __getitem__(self, index): - batch = {"index": index.astype(np.float32)} + batch = {"index": index} batch["userid"] = self.ids[index].astype(np.float32) batch["views"] = [view[index].astype(np.float32) for view in self.views] return batch diff --git a/multiviewdata/torchdatasets/wiw.py b/multiviewdata/torchdatasets/wiw.py index 9f55866..f1bd276 100644 --- a/multiviewdata/torchdatasets/wiw.py +++ b/multiviewdata/torchdatasets/wiw.py @@ -71,7 +71,7 @@ def __len__(self): return len(self.dataset) def __getitem__(self, index): - batch = {"index": index.astype(np.float32)} + batch = {"index": index} batch["views"] = [ self.dataset["%06d" % index][feat + "_feats"][()].astype(np.float32) for feat in self.feats diff --git a/multiviewdata/torchdatasets/xrmb.py b/multiviewdata/torchdatasets/xrmb.py index e64db31..4dd3e7c 100644 --- a/multiviewdata/torchdatasets/xrmb.py +++ b/multiviewdata/torchdatasets/xrmb.py @@ -84,7 +84,7 @@ def __getitem__(self, index): self.dataset["view_1"][index].astype(np.float32), self.dataset["view_2"][index].astype(np.float32), ), - "index": index.astype(np.float32), + "index": index, } def _check_exists(self) -> bool: