diff --git a/openstl/modules/wast_modules.py b/openstl/modules/wast_modules.py index ca76b504..14b1d928 100644 --- a/openstl/modules/wast_modules.py +++ b/openstl/modules/wast_modules.py @@ -7,7 +7,7 @@ from timm.models._efficientnet_blocks import SqueezeExcite, InvertedResidual # version adaptation for PyTorch > 1.7.1 -IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1) +IS_HIGH_VERSION = tuple(map(lambda x: int(float(x)), torch.__version__.split('+')[0].split('.')[0:3])) > (1, 7, 1) if IS_HIGH_VERSION: import torch.fft