diff --git a/tests/data/io/test_batch_files.py b/tests/data/io/test_batch_files.py new file mode 100644 index 0000000..920693a --- /dev/null +++ b/tests/data/io/test_batch_files.py @@ -0,0 +1,34 @@ +from yann.data.io.batch_files import BatchWriter +import torch +import numpy as np +import pytest + + +def test_batch_write_kwargs(): + with BatchWriter('asfasd') as write: + for n in range(20): + write.batch( + ids=list(range(10)), + targets=np.random.randn(10), + outputs=torch.rand(10, 12), + ) + + +def test_batch_write_args(): + with BatchWriter('asfasd', names=('id', 'target', 'output')) as write: + for n in range(20): + write.batch( + list(range(10)), + np.random.randn(10), + torch.rand(10, 12), + ) + + with pytest.raises(ValueError, 'names and encoders must be same length'): + bw = BatchWriter('asfsd', names=(1,2,3), encoders=(1,2)) + + +def test_meta(): + BatchWriter(path=lambda x: 'foo', meta={ + 'checkpoint_id': 'asfads', + 'dataset': 'MNIST' + }) \ No newline at end of file diff --git a/tests/data/storage/__init__.py b/tests/data/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/data/storage/test_batch_files.py b/tests/data/storage/test_batch_files.py new file mode 100644 index 0000000..05a9e9d --- /dev/null +++ b/tests/data/storage/test_batch_files.py @@ -0,0 +1,66 @@ +from yann.data.storage.batch_files import BatchWriter, PartitionedBatchWriter, BatchReader +import numpy as np +import torch +import pathlib + +import yann + +def same_type(*items): + return all(type(x) == type(items[0]) for x in items) + + +def test_pickle(tmpdir: pathlib.Path): + path = tmpdir / 'batches.pkl' + w = BatchWriter(path, names=('ids', 'targets', 'outputs', 'paths')) + + batches = [] + + for i in range(10): + batches.append(( + list(range(10)), + torch.zeros(10, 12), + torch.rand(10, 12), + [f"{i}-{n}.jpg" for n in range(10)] + )) + + w.batch(*batches[-1]) + + w.close() + + assert path.exists() + # assert path.stat().st_size > 400 + + assert w.meta_path.exists() + + assert w.path == path + + loaded_batches = yann.load(w.path) + + assert len(loaded_batches) == 10 + + + + +def test_use_case(tmpdir): + model = torch.nn.Module() + + w = BatchWriter(tmpdir / 'MNIST-preds.pkl') + + iw = BatchWriter(tmpdir/'inputs.pkl') + + for inputs, targets in iw.through(yann.batches('MNIST', size=32, workers=10, transform=())): + preds = model(inputs) + w.batch( + targets=targets, + preds=preds + ) + w.close() + + processed = 0 + correct = 0 + r = BatchReader(w.path) + for batch in r.batches(): + processed += len(batch['targets']) + correct += sum(batch['targets'] == batch['preds']) + + diff --git a/tests/data/storage/test_parquet.py b/tests/data/storage/test_parquet.py new file mode 100644 index 0000000..fdb3a56 --- /dev/null +++ b/tests/data/storage/test_parquet.py @@ -0,0 +1,16 @@ +from yann.data.storage.parquet import BatchParquetFileWriter + +import torch + + + +def test_parquet_batch_writer(tmpdir): + path = tmpdir / 'test.parquet' + with BatchParquetFileWriter(path) as write: + for i in range(10): + write.batch( + ids=list(range(10)), + labels=torch.ones(10, 12) + ) + + assert path.exists() \ No newline at end of file diff --git a/yann/data/io/__init__.py b/yann/data/io/__init__.py index 685d951..1914e45 100644 --- a/yann/data/io/__init__.py +++ b/yann/data/io/__init__.py @@ -6,41 +6,57 @@ import csv import gzip from pathlib import Path - +import torch class Loader: """ - gs://bucket/file.th ./foo/**/*.jpg - - Args: - path: - - Returns: """ - def __call__(self, path, **kwargs): + + def __call__(self, path, format=None, deserialize=None, filesystem=None, **kwargs): path = Path(path) - if hasattr(self, path.suffix): - return getattr(self, path.suffix)(**kwargs) + format = format or path.suffix[1:] + if hasattr(self, format): + return getattr(self, format)(str(path), **kwargs) + raise ValueError(f'File format not supported ({format})') + + def th(self, path, **kwargs): + return torch.load(path, **kwargs) - def csv(self): - pass + def json(self, path, **kwargs): + return load_json(path, **kwargs) - def json(self): - pass + def pickle(self, path, **kwargs): + return load_pickle(path, **kwargs) - def jsonlines(self): - pass + pkl = pickle load = Loader() class Saver: - def __call__(self, x, path, **kwargs): - pass + def __call__( + self, x, path, format=None, serialize=None, filesystem=None, **kwargs + ): + path = Path(path) + format = format or path.suffix[1:] + if hasattr(self, format): + return getattr(self, format)(x, path, **kwargs) + raise ValueError(f'File format not supported ({format})') + + def th(self, x, path, **kwargs): + return torch.save(x, path, **kwargs) + + def json(self, x, path, **kwargs): + return save_json(x, path, **kwargs) + + def pickle(self, x, path, **kwargs): + return save_pickle(x, path, **kwargs) + + pkl = pickle save = Saver() @@ -95,6 +111,12 @@ def untar(path): tar.extractall() +def unzip(zip, dest): + import zipfile + with zipfile.ZipFile(zip, 'r') as f: + f.extractall(dest) + + def iter_csv(path, header=True, tuples=True, sep=',', quote='"', **kwargs): with open(path) as f: reader = csv.reader(f, delimiter=sep, quotechar=quote, **kwargs) @@ -122,4 +144,3 @@ def write_csv(data, path, header=None): writer.writerow(header) for row in data: writer.writerow(row) - diff --git a/yann/data/storage/batch_files.py b/yann/data/storage/batch_files.py new file mode 100644 index 0000000..80a24cd --- /dev/null +++ b/yann/data/storage/batch_files.py @@ -0,0 +1,183 @@ +from collections import defaultdict +import os.path +from pathlib import Path +import datetime +import yann + +import torch +from ...utils import fully_qulified_name, timestr +from ...utils.ids import memorable_id + +""" +TODO: + - pickle + - parquet + - numpy + - csv + - json lines + - hdf5 + - tfrecords + - lmdb +""" + + +class BatchWriter: + def __init__(self, path, encoders=None, names=None, meta=None): + self.path = path + + if isinstance(encoders, (list, tuple)): + if not names: + raise ValueError('Names must be provided if encoders are a tuple') + if len(encoders) != len(names): + raise ValueError('names and encoders must be the same length if provided as tuples') + encoders = dict(zip(names, encoders)) + + self.encoders = encoders or {} + self.names = names + self.writer = None + self.buffers = defaultdict(list) + + self.meta = meta or {} + self.meta['encoders'] = { + k: { + 'path': fully_qulified_name(v), + 'name': getattr(v, '__name__', None) + } for k, v in self.encoders.items() + } + self.meta['time_created'] = timestr() + self.meta['write_id'] = memorable_id() + self.meta['path'] = str(self.path) + self.save_meta() + + def encode_batch(self, *args, **kwargs): + if args: + items = zip(self.names, args) + else: + items = kwargs.items() + + data = {} + for k, v in items: + if self.encoders and k in self.encoders: + v = self.encoders[k][v] + elif torch.is_tensor(v): + v = v.detach().cpu() + + data[k] = v + return data + + def batch(self, *args, **kwargs): + data = self.encode_batch(*args, **kwargs) + for k, v in data.items(): + self.buffers[k].append(v) + + def through(self, batches): + for b in batches: + if isinstance(b, (tuple, list)): + self.batch(*b) + else: + self.batch(**b) + yield b + + def all(self, batches): + for b in batches: + if isinstance(b, (tuple, list)): + self.batch(*b) + else: + self.batch(**b) + + def collate(self, buffers): + return buffers + + def flush(self): + self._write() + self._wipe_buffers() + + @property + def meta_path(self) -> Path: + return Path(self.path).parent / f"writer-meta.json" + + def save_meta(self): + yann.save(self.meta, self.meta_path) + + def close(self): + self.flush() + if self.writer and hasattr(self.writer, 'close'): + self.writer.close() + + def __enter__(self): + return self + + def __exit__(self): + self.close() + + def _wipe_buffers(self): + for k in self.buffers: + self.buffers[k] = [] + + def _write(self): + Path(self.path).parent.mkdir(parents=True, exist_ok=True) + collated = self.collate(self.buffers) + self._save(dict(collated), self.path) + + def _save(self, data, path): + yann.save(data, path) + + def _num_buffered_batches(self): + return len(next(iter(self.buffers.values()))) + + +class BatchStreamWriter(BatchWriter): + pass + +class PartitionedBatchWriter(BatchWriter): + def __init__(self, path, batches_per_file=256, encoders=None, names=None, meta=None): + super().__init__(path, encoders=encoders, names=names, meta=meta) + + self.part = 0 + self.batches_per_file = batches_per_file + + def batch(self, *args, **kwargs): + super().batch(*args, **kwargs) + if self._num_buffered_batches() >= self.batches_per_file: + self.flush() + + def get_part_path(self, part): + if callable(self.path): + return self.path(part=part, batches=self.buffers) + elif '{' in self.path and '}' in self.path: + return self.path.format( + part=part, + time=datetime.datetime.utcnow() + ) + else: + name, ext = os.path.splitext(self.path) + return f"{name}-{part}{ext}" + + def _write(self): + path = self.get_part_path(self.part) + Path(path).parent.mkdir(parents=True, exist_ok=True) + collated = self.collate(self.buffers) + self._save(dict(collated), path) + + + +class BatchReader: + def __init__(self, path): + pass + + def batches(self): + pass + + def samples(self): + pass + + def __iter__(self): + return self.batches() + + +def writer() -> BatchWriter: + raise NotImplementedError() + + +def reader() -> BatchReader: + raise NotImplementedError() \ No newline at end of file diff --git a/yann/data/storage/parquet.py b/yann/data/storage/parquet.py index 2f1d87e..8e76e8e 100644 --- a/yann/data/storage/parquet.py +++ b/yann/data/storage/parquet.py @@ -2,6 +2,7 @@ import pyarrow as pa import pandas as pd +from .batch_files import BatchWriter def write_parquet(dest, data, columns=None, **kwargs): if isinstance(data, dict): @@ -20,4 +21,35 @@ def write_parquet(dest, data, columns=None, **kwargs): def read_parquet(): + pass + + +class BatchParquetFileWriter(BatchWriter): + def __init__(self, path, schema=None, encoders=None, names=None, meta=None, **writer_args): + super(BatchParquetFileWriter, self).__init__( + path=path, + batches_per_file=1, # writing to a parquet + encoders=encoders, + names=names, + meta=meta + ) + + self.path = path + self.schema = schema + self._writer_args = writer_args + + # will determine schema on first write if not provided + if self.schema: + self.writer= pq.ParquetWriter(path, self.schema, **self._writer_args) + + def _write(self): + df = pd.DataFrame(self.collate(self.buffers)) + table = pa.Table.from_pandas(df) + if self.writer is None: + self.schema = table.schema + self.writer = pq.ParquetWriter(self.path, self.schema, **self._writer_args) + self.writer.write_table(table) + + +class BatchParquetDatasetWriter(BatchWriter): pass \ No newline at end of file diff --git a/yann/modules/conv/attention.py b/yann/modules/conv/attention.py new file mode 100644 index 0000000..ae0ce6e --- /dev/null +++ b/yann/modules/conv/attention.py @@ -0,0 +1,26 @@ +from torch import nn + + +class EfficientChannelAttention(nn.Module): + """ + https://github.com/BangguWu/ECANet + """ + def __init__(self, kernel_size=3): + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + 1, 1, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + bias=False + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + x = self.pool(input) + x = self.conv(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + x = self.sigmoid(x) + + return input * x.expand_as(input) + +ECA = EfficientChannelAttention diff --git a/yann/modules/loss.py b/yann/modules/loss.py index 5fbed9b..b2caea4 100644 --- a/yann/modules/loss.py +++ b/yann/modules/loss.py @@ -194,5 +194,7 @@ def forward(self, *input, **kwargs): return _reduce(loss, reduction=self.reduction, reduce=self.reduce) +MultiTaskLoss = CombinedLoss + # class MultiTaskLoss(_Loss): # pass \ No newline at end of file diff --git a/yann/utils/decorators.py b/yann/utils/decorators.py index 1cf7e66..9e05af7 100644 --- a/yann/utils/decorators.py +++ b/yann/utils/decorators.py @@ -1,3 +1,5 @@ +from functools import wraps + class lazy: __slots__ = 'method', 'name' @@ -10,3 +12,53 @@ def __get__(self, obj, cls): val = self.method(obj) setattr(obj, self.name, val) return val + + +class RobustFunction: + def __init__(self, func, exceptions=Exception, default=None): + self.func = func + self.exceptions = exceptions + self.default = default + + def __call__(self, *args, **kwargs): + if not self.exceptions: + return self.func(*args, **kwargs) + try: + r = self.func(*args, **kwargs) + except self.exceptions: + r = self.default + return r + + +def robust(x=None, exceptions=Exception, default=None): + if callable(x): + return RobustFunction(func=x, exceptions=exceptions, default=default) + else: + def decorator(x): + return RobustFunction(func=x, exceptions=exceptions, default=default) + + return decorator + + +class FunctionTracker: + def __init__(self, func, sanitize=None): + self.func = func + self.history = [] + self.sanitize = sanitize + + def __call__(self, *args, **kwargs): + r = self.func(*args, **kwargs) + if self.sanitize: + r = self.sanitize(r) + self.history.append(r) + return r + + +def track(x=None, sanitize=None): + if callable(x): + return FunctionTracker(func=x) + else: + def decorator(x): + return FunctionTracker(func=x, sanitize=sanitize) + + return decorator \ No newline at end of file