Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #132 from Borda/ci/precommit
Browse files Browse the repository at this point in the history
ci: update precommit
  • Loading branch information
justusschock authored Feb 22, 2023
2 parents c82dc00 + 8bd22f1 commit 3d459fb
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 73 deletions.
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ repos:
- id: check-docstring-first
- id: detect-private-key

#- repo: https://github.com/asottile/pyupgrade
# rev: v2.23.1
# hooks:
# - id: pyupgrade
# args: [--py36-plus]
# name: Upgrade code
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py38-plus]
name: Upgrade code

- repo: https://github.com/asottile/yesqa
rev: v1.4.0
hooks:
- id: yesqa

- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black
name: Format code

- repo: https://github.com/PyCQA/isort
rev: 5.11.4
rev: 5.12.0
hooks:
- id: isort
name: imports
Expand Down
23 changes: 11 additions & 12 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
Expand Down Expand Up @@ -298,7 +297,7 @@

nbsphinx_kernel_name = "python3"

github_path = r"https://github.com/%s/%s/blob/master/notebooks/{{ env.doc2path(env.docname, base=None) }}" % (
github_path = r"https://github.com/{}/{}/blob/master/notebooks/{{{{ env.doc2path(env.docname, base=None) }}}}".format(
github_user,
github_repo,
)
Expand All @@ -321,21 +320,21 @@
.. raw:: html
<div class="pytorch-call-to-action-links">
<a href="%s">
<a href="{}">
<div id="google-colab-link">
<img class="call-to-action-img" src="_static/images/pytorch-colab.svg"/>
<div class="call-to-action-desktop-view">Run in Google Colab</div>
<div class="call-to-action-mobile-view">Colab</div>
</div>
</a>
<a href="%s" download>
<a href="{}" download>
<div id="download-notebook-link">
<img class="call-to-action-notebook-img" src="_static/images/pytorch-download.svg"/>
<div class="call-to-action-desktop-view">Download Notebook</div>
<div class="call-to-action-mobile-view">Notebook</div>
</div>
</a>
<a href="%s">
<a href="{}">
<div id="github-view-link">
<img class="call-to-action-img" src="_static/images/pytorch-github.svg"/>
<div class="call-to-action-desktop-view">View on GitHub</div>
Expand All @@ -344,7 +343,7 @@
</a>
</div>
""" % (
""".format(
colab_path,
r"{{ env.doc2path(env.docname, base=None) }}",
github_path,
Expand All @@ -356,22 +355,22 @@
.. raw:: html
<div class="pytorch-call-to-action-links">
<a href="%s">
<a href="{}">
<div id="google-colab-link">
<img class="call-to-action-img" src="_static/images/pytorch-colab.svg"/>
<div class="call-to-action-desktop-view">Run in Google Colab</div>
<div class="call-to-action-mobile-view">Colab</div>
</div>
</a>
<a href="%s">
<a href="{}">
<div id="github-view-link">
<img class="call-to-action-img" src="_static/images/pytorch-github.svg"/>
<div class="call-to-action-desktop-view">View on GitHub</div>
<div class="call-to-action-mobile-view">GitHub</div>
</div>
</a>
</div>
""" % (
""".format(
colab_path,
github_path,
)
Expand All @@ -389,14 +388,14 @@
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule

MOCK_REQUIRE_PACKAGES = []
with open(os.path.join(PATH_ROOT, "requirements", "install.txt"), "r") as fp:
with open(os.path.join(PATH_ROOT, "requirements", "install.txt")) as fp:
for ln in fp.readlines():
found = [ln.index(ch) for ch in list(",=<>#") if ch in ln]
pkg = ln[: min(found)] if found else ln
if pkg.rstrip():
MOCK_REQUIRE_PACKAGES.append(pkg.rstrip())

with open(os.path.join(PATH_ROOT, "requirements", "install_async.txt"), "r") as fp:
with open(os.path.join(PATH_ROOT, "requirements", "install_async.txt")) as fp:
for ln in fp.readlines():
found = [ln.index(ch) for ch in list(",=<>#") if ch in ln]
pkg = ln[: min(found)] if found else ln
Expand Down Expand Up @@ -443,7 +442,7 @@ def find_source():
# do mapping from latest tags to master
branch = {"latest": "master", "stable": "master"}.get(branch, branch)
filename = "/".join([branch] + filename.split("/")[1:])
return "https://github.com/%s/%s/blob/%s" % (github_user, github_repo, filename)
return f"https://github.com/{github_user}/{github_repo}/blob/{filename}"


autodoc_member_order = "groupwise"
Expand Down
18 changes: 9 additions & 9 deletions rising/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
[c] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None)
)
break
except EnvironmentError:
except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
Expand All @@ -90,7 +90,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
return None, None
else:
if verbose:
print("unable to find command, tried %s" % (commands,))
print(f"unable to find command, tried {commands}")
return None, None
stdout = p.communicate()[0].strip()
if sys.version_info[0] >= 3:
Expand Down Expand Up @@ -127,7 +127,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
root = os.path.dirname(root) # up a level

if verbose:
print("Tried directories %s but none started with prefix %s" % (str(rootdirs), parentdir_prefix))
print(f"Tried directories {str(rootdirs)} but none started with prefix {parentdir_prefix}")
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")


