-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 94f08d2
Showing
38 changed files
with
68,670 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
checkpoints/ | ||
dataset_csvs/ | ||
# src/*.sh | ||
|
||
tests/data/ | ||
*.pt | ||
*.pth | ||
*.bin | ||
.DS_Store | ||
*mgb* | ||
DEBUG* | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
sync.sh | ||
gpu1sync.sh | ||
.idea | ||
**/._* | ||
**/*DS_* | ||
**.jsonl | ||
src/sbatch | ||
src/misc | ||
.vscode | ||
src/debug | ||
core.* |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
include conch/open_clip_custom/model_configs/*.json | ||
include conch/open_clip_custom/tokenizers/*.json |
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from PIL import Image | ||
import os | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
class TileClassificationDataset(Dataset): | ||
def __init__(self, | ||
df, | ||
data_source = None, | ||
img_transforms = None, | ||
index_col = 'image_name', | ||
subdir_col = None, | ||
target_col = 'label', | ||
target_transforms = None, | ||
label_map = None, | ||
dummy_size = 0): | ||
|
||
self.label_map = label_map | ||
self.data_source = data_source | ||
self.index_col = index_col | ||
self.target_col = target_col | ||
self.subdir_col = subdir_col | ||
self.img_transforms = img_transforms | ||
self.target_transforms = target_transforms | ||
self.data = df | ||
self.dummy_size = dummy_size | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def get_ids(self, ids): | ||
return self.data.loc[ids, self.index_col] | ||
|
||
def get_labels(self, ids): | ||
return self.data.loc[ids, self.target_col] | ||
|
||
def __getitem__(self, idx): | ||
img_name = self.get_ids(idx) | ||
label = self.get_labels(idx) | ||
|
||
if self.label_map is not None: | ||
label = self.label_map[label] | ||
if self.target_transforms is not None: | ||
label = self.target_transforms(label) | ||
|
||
if self.dummy_size > 0: | ||
img = torch.rand(3, self.dummy_size, self.dummy_size) | ||
else: | ||
if self.data_source is not None: | ||
if self.subdir_col is not None: | ||
subdir = self.data.loc[idx, self.subdir_col] | ||
if not isinstance(subdir, str): | ||
subdir = "" | ||
img_path = os.path.join(self.data_source, subdir, img_name) | ||
else: | ||
img_path = os.path.join(self.data_source, img_name) | ||
else: | ||
img_path = img_name | ||
img = Image.open(img_path).convert('RGB') | ||
if self.img_transforms is not None: | ||
img = self.img_transforms(img) | ||
|
||
return {'img': img, 'label': label} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
import pickle | ||
|
||
class AverageMeter(object): | ||
"""Computes and stores the average and current value""" | ||
def __init__(self, name = 'unk', fmt=':f'): | ||
self.name = name | ||
self.fmt = fmt | ||
self.reset() | ||
|
||
def reset(self): | ||
self.val = 0 | ||
self.avg = 0 | ||
self.sum = 0 | ||
self.count = 0 | ||
|
||
def update(self, val, n=1): | ||
self.val = val | ||
self.sum += val * n | ||
self.count += n | ||
self.avg = self.sum / self.count | ||
|
||
def __str__(self): | ||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | ||
return fmtstr.format(**self.__dict__) | ||
|
||
def save_pkl(filename, save_object): | ||
writer = open(filename,'wb') | ||
pickle.dump(save_object, writer) | ||
writer.close() | ||
|
||
def merge_dict(main_dict, new_dict, value_fn = None): | ||
""" | ||
Merge new_dict into main_dict. If a key exists in both dicts, the values are appended. | ||
Else, the key-value pair is added. | ||
If value_fn is not None, it is applied to each item in each value in new_dict before merging. | ||
Args: | ||
main_dict: main dict | ||
new_dict: new dict | ||
value_fn: function to apply to each item in each value in new_dict before merging | ||
""" | ||
if value_fn is None: | ||
value_fn = lambda x: x | ||
for key, value in new_dict.items(): | ||
if not isinstance(value, list): | ||
value = [value] | ||
value = [value_fn(v) for v in value] | ||
if key in main_dict: | ||
main_dict[key] = main_dict[key] + value | ||
else: | ||
main_dict[key] = value | ||
return main_dict | ||
|
||
def aggregate_array(arr, agg): | ||
arr = np.array(arr) | ||
if agg == 'mean': | ||
return arr.mean() | ||
elif agg == 'std': | ||
return arr.std() | ||
elif agg == 'median': | ||
return np.median(arr) | ||
elif agg == 'max': | ||
return arr.max() | ||
elif agg == 'min': | ||
return arr.min() | ||
elif agg == 'sum': | ||
return arr.sum() | ||
else: | ||
raise NotImplementedError | ||
|
Oops, something went wrong.