-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robodummy #57
base: main
Are you sure you want to change the base?
Robodummy #57
Changes from 14 commits
5b8c573
04e3fd9
73e4c47
2dec996
3d66cbc
7bf8f81
d157d9d
39fb936
52d0114
83316c2
8814e98
eb19bb3
8220b95
35af962
ae97dbd
3b8b7d7
84d61cb
28640d6
997c4eb
d355deb
a9cfecd
fd1eb97
cba7191
f265faf
fcce3f2
4daa134
6be7a82
f38692a
27c4f43
d1a1cdc
064c27c
4c6d9cb
3940c55
5da186d
f45d619
3db31e9
79a77da
6bd9d7e
055d08f
3eec9c8
8e907e3
c6ddc1d
617a4df
39a674d
9050f93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,6 @@ tiktorch/.idea | |
tiktorch/__pycache/ | ||
/#wrapper.py# | ||
/.#wrapper.py# | ||
.py~ | ||
.py~ | ||
*.nn | ||
*.hdf | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as f | ||
from sklearn.metrics import mean_squared_error | ||
import zipfile | ||
import h5py | ||
import z5py | ||
from z5py.converter import convert_from_h5 | ||
from torch.autograd import Variable | ||
from collections import OrderedDict | ||
import yaml | ||
import logging | ||
from tensorboardX import SummaryWriter | ||
from tiktorch.server import TikTorchServer | ||
from tiktorch.rpc import Client, Server, InprocConnConf | ||
from tiktorch.rpc_interface import INeuralNetworkAPI | ||
from tiktorch.types import NDArray, NDArrayBatch | ||
from mr_robot.utils import tile_image | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe sort the import statements a little don't mix import... and from... too much There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
patch_size = 16 | ||
img_dim = 32 | ||
|
||
|
||
class MrRobot: | ||
""" The robot class runs predictins on the model, and feeds the | ||
worst performing patch back for training. The order in which patches | ||
are feed back is determined by the 'strategy'. The robot can change | ||
strategies as training progresses. | ||
|
||
Args: | ||
path_to_config_file (string): path to the robot configuration file to | ||
load necessary variables | ||
strategy (Strategy object): strategy to follow (atleast intially) | ||
""" | ||
|
||
def __init__(self, path_to_config_file, strategy): | ||
# start the server | ||
self.new_server = TikTorchServer() | ||
self.strategy = strategy | ||
|
||
with open(path_to_config_file, mode="r") as f: | ||
self.base_config = yaml.load(f) | ||
|
||
self.data_file = z5py.File(self.base_config.pop("raw_data_path")) | ||
|
||
self.tile_indices = tile_image(self.base_config["training"]["training_shape"], patch_size) | ||
self.input_shape = list((self.base_config["training"]["training_shape"])) | ||
self.slicer = [slice(0, i) for i in self.input_shape] | ||
|
||
self.iterations_max = self.base_config.pop("max_robo_iterations") | ||
self.iterations_done = 0 | ||
self.tensorboard_writer = SummaryWriter() | ||
self.logger = logging.getLogger(__name__) | ||
plt.ion() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this? |
||
|
||
# def load_data(self): | ||
# self.f = z5py.File(self.base_config["cremi_data_dir"]) | ||
# self.logger("data file loaded") | ||
|
||
def _load_model(self): | ||
|
||
archive = zipfile.ZipFile(self.base_config["data_dir"]["path_to_zip"], "r") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add "path_to_folder" and ability to load model and binary_state from (unzipped) folder as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or yet, check if path_to_folder ends in ".zip" or alike... and then use zipfile... |
||
model = archive.read(self.base_config["data_dir"]["path_in_zip_to_model"]) | ||
binary_state = archive.read(self.base_config["data_dir"]["path_in_zip_to_state"]) | ||
|
||
# cleaning dictionary before passing to tiktorch | ||
self.base_config.pop("data_dir") | ||
|
||
self.new_server.load_model(self.base_config, model, binary_state, b"", ["cpu"]) | ||
self.logger.info("model loaded") | ||
|
||
def _resume(self): | ||
|
||
self.new_server.resume_training() | ||
# self.binary_state = self.new_server.get_model_state() | ||
self.logger.info("training resumed") | ||
|
||
def _predict(self): | ||
""" run prediction on the whole set of patches | ||
""" | ||
self.strategy.patched_data.clear() | ||
self.patch_id = dict() | ||
x = 0 | ||
|
||
for i in self.tile_indices: | ||
self.slicer[-1] = slice(i[0][1], i[1][1]) | ||
self.slicer[-2] = slice(i[0][0], i[1][0]) | ||
new_slicer = tuple(self.slicer) | ||
self.patch_id[i[0]] = x # map each slicer with its corresponding index | ||
x += 1 | ||
pred_output = self.new_server.forward(NDArray(self.data_file["volume"][new_slicer])) | ||
pred_output = pred_output.result() | ||
self.strategy._loss( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid calling 'private' methods externally There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does 'predict' need to call strategy.loss()? |
||
pred_output, self.data_file[self.base_config["labelled_data_path"]][new_slicer], new_slicer | ||
) | ||
|
||
self.logger.info("prediction run") | ||
|
||
def stop(self): | ||
""" function which determines when the robot should stop | ||
|
||
currently, it stops after robot has completed 'iterations_max' number of iterations | ||
""" | ||
|
||
if self.iterations_done > self.iterations_max: | ||
return False | ||
else: | ||
self.iterations_done += 1 | ||
return True | ||
|
||
def _run(self): | ||
""" Feed patches to tiktorch (add to the training data) | ||
|
||
The function fetches the patches in order decided by the strategy, | ||
removes it from the list of indices and feeds it to tiktorch | ||
""" | ||
while self.stop(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. while not self.stop() ? |
||
self._predict() | ||
|
||
# log average loss for all patches per iteration to tensorboard | ||
total_loss = sum([pair[0] for pair in self.strategy.patched_data]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion for better readability: |
||
avg = total_loss / float(len(self.strategy.patched_data)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternatively use: |
||
self.tensorboard_writer.add_scalar("avg_loss", avg, self.iterations_done) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to entagle MrRobot and a strategy futher, let the strategy compute the current loss mean and access it from withing MrRobot |
||
|
||
self.strategy.rearrange() | ||
slicer = self.strategy.get_next_patch() | ||
self.tile_indices.pop(self.patch_id[(slicer[-2].start, slicer[-1].start)]) | ||
self._add(slicer) | ||
self._resume() | ||
|
||
self.terminate() | ||
|
||
def _add(self, slicer): | ||
new_input = NDArray(self.data_file["volume"][slicer].astype(float), (slicer[-2].start, slicer[-1].start)) | ||
new_label = NDArray( | ||
self.data_file[self.base_config["labelled_data_path"]][slicer].astype(float), | ||
(slicer[-2].start, slicer[-1].start), | ||
) | ||
self.new_server.update_training_data(NDArrayBatch([new_input]), NDArrayBatch([new_label])) | ||
|
||
# annotate worst patch | ||
def dense_annotate(self, x, y, label, image): | ||
raise NotImplementedError() | ||
|
||
def terminate(self): | ||
self.tensorboard_writer.close() | ||
self.new_server.shutdown() | ||
|
||
|
||
class BaseStrategy: | ||
def __init__(self, path_to_config_file): | ||
with open(path_to_config_file, mode="r") as f: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.base_config = yaml.load(f) | ||
|
||
self.patched_data = [] | ||
self.loss_fn = self.base_config["training"]["loss_criterion_config"]["method"] | ||
self.logger = logging.getLogger(__name__) | ||
|
||
def _loss(self, pred_output, target, slicer): | ||
""" computes loss corresponding to the output and target according to | ||
the given loss function | ||
|
||
Args: | ||
predicted_output(np.ndarray) : output predicted by the model | ||
target(np.ndarray): ground truth | ||
loss_fn(string): loss metric | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: I went back and forth on this one already.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please update doc string... |
||
slicer(tuple): tuple of slice objects, one per dimension | ||
""" | ||
|
||
criterion_class = getattr(nn, self.loss_fn, None) | ||
assert criterion_class is not None, "Criterion {} not found.".format(method) | ||
criterion_class_obj = criterion_class() | ||
curr_loss = criterion_class_obj( | ||
torch.from_numpy(pred_output.as_numpy().astype(np.float32)), torch.from_numpy(target.astype(np.float32)) | ||
) | ||
self.patched_data.append((curr_loss, slicer)) | ||
|
||
def get_next_patch(self): | ||
raise NotImplementedError() | ||
|
||
def rearrange(self): | ||
raise NotADirectoryError() | ||
|
||
|
||
class Strategy1(BaseStrategy): | ||
""" This strategy sorts the patches in descending order of their loss | ||
|
||
Args: | ||
path_to_config_file (string): path to the configuration file for the robot | ||
""" | ||
|
||
def __init__(self, path_to_config_file): | ||
super().__init__(path_to_config_file) | ||
# self.patch_counter = -1 | ||
|
||
def rearrange(self): | ||
""" rearranges the patches in descending order of their loss | ||
""" | ||
self.patched_data.sort(reverse=True) | ||
|
||
def get_next_patch(self): | ||
""" Feeds patches to the robot in descending order of their loss | ||
""" | ||
|
||
# self.patch_counter += 1 | ||
return self.patched_data[0][1] | ||
|
||
|
||
class Strategy2(BaseStrategy): | ||
def __init__(): | ||
super().__init__() | ||
|
||
|
||
class Strategy3(BaseStrategy): | ||
def __init__(): | ||
super().__init__() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# base config for robot | ||
max_robo_iterations: 5 | ||
model_class_name: DUNet2D | ||
model_init_kwargs: {in_channels: 1, out_channels: 1} | ||
training: { | ||
training_shape: [1, 32, 32], | ||
batch_size: 1, | ||
loss_criterion_config: {"method": "MSELoss"}, | ||
optimizer_config: {"method": "Adam"}, | ||
num_iterations_done: 1, | ||
} | ||
validation: {} | ||
dry_run: {"skip": True, "shrinkage": [0, 0, 0]} | ||
|
||
data_dir: { | ||
path_to_zip: "D:/Machine Learning/tiktorch/tests/data/CREMI_DUNet_pretrained_new.zip", | ||
path_in_zip_to_model: "CREMI_DUNet_pretrained_new/model.py", | ||
path_in_zip_to_state: "CREMI_DUNet_pretrained_new/state.nn", | ||
} | ||
|
||
raw_data_path: "D:/Machine Learning/tiktorch/mr_robot/train.n5" | ||
labelled_data_path: "volumes/labels/neuron_ids" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# utility functions for the robot | ||
import numpy as np | ||
from scipy.ndimage import convolve | ||
|
||
# ref: https://github.com/constantinpape/vis_tools/blob/master/vis_tools/edges.py#L5 | ||
def make_edges3d(segmentation): | ||
FynnBe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" Make 3d edge volume from 3d segmentation | ||
""" | ||
# NOTE we add one here to make sure that we don't have zero in the segmentation | ||
gz = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(3, 1, 1)) | ||
gy = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 3, 1)) | ||
gx = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 1, 3)) | ||
return (gx ** 2 + gy ** 2 + gz ** 2) > 0 | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems to me that image tiling could nicely be implemented for n dimensions. Maybe have a look at https://github.com/ilastik/lazyflow/blob/dfbb450989d4f790f5b19170383b777fb88be0e8/lazyflow/roi.py#L473 for some inspiration |
||
# create patches | ||
def tile_image(image_shape, tile_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems like a good candidate for proper tests. |
||
""" cuts the input image into pieces of size 'tile_size' | ||
and returns a list of indices conatining the starting index (x,y) | ||
for each patch | ||
|
||
Args: | ||
image_shape (tuple): shape of input n-dimensional image | ||
tile_size (int): cutting parameter | ||
""" | ||
|
||
assert image_shape[-1] >= tile_size and image_shape[-2] >= tile_size, "image too small for this tile size" | ||
|
||
tiles = [] | ||
(w, h) = image_shape[-2], image_shape[-1] | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = ((wsi, hsi), (wsi + tile_size, hsi + tile_size)) | ||
tiles.append(img) | ||
|
||
if h % tile_size != 0: | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
img = ((wsi, h - tile_size), (wsi + tile_size, h)) | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0: | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = ((w - tile_size, hsi), (w, hsi + tile_size)) | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0 and h % tile_size != 0: | ||
img = ((w - tile_size, h - tile_size), (w, h)) | ||
tiles.append(img) | ||
""" | ||
x = [] | ||
for i in range(len(image_shape) - 2): | ||
x.append([0, image_shape[i]]) | ||
|
||
for i in range(len(tiles)): | ||
tiles[i] = x + tiles[i] | ||
""" | ||
return tiles |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pytest | ||
from mr_robot.mr_robot import MrRobot, Strategy1 | ||
from mr_robot.utils import tile_image | ||
from tiktorch.server import TikTorchServer | ||
|
||
|
||
def test_tile_image(): | ||
# when image dim are not multiple of patch size | ||
tiled_indices = tile_image((1, 48, 48), 32) | ||
assert len(tiled_indices) == 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to also have the comparison with the exact result, not only length |
||
tiled_indices = tile_image((3, 71, 71), 23) | ||
assert len(tiled_indices) == 16 | ||
|
||
# when image too small for the patch | ||
with pytest.raises(AssertionError): | ||
tiled_indices = tile_image((1, 48, 48), 64) | ||
|
||
|
||
def test_MrRobot(): | ||
strategy = Strategy1("D:/Machine Learning/tiktorch/mr_robot/robot_config.yml") | ||
robo = MrRobot("D:/Machine Learning/tiktorch/mr_robot/robot_config.yml", strategy) | ||
assert isinstance(robo, MrRobot) | ||
assert isinstance(robo.new_server, TikTorchServer) | ||
assert robo.input_shape == [1, 32, 32] | ||
assert isinstance(robo.slicer, list) | ||
robo._load_model() | ||
# robo._resume() | ||
# robo._predict() | ||
# assert len(strategy.patched_data) == 4 | ||
robo._run() | ||
# robo.terminate() | ||
# print(robo.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to ignore
.nn
and.hdf
files (as there are none in the repo). Pls remove