Expand All @@ -140,7 +140,7 @@ def git_get_keywords(versionfile_abs):
# _version.py.
keywords = {}
try:
f = open(versionfile_abs, "r")
f = open(versionfile_abs)
for line in f.readlines():
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
Expand All @@ -155,7 +155,7 @@ def git_get_keywords(versionfile_abs):
if mo:
keywords["date"] = mo.group(1)
f.close()
except EnvironmentError:
except OSError:
pass
return keywords

Expand All @@ -179,11 +179,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")])
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
Expand All @@ -192,7 +192,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r"\d", r)])
tags = {r for r in refs if re.search(r"\d", r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
Expand Down Expand Up @@ -285,7 +285,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (full_tag, tag_prefix)
pieces["error"] = f"tag '{full_tag}' doesn't start with prefix '{tag_prefix}'"
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix) :]

Expand Down
2 changes: 1 addition & 1 deletion rising/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class AbstractMixin(object):
class AbstractMixin:
"""
This class implements an interface which handles non processed arguments.
Subclass all classes which mixin additional methods and attributes
Expand Down
8 changes: 4 additions & 4 deletions rising/loading/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Generator, Iterator, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Generator, Iterator, Mapping, Optional, Protocol, Sequence, Union, runtime_checkable

import torch
from threadpoolctl import threadpool_limits
Expand All @@ -11,7 +11,7 @@
from torch.utils.data._utils.collate import default_convert
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter as __MultiProcessingDataLoaderIter
from torch.utils.data.dataloader import _SingleProcessDataLoaderIter as __SingleProcessDataLoaderIter
from typing_extensions import Protocol, Self, runtime_checkable
from typing_extensions import Self

try:
import numpy as np
Expand Down Expand Up @@ -370,7 +370,7 @@ def patch_dataset(loader: DataLoader) -> Generator:
loader._DataLoader__initialized = True


class BatchTransformer(object):
class BatchTransformer:
"""
A callable wrapping the collate_fn to enable transformations on a
batch-basis.
Expand Down Expand Up @@ -424,7 +424,7 @@ def __call__(self, *args, **kwargs) -> Any:
return batch


class SampleTransformer(object):
class SampleTransformer:
"""
A dataset wrapper applying transforms to each retrieved sample of the
dataset
Expand Down
3 changes: 1 addition & 2 deletions rising/transforms/functional/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def random_inversion(
minv: float = 0.0,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if torch.rand((1)) < prob_inversion:
if torch.rand(1) < prob_inversion:
# Inversion of curve
out = maxv + minv - data
else:
Expand Down
5 changes: 1 addition & 4 deletions rising/transforms/functional/painting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
def local_pixel_shuffle(
data: torch.Tensor, n: int = -1, block_size: tuple = (0, 0, 0), rel_block_size: float = 0.1
) -> torch.Tensor:

batch_size, channels, img_rows, img_cols, img_deps = data.size()

if n < 0:
Expand Down Expand Up @@ -42,10 +41,9 @@ def local_pixel_shuffle(


def random_inpainting(data: torch.Tensor, n: int = 5, maxv: float = 1.0, minv: float = 0.0) -> torch.Tensor:

batch_size, channels, img_rows, img_cols, img_deps = data.size()

while n > 0 and torch.rand((1)) < 0.95:
while n > 0 and torch.rand(1) < 0.95:
for b in range(batch_size):
block_size_x = torch.randint(img_rows // 10, img_rows // 4, (1,))
block_size_y = torch.randint(img_rows // 10, img_rows // 4, (1,))
Expand All @@ -64,7 +62,6 @@ def random_inpainting(data: torch.Tensor, n: int = 5, maxv: float = 1.0, minv: f


def random_outpainting(data: torch.Tensor, maxv: float = 1.0, minv: float = 0.0) -> torch.Tensor:

batch_size, channels, img_rows, img_cols, img_deps = data.size()

out = torch.rand(data.size()) * (maxv - minv) + minv
Expand Down
8 changes: 4 additions & 4 deletions rising/transforms/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
padding_mode: str = "zero",
keys: Sequence = ("data",),
grad: bool = False,
**kwargs
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_conv(dim) -> Callable:
elif dim == 3:
return torch.nn.functional.conv3d
else:
raise TypeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim))
raise TypeError(f"Only 1, 2 and 3 dimensions are supported. Received {dim}.")

def create_kernel(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -130,7 +130,7 @@ def __init__(
padding_mode: str = "reflect",
keys: Sequence = ("data",),
grad: bool = False,
**kwargs
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -161,7 +161,7 @@ def __init__(
padding_mode=padding_mode,
keys=keys,
grad=grad,
**kwargs
**kwargs,
)

def create_kernel(self) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion tests/loading/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_numpy_collate_bool(self):
def test_numpy_collate_str(self):
# Should be a no-op
arr = ["a", "b", "c"]
self.assertTrue((arr == numpy_collate(arr)))
self.assertTrue(arr == numpy_collate(arr))

@unittest.skipIf(np is None, "numpy is not available")
def test_numpy_collate_ndarray(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/loading/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_pseudo_batch_dim_sequence(self):
self.assertTrue(torch.allclose(unbatched[idx], input_sequence[idx]))

def test_pseudo_batch_dim_custom_obj(self):
class Foo(object):
class Foo:
self.bar = 5.0

transformer = SampleTransformer(self.dset)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from rising import AbstractMixin


class Abstract(object):
class Abstract:
def __init__(self, **kwargs):
super().__init__()
self.abstract = True


class AbstractForward(object):
class AbstractForward:
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.abstract = True
Expand Down
Loading

0 comments on commit 3d459fb

Please sign in to comment.