Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: sampler should be an instance of torch.utils.data.Sampler, #3

Open
bbrattoli opened this issue Mar 11, 2019 · 5 comments
Open

Comments

@bbrattoli
Copy link

Dear vadimkantorov,

thank you for your publishing this nice repo, very well written.
I'm running
"python train.py --dataset cub2011 --model margin --base resnet50"
with pytorch 1.0.1 and pythorn 3.6 but it crushes with the error

Traceback (most recent call last):
File "train.py", line 71, in
loader_train = torch.utils.data.DataLoader(dataset_train, sampler = adapt_sampler(opts.batch, dataset_train, opts.sampler), num_workers = opts.threads, batch_size = opts.batch, drop_last = True, pin_memory = True)
File "/export/home/bbrattol/anaconda2/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 805, in init
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
File "/export/home/bbrattol/anaconda2/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 146, in init
.format(sampler))
ValueError: sampler should be an instance of torch.utils.data.Sampler, but got sampler=<main. object at 0x7f199b08a9b0>

I guess it has something to do with the new pytorch version. Could you help me to make it run correctly?

Thanks

@vadimkantorov
Copy link
Owner

Could you try to modify:

adapt_sampler = lambda batch, dataset, sampler, **kwargs: type('', (), dict(__len__ = dataset.__len__, __iter__ = lambda _: itertools.chain.from_iterable(sampler(batch, dataset, **kwargs))))()

to read instead:

adapt_sampler = lambda batch, dataset, sampler, **kwargs: type('', (torch.utils.data.Sampler,), dict(__len__ = dataset.__len__, __iter__ = lambda _: itertools.chain.from_iterable(sampler(batch, dataset, **kwargs))))()

?

Please let me know if it works and don't hesitate to send a PR.

@bbrattoli
Copy link
Author

I get this error now

Traceback (most recent call last):
File "train.py", line 77, in
loader_train = torch.utils.data.DataLoader(dataset_train, sampler = adapt_sampler(opts.batch, dataset_train, opts.sampler), num_workers = opts.threads, batch_size = opts.batch, drop_last = True, pin_memory = True)
File "train.py", line 71, in
dict(len = dataset.len, iter = lambda _: itertools.chain.from_iterable(
TypeError: type.new() argument 2 must be tuple, not type

I know it's just a missing parameter but I don't understand what's happening in this piece of code, so please help! :D

@vadimkantorov
Copy link
Owner

vadimkantorov commented Mar 11, 2019

Just checking, are you using (torch.utils.data.Sampler,) and not (torch.utils.data.Sampler)? (the difference is the comma, but it's important)

@bbrattoli
Copy link
Author

It gives me the same error in both cases

@vadimkantorov
Copy link
Owner

vadimkantorov commented Mar 12, 2019

Sorry @bbrattoli, don't have time to look at this in detail these days. I'll update here if I check what's going on. Meanwhile, the way to go is to define yourself a Sampler subclass instead of my hacky adapt_sampler dynamic class creation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants