diff --git a/.gitignore b/.gitignore index 6d091a80..59495b80 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ tiktorch/.idea tiktorch/__pycache/ /#wrapper.py# /.#wrapper.py# -.py~ \ No newline at end of file +.py~ +*.hdf diff --git a/environment.yml b/environment.yml index 24d0249b..83060973 100644 --- a/environment.yml +++ b/environment.yml @@ -13,6 +13,9 @@ dependencies: - inferno=v0.4.* - pyzmq=18.0.1 - pyyaml=3.13 + - seaborn + - z5py + - scikit-learn - pytest=4.3.0 - black - isort diff --git a/mr_robot/__init__.py b/mr_robot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mr_robot/__main__.py b/mr_robot/__main__.py new file mode 100644 index 00000000..bda461de --- /dev/null +++ b/mr_robot/__main__.py @@ -0,0 +1,23 @@ +import argparse + +from mr_robot.mr_robot import MrRobot, strategies + + +parser = argparse.ArgumentParser() +parser.add_argument("-c", "--config", help="Path to config file", type=str, required=True) +parser.add_argument( + "-s", "--strategy", help="Annotation strategy to use", type=str, choices=strategies.keys(), required=True +) +parser.add_argument("-d", "--device", help="Which device to use", type=str, required=True, default="cpu") + + +def main(): + args = parser.parse_args() + robo = MrRobot(args.config, args.strategy, args.device) + robo._load_model() + robo._run() + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/mr_robot/annotator/annotate.py b/mr_robot/annotator/annotate.py new file mode 100644 index 00000000..36166b3f --- /dev/null +++ b/mr_robot/annotator/annotate.py @@ -0,0 +1,54 @@ +import numpy as np +import random + +from mr_robot.utils import get_coordinate + + +class Annotater: + """ The annotater class prvides methods for different labelling strategies, emulating a user + in some way + + Args: + annotation_percent (int): percentage of pixels in the patch to annotate + """ + + def __init__(self, annotation_percent): + self.annotation_percent = annotation_percent + # self.block_shape = block_shape + + def get_random_index(self, block_shape): + random_index = [] + for i in range(len(block_shape)): + random_index.append(np.random.randint(0, block_shape[i])) + + return tuple(random_index) + + def get_random_patch(self, block_shape): + rand_index = self.get_random_index(block_shape) + patch_dimension = [] + for i in range(len(block_shape)): + patch_dimension.append(np.random.randint(0, block_shape[i] - rand_index[i])) + + block = [] + for i in range(len(patch_dimension)): + block.append(slice(rand_index[i], rand_index[i] + patch_dimension[i])) + return tuple(block) + + def dense(self, label): + return label + + def random_sparse(self, label): + ret_label = np.zeros(label.shape) + for i in range(int(self.annotation_percent) * np.product(label.shape)): + index = self.get_random_index(label.shape) + ret_label[index] = label[index] + return ret_label + + def random_blob(self, label): + ret_label = np.ones(label.shape) + for i in range(int(self.annotation_percent) * np.product(label.shape)): + random_block = self.get_random_patch(label.shape) + ret_label[random_block] = label[random_block] + i += np.product(label[random_block].shape) + print("random blob annotated label:", np.unique(ret_label), "labels of actual return patch:", np.unique(label)) + return ret_label diff --git a/mr_robot/mr_robot.py b/mr_robot/mr_robot.py new file mode 100644 index 00000000..2f867731 --- /dev/null +++ b/mr_robot/mr_robot.py @@ -0,0 +1,394 @@ +import concurrent.futures as cf +import logging +import os +import random +import zipfile +from io import BytesIO +from typing import List + +import numpy as np +import torch +import torch.nn as nn + +import h5py +import yaml +import z5py +from mr_robot.strategies.strategy import * +from mr_robot.utils import ( + get_confusion_matrix, + get_coordinate, + integer_to_onehot, + make_plot, + plot_confusion_matrix, + tile_image, +) +from scipy import sparse +from sklearn.metrics import accuracy_score, f1_score +from tensorboardX import SummaryWriter +from tiktorch.models.dunet import DUNet +from tiktorch.rpc.utils import BatchedExecutor +from tiktorch.server.base import TikTorchServer +from tiktorch.types import Model, ModelState, NDArray, NDArrayBatch + +img_dim = 32 +batch_size = 48 + + +class MrRobot: + """ The robot class runs predictins on the model, and feeds the worst performing patch back for training. + The order in which patches are feed back is determined by the 'strategy'. The robot applies a given strategy, + adds new patches to the training data and logs the metrics to tensorboard + + Args: + path_to_config_file (string): path to the robot configuration file to + load necessary variables + strategy (string): strategy to follow (atleast intially) + """ + + def __init__(self, path_to_config_file, strategy_set, devices: List[str]) -> None: + + assert torch.cuda.device_count( + ) == 1, f"Device count is {torch.cuda.device_count()}" + # start the server + self.new_server = TikTorchServer() + self._devices = devices + + with open(path_to_config_file, mode="r") as f: + self.base_config = yaml.load(f) + + if self.base_config["data_dir"]["raw_data_base_folder"].endswith(".h5"): + self.raw_data_file = h5py.File( + self.base_config["data_dir"]["raw_data_base_folder"], 'r') + self.labelled_data_file = h5py.File( + self.base_config["data_dir"]["labelled_data_base_folder"], 'r') + self.validation_raw_file = h5py.File( + self.base_config["data_dir"]["validation_raw_base"], 'r') + self.validation_label_file = h5py.File( + self.base_config["data_dir"]["validation_label_base"], 'r') + else: + self.raw_data_file = z5py.File( + self.base_config["data_dir"]["raw_data_base_folder"], 'r') + self.labelled_data_file = z5py.File( + self.base_config["data_dir"]["labelled_data_base_folder"], 'r') + self.validation_raw_file = h5py.File( + self.base_config["data_dir"]["validation_raw_base"], 'r') + self.validation_label_file = h5py.File( + self.base_config["data_dir"]["validation_label_base"], 'r') + + image_shape = self.raw_data_file[self.base_config["data_dir"] + ["path_to_raw_data"]].shape + validation_data_shape = self.validation_raw_file[self.base_config["data_dir"] + ["validation_raw"]].shape + print(image_shape) + self.block_list = tile_image( + image_shape, self.base_config["training"]["training_shape"]) + self.validation_block_list = tile_image( + validation_data_shape, self.base_config["training"]["training_shape"]) + print("number of patches: %s" % len(self.block_list)) + print() + + strategy_objects = [] + for(strategy in strategy_set): + if strategy == "VideoLabelling": + videolabelling = VideoLabelling(self.base_config[strategy_params]["loss_fn"], self.base_config["class_dict"], self.raw_data_file, self.labelled_data_file, self.base_config["data_dir"], + self.base_config["strategy_params"]["labelling_strategy"], self.base_config[strategy_params]["annotatoin_percent"], self.base_config["training"]["training_shape"]) + strategy_objects.append(videolabelling) + else: + new_strategy_object = strategy(self.base_config[strategy_params]["loss_fn"], self.base_config["class_dict"], self.raw_data_file, self.labelled_data_file, + self.base_config["data_dir"], self.base_config["strategy_params"]["labelling_strategy"], self.base_config[strategy_params]["annotatoin_percent"]) + strategy_objects.append(new_strategy_object) + + self.strategy = StrategyAbstract(self.new_server, strategy_objects) + + self.iterations_max = self.base_config.pop("max_robo_iterations") + self.iterations_done = 0 + self.stats = { + "training_loss": [], + "training_accuracy": [], + "robo_predict_accuracy": [], + "f1_score": [], + "robo_predict_loss": [], + "validation_loss": [], + "validation_accuracy": [], + "training_iterations": [], + "number_of_patches": [], + "training_confusion": 0, + + } + mr_robot_folder = os.path.dirname( + os.path.dirname(os.path.abspath(__file__))) + self.tensorboard_writer = SummaryWriter(logdir=os.path.join( + mr_robot_folder, "tests", "robot", "robo_logs", "strat_classwise")) + self.patch_id = dict() + self.logger = logging.getLogger(__name__) + + def _load_model(self): + + if self.base_config["model_dir"]["path_to_folder"].endswith(".zip"): + archive = zipfile.ZipFile( + self.base_config["model_dir"]["path_to_folder"], "r") + model_file = archive.read( + self.base_config["model_dir"]["path_in_folder_to_model"]) + binary_state = archive.read( + self.base_config["model_dir"]["path_in_folder_to_state"]) + + else: + model_file = open( + os.path.join( + self.base_config["model_dir"]["path_to_folder"], + self.base_config["model_dir"]["path_in_folder_to_model"], + ) + ) + binary_state = open( + os.path.join( + self.base_config["model_dir"]["path_to_folder"], + self.base_config["model_dir"]["path_in_folder_to_state"], + ) + ) + + model = Model(code=model_file, config=self.base_config) + binary_state = ModelState(b"") + # binary_state.model_state = b"" + + self.new_server.load_model(model, binary_state, self._devices) + # self.tensorboard_writer.add_graph(DUNet(1,1),torch.from_numpy(self.raw_data_file[self.base_config["data_dir"]["path_to_raw_data"]][0]) ) + self.logger.info("model loaded") + + def _resume(self): + + self.new_server.resume_training() + # self.binary_state = self.new_server.get_model_state() + self.logger.info("training resumed") + + def _predict(self): + """ run prediction on the whole set of patches + """ + # self.strategy.patched_data.clear() + + x = 0 + prediction_list = [] + path_to_input = self.base_config["data_dir"]["path_to_raw_data"] + path_to_label = self.base_config["data_dir"]["path_to_labelled"] + + batch_maker = BatchedExecutor(batch_size=8) + for block in self.block_list: + # map each slicer with its corresponding index + is_patch_already_in_training_flag = self.assign_id(block, x) + if(is_patch_already_in_training_flag == True): + continue + # self.patch_id[block[0].start] = x + # pred_output = self.new_server.forward(NDArray(self.raw_data_file[path_to_input][block])) + prediction_list.append( + batch_maker.submit(self.new_server.forward, NDArray( + self.raw_data_file[path_to_input][block], x)) + ) + x += 1 + # self.pred_output = pred_output.result().as_numpy() + # print("hello") + # self.strategy.update_state(self.pred_output, self.labelled_data_file[path_to_label][block], block) + + for prediction in cf.as_completed(prediction_list): + block = self.block_list[prediction.result().id] + self.strategy.update_state( + prediction.result().as_numpy( + ), self.labelled_data_file[path_to_label][block], block, False + ) + + # self.logger.info("prediction run for iteration {}", self.iterations_done) + + def validate(self): + validation_list = [] + x = 0 + path_to_input = self.base_config["data_dir"]["validation_raw"] + path_to_label = self.base_config["data_dir"]["validation_label"] + + batch_maker = BatchedExecutor(batch_size=8) + for block in self.validation_block_list: + validation_list.append( + batch_maker.submit(self.new_server.forward, NDArray( + self.validation_raw_file[path_to_input][block], x)) + ) + x += 1 + for prediction in cf.as_completed(validation_list): + block = self.validation_block_list[prediction.result().id] + self.strategy.update_state( + prediction.result().as_numpy( + ), self.validation_label_file[path_to_label][block], block, True + ) + + def stop(self): + """ function which determines when the robot should stop + + currently, it stops after robot has completed 'iterations_max' number of iterations + """ + + if self.iterations_done > self.iterations_max: + return True + else: + self.iterations_done += 1 + return False + + def _run(self): + """ Feed patches to tiktorch (add to the training data) + + The function fetches the patches in order decided by the strategy, + removes it from the list of indices and feeds it to tiktorch + """ + + while not self.stop(): + print("robot running for iteration: %s" % self.iterations_done) + + self._predict() + self.validate() + curr_model_state = self.new_server.get_model_state() + self.stats["training_iterations"].append( + curr_model_state.num_iterations_done) + # self.stats["training_loss"].append(curr_model_state.loss) + self.stats["number_of_patches"].append( + (self.iterations_done - 1) * batch_size) + + self.write_to_tensorboard() + + data_batch = self.strategy.get_next_batch(batch_size) + + self._add(data_batch) + #self.remove_key([id for image, label, id in data_batch]) + print("waiting for training") + self.new_server.train_for(10).result() + print("training done!!") + + self.terminate() + + def _add(self, new_data_batch): + """ add a new batch of images to training data + + Args: + new_data_batch (list): list of tuples, where each tuple contains an image, its label and their block id + """ + assert new_data_batch is not None, "No data provided!" + print("adding new data batch:", len(new_data_batch)) + + new_inputs, new_labels = [], [] + for image, label, block_id in new_data_batch: + new_inputs.append(NDArray(image.astype(np.int32), block_id)) + new_labels.append(NDArray(label.astype(np.int32), block_id)) + + self.new_server.update_training_data( + NDArrayBatch(new_inputs), NDArrayBatch(new_labels)) + print("addition done!!") + + def write_to_tensorboard(self): + metric_data = self.strategy.get_metrics() + + for key, value in metric_data.items(): + if "confusion_matrix" not in key and key != "robo_predict_loss": + self.stats[key].append(value) + if key == "robo_predict_loss": + print(value.item()) + self.stats[key].append(value.item()) + + if(self.iterations_done == 1): + self.stats["training_accuracy"].append(0) + self.stats["training_loss"].append(0) + training_confusion_matrix = np.zeros((2, 2)) + + else: + path_to_input = self.base_config["data_dir"]["path_to_raw_data"] + path_to_label = self.base_config["data_dir"]["path_to_labelled"] + + training_block = [] + for key, value in self.patch_id.items(): + if value == -1: + training_shape = self.base_config["training"]["training_shape"] + block = tuple([slice(key[i], key[i] + training_shape[i]) + for i in range(len(training_shape))]) + print(block) + training_block.append(block) + + x = 0 + print("training data size:", len(training_block)) + train_prediction_list = [] + batch_maker = BatchedExecutor(batch_size=8) + for block in training_block: + image = self.raw_data_file[path_to_input][block] + print("training accuracy calc:", image.shape) + train_prediction_list.append( + batch_maker.submit( + self.new_server.forward, NDArray(image, x)) + ) + x += 1 + print() + train_accuracy, conf_matrix, train_loss = 0.0, 0, 0 + for prediction in cf.as_completed(train_prediction_list): + block = training_block[prediction.result().id] + label = self.labelled_data_file[path_to_label][block] + label[label == 2] = 0 + + pred_output = prediction.result().as_numpy() + + criterion_class = getattr(nn, "BCELoss", None) + criterion_class_obj = criterion_class(reduction="mean") + train_loss += criterion_class_obj( + torch.from_numpy(pred_output.astype(np.float32)), torch.from_numpy( + label.astype(np.float32)) + ) + + pred_output = pred_output.flatten().round().astype(np.int32) + target = label.flatten().round().astype(np.int32) + + train_accuracy += accuracy_score(target, pred_output) + conf_matrix += get_confusion_matrix( + pred_output, target, list(self.base_config["class_dict"].keys())) + + train_accuracy /= ((self.iterations_done - 1) * batch_size) + train_loss /= ((self.iterations_done - 1) * batch_size) + training_confusion_matrix = conf_matrix / \ + ((self.iterations_done - 1) * batch_size) + self.stats["training_accuracy"].append(train_accuracy) + self.stats["training_loss"].append(train_loss) + + training_confusion_matrix = plot_confusion_matrix( + training_confusion_matrix, self.base_config["class_dict"]) + loss_plot, accuracy_plot = make_plot(self.stats) + self.tensorboard_writer.add_figure("loss_plot", loss_plot) + self.tensorboard_writer.add_figure("accuracy_plot", accuracy_plot) + # self.tensorboard_writer.add_scalar("avg_loss", metric_data["avg_loss"], self.iterations_done) + # self.tensorboard_writer.add_scalar("avg_accuracy", metric_data["avg_accuracy"] * 100, self.iterations_done) + # self.tensorboard_writer.add_scalar("F1_score", metric_data["avg_f1_score"], self.iterations_done) + self.tensorboard_writer.add_figure( + "robo_confusion_matrix", metric_data["robo_confusion_matrix"], global_step=self.iterations_done + ) + self.tensorboard_writer.add_figure( + "validation_confusion_matrix", metric_data["validation_confusion_matrix"], global_step=self.iterations_done + ) + self.tensorboard_writer.add_figure( + "training_confusion_matrix", training_confusion_matrix, global_step=self.iterations_done + ) + + def assign_id(self, block, index): + id = get_coordinate(block) + if (id in self.patch_id and self.patch_id[id] == -1): + return True + + self.patch_id[id] = index + return False + + def remove_key(self, ids): + for id in ids: + self.patch_id[id] = -1 + + def terminate(self): + self.tensorboard_writer.close() + self.new_server.shutdown() + +""" +strategies = { + "highestloss": HighestLoss, + "strategyrandom": StrategyRandom, + "randomsparseannotate": RandomSparseAnnotate, + "densesparseannotate": DenseSparseAnnotate, + "classwiseloss": ClassWiseLoss, + "videolabelling": VideoLabelling, + "strategyabstract": StrategyAbstract, +} +""" diff --git a/mr_robot/robot_config.yml b/mr_robot/robot_config.yml new file mode 100644 index 00000000..ce2e5d43 --- /dev/null +++ b/mr_robot/robot_config.yml @@ -0,0 +1,41 @@ +# base config for robot +max_robo_iterations: 10 +model_class_name: DUNet2D +model_init_kwargs: {in_channels: 1, out_channels: 1} +training: { + training_shape: [1, 256, 256], + batch_size: 1, + loss_criterion_config: {"method": "NLLLoss", weight: [1.0, 1.0]}, + optimizer_config: {"method": "Adam", "lr": 1.0e-4}, + num_iterations_done: 0, +} +validation: {} +dry_run: {"skip": True, "shrinkage": [0, 0, 0]} + +model_dir: { + path_to_folder: "/home/psharma/psharma/repos/tiktorch/tests/data/CREMI_DUNet_pretrained_new.zip", + path_in_folder_to_model: "CREMI_DUNet_pretrained_new/model.py", + path_in_folder_to_state: "CREMI_DUNet_pretrained_new/state.nn", +} + +data_dir: { + raw_data_base_folder: "/home/psharma/psharma/repos/tiktorch/tests/data/Fluo_N2DH_train_raw.h5", + path_to_raw_data: "raw", + labelled_data_base_folder: "/home/psharma/psharma/repos/tiktorch/tests/data/Fluo_N2DH_train_label.h5", + path_to_labelled: "labelled", + validation_raw_base: "/home/psharma/psharma/repos/tiktorch/tests/data/Fluo_N2DH_validation_raw.h5", + validation_raw: "raw", + validation_label_base: "/home/psharma/psharma/repos/tiktorch/tests/data/Fluo_N2DH_validation_label.h5", + validation_label: "labelled" +} + +strategy_params: { + loss_fn: "BCELoss", + class_dict: { + 1: "background", + 0: "cell" + }, + + labelling_strategy: "dense", + annotation_percent: 0.6 +} diff --git a/mr_robot/strategies/strategy.py b/mr_robot/strategies/strategy.py new file mode 100644 index 00000000..eab5cfb7 --- /dev/null +++ b/mr_robot/strategies/strategy.py @@ -0,0 +1,682 @@ +import logging +import os + +import h5py +import numpy as np +import torch +import torch.nn as nn +import z5py +import random + +from scipy import sparse +from sklearn.metrics import accuracy_score, f1_score +from tensorboardX import SummaryWriter + +from tiktorch.server import TikTorchServer +from mr_robot.annotator.annotate import * +from mr_robot.utils import ( + get_confusion_matrix, + integer_to_onehot, + plot_confusion_matrix, + tile_image, + get_random_patch, + get_random_index, +) + + + + +def randomize(label, num_of_classes): + """ perform a random action on the label + Action: + -1: erase label + 0: retain current state (ignore) + x: update label to x, where x is class number + + Args: + label (np.ndarray): actual ground truth + num_of_classes (int): number of classes in the dataset + """ + + actions = [-1, 0] + [i for i in range(0, num_of_classes )] + volume = np.product(label.shape) + # print(volume, label.shape) + x = np.random.randint(0, volume) + while x: + index = get_random_index(label.shape) + label[index] = random.choice(actions) + x -= 1 + + return label + + +def user_simulator(raw_data_file, label_data_file, internal_paths, canvas_shape, num_of_classes): + """ mimic user annotation process by randomly taking a patch from the dataset and labelling it + labels can be added, updated or deleted + + Args: + raw_data_file(file pointer): pointer to folder containing raw data + lab_data_file(file pointer): pointer to folder containing labelled data + internal_paths (dictionary): paths inside base folders to raw and labelled data file + canvas_shape (tuple): shape of canvas + num_of_classes (int): number of classes in the dataset + """ + print(canvas_shape) + timesteps = np.random.randint(10, 20) + video = [] + for i in range(timesteps): + random_patch = get_random_patch(canvas_shape) + image, label = ( + raw_data_file[internal_paths["path_to_raw_data"]][random_patch], + label_data_file[internal_paths["path_to_labelled"]][random_patch], + ) + label = randomize(label, num_of_classes) + label[label>=2] =0 + video.append((image, label, random_patch)) + return video + + +class BaseStrategy: + def __init__( + self, loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ): + + self.patched_data = [] + self.loss_fn = loss_fn + self.strategy_metric = { + "robo_predict_accuracy": 0, + "f1_score": 0, + "robo_predict_loss": 0, + "validation_loss": 0, + "validation_accuracy": 0, + "robo_confusion_matrix": 0, + "validation_confusion_matrix": 0, + } + self.class_dict = class_dict + self.raw_data_file = raw_data_file + self.labelled_data_file = labelled_data_file + self.paths = paths + self.annotater = Annotater(annotation_percent) + self.labelling_strategy = labelling_strategy + self.validation_data_size = 0 + self.logger = logging.getLogger(__name__) + + def update_state(self, pred_output, target, block, validation_flag): + """ computes loss and accuracy corresponding to the output and target according to + the given loss function and update patch data + + Args: + predicted_output(np.ndarray) : output predicted by the model + target(np.ndarray): ground truth + block(tuple): tuple of slice objects, one per dimension, specifying the corresponding block + in the actual image + """ + # print("shape before one hot:", target.shape) + + if (validation_flag == False): + print("actual prediction:", pred_output, "label:", target) + else: + self.validation_data_size+=1 + + #pred_output[pred_output >= 0.5] = 1 + #pred_output[pred_output < 0.5] = 2 + target[target==2] = 0 + #target[target==1] = 1 + print("values in prediction:", np.unique(pred_output) ) + if (validation_flag == False): + print("after updating prediction:", pred_output) + criterion_class = getattr(nn, self.loss_fn, None) + assert criterion_class is not None, "Criterion {} not found.".format(method) + criterion_class_obj = criterion_class(reduction="sum") + + if pred_output.shape != target.shape: + print("shape mismatch!!") + target = integer_to_onehot(target) + pred_output = np.expand_dims(pred_output, axis=0) + + self.training_shape = target.shape + # print("output_shape:", target.shape, "predicion shape:", pred_output.shape) + curr_loss = criterion_class_obj( + torch.from_numpy(pred_output.astype(np.float32)), torch.from_numpy(target.astype(np.float32)) + ) + + if validation_flag is False: + self.patched_data.append((curr_loss, block)) + # print("output_shape:", target.shape) + pred_output = pred_output.flatten().round().astype(np.int32) + target = target.flatten().round().astype(np.int32) + + self.write_metric(pred_output, target, curr_loss, validation_flag) + + def write_metric(self, pred_output, target, curr_loss, validation_flag): + if validation_flag is True: + self.strategy_metric["validation_accuracy"] += accuracy_score(target, pred_output) + self.strategy_metric["validation_loss"] += curr_loss + self.strategy_metric["f1_score"] += f1_score(target, pred_output, average="weighted") + self.strategy_metric["validation_confusion_matrix"] += get_confusion_matrix( + pred_output, target, list(self.class_dict.keys()) + ) + + else: + self.strategy_metric["robo_confusion_matrix"] += get_confusion_matrix( + pred_output, target, list(self.class_dict.keys()) + ) + self.strategy_metric["robo_predict_accuracy"] += accuracy_score(target, pred_output) + self.strategy_metric["robo_predict_loss"] += curr_loss + # print(self.strategy_metric) + + def get_annotated_data(self, return_block_set): + return_data_set = [] + for block in return_block_set: + image, label = ( + self.raw_data_file[self.paths["path_to_raw_data"]][block], + self.labelled_data_file[self.paths["path_to_labelled"]][block], + ) + label[label ==2] = 0 + return_data_set.append( + (image, getattr(self.annotater, self.labelling_strategy)(label), get_coordinate(block)) + ) + + self.patched_data.clear() + return return_data_set + + def get_metrics(self): + print("metric before averaging", self.strategy_metric) + for key in self.strategy_metric.keys(): + if( "validation" in key or "f1" in key): + self.strategy_metric[key] /= self.validation_data_size + else: + self.strategy_metric[key] /= len(self.patched_data) + + if ("loss" in key): + self.strategy_metric[key] /= np.product(self.training_shape) + + #self.strategy_metric["validation_loss"] /= np.product(self.training_shape) + print("metric after averaging", self.strategy_metric) + import copy + + strategy_metric = copy.deepcopy(self.strategy_metric) + + # FIXME confision_matrix -> plotted_confusion_matrix + strategy_metric["robo_confusion_matrix"] = plot_confusion_matrix( + strategy_metric["robo_confusion_matrix"], self.class_dict + ) + strategy_metric["validation_confusion_matrix"] = plot_confusion_matrix( + strategy_metric["validation_confusion_matrix"], self.class_dict + ) + self.strategy_metric = self.strategy_metric.fromkeys(self.strategy_metric, 0) + print("reset done:") + return strategy_metric + + def get_next_batch(self): + raise NotImplementedError() + + def rearrange(self): + raise NotImplementedError() + + +class HighestLoss(BaseStrategy): + """ This strategy sorts the patches in descending order of their loss + + Args: + loss_fn (string): loss metric to be used + class_dict (dictionary): dictionary indicating the mapping between classes and their labels + """ + + def __init__( + self, loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ): + super().__init__( + loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ) + # self.patch_counter = -1 + + def rearrange(self): + """ rearranges the patches in descending order of their loss + """ + self.patched_data.sort(reverse=True) + + def get_next_batch(self, batch_size=1): + """ Feeds a batch of patches at a time to the robot in descending order of their loss + + Args: + batch_size (int): number of patches to be returned, defaults to 1 + """ + # self.patch_counter += 1 + assert len(self.patched_data) >= batch_size, "batch_size too big for current dataset" + + self.rearrange() + return_block_set = [block for loss, block in np.array(self.patched_data)[:batch_size]] + return super().get_annotated_data(return_block_set) + + +class StrategyRandom(BaseStrategy): + """ randomly selected a patch, or batch of patches + and returns them to the robot + + Args: + loss_fn (string): loss metric + class_dict (dictionary): dictionary indicating the mapping between classes and their labels + """ + + def __init__( + self, loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ): + super().__init__( + loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ) + + def rearrange(self): + pass + + def get_next_batch(self, batch_size=1): + """ returns a random set of patches + + Args: + batch_size (int): number of patches to return + """ + print("leftover dataset size:", len(self.patched_data)) + assert len(self.patched_data) >= batch_size, "batch_size too big for current dataset" + + rand_indices = np.random.randint(0, len(self.patched_data), size=batch_size) + # print(rand_indices, self.patched_data[rand_indices]) + return_block_set = [block for loss, block in np.array(self.patched_data)[rand_indices]] + return super().get_annotated_data(return_block_set) + + +class RandomSparseAnnotate(HighestLoss): + """ randomly annotate pixels in the labels. + This emulates a user who randomly annotates pixels evenly spread across the entire image + + Args: + loss_fn (string): loss metric + class_dict (dictionanry): dictionary indicating the mapping between classes and their labels + raw_data_file (h5py/z5py.File): pointer to base folder containing raw images + labelled_data_file (h5py/z5py.File): pointer to base folder containing labelled images + paths (dictionary): path inside base folders to raw images and their labels + """ + + def __init__( + self, loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ): + super().__init__( + loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ) + + def update_state(self, pred_output, target, block, training_iterations): + super().update_state(pred_output, target, block, training_iterations) + self.block_shape = target.shape + + def get_random_index(self): + random_index = [] + for i in range(len(self.block_shape)): + random_index.append(np.random.randint(0, self.block_shape[i])) + + return tuple(random_index) + + def get_next_batch(self, batch_size=1): + assert len(self.patched_data) >= batch_size, "batch_size too big for current dataset" + + return_block_set = [block for loss, block in np.array(self.patched_data)[:batch_size]] + return super().get_annotated_data(return_block_set) + return_data_set = [] + for block in return_block_set: + image, label = ( + self.raw_data_file[self.paths["path_to_raw_data"]][block], + self.labelled_data_file[self.paths["path_to_labelled"]][block], + ) + x = np.random.randint(0, np.product(self.block_shape)) + for i in range(x): + label[self.get_random_index()] = -1 + return_data_set.append((image, label, get_coordinate(block))) + + return return_data_set + + +class DenseSparseAnnotate(RandomSparseAnnotate): + """ sparsely annotate dense patches of labels. + This emulates a user who randomly annotates small patches sparsely spread across the entire image + + Args: + loss_fn (string): loss metric + class_dict (dictionanry): dictionary indicating the mapping between classes and their labels + raw_data_file (file pointer): pointer to base folder containing raw images + labelled_data_file (): pointer to base folder containing labelled images + paths (dictionary): path inside base folders to raw images and their labels + """ + + def __init__( + self, loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ): + super().__init__( + loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ) + + def get_random_patch(self): + rand_index = super().get_random_index() + patch_dimension = [] + for i in range(len(self.block_shape)): + patch_dimension.append(np.random.randint(0, self.block_shape[i] - rand_index[i])) + + block = [] + for i in range(len(patch_dimension)): + block.append(slice(rand_index[i], rand_index[i] + patch_dimension[i])) + return tuple(block) + + def get_next_batch(self, batch_size=1): + assert len(self.patched_data) >= batch_size, "batch_size too big for current dataset" + + return_block_set = [block for loss, block in np.array(self.patched_data)[:batch_size]] + return_data_set = [] + for block in return_block_set: + image, label = ( + self.raw_data_file[self.paths["path_to_raw_data"]][block], + self.labelled_data_file[self.paths["path_to_labelled"]][block], + ) + x, sparse_label = np.random.randint(0, np.product(self.block_shape)), np.full(label.shape, -1) + for i in range(x): + block = get_random_patch() + sparse_label[block] = label[block] + + return_data_set.append((image, sparse_label, get_coordinate(block))) + + return return_data_set + + +class ClassWiseLoss(BaseStrategy): + """ sorts patches according to classes with highest loss, patches with maximum + instances of this class are fed first + Assumptions: + 1. model output for multiclass claissfication will always be one hot encoded + 2. class labels are annotated using 1 based indexing + + Args: + loss_fn (string): loss function to use + class_dict (dictionary): dictionary indicating the mapping between classes and their labels + """ + + def __init__( + self, loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ): + super().__init__( + loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ) + self.num_classes = len(self.class_dict) + self.class_loss = [0] * self.num_classes + self.image_class_count = np.zeros((1, self.num_classes + 1)).astype(np.int32) + self.image_counter = 1 + self.image_id = dict() + + def update_state(self, pred_output, target, block, validation_flag): + """ + 1. calculate loss for given prediction and label + 2. map each image with a corresponding ID + + Args: + pred_output (numpy.ndarray): prediction + target (numpy.ndarray): actual label + block (tuple[slice]): tuple of slice objects, one per dimension, specifying the patch in the actual image + """ + if (validation_flag == False): + print("actual prediction:", pred_output, "label:", target) + else: + self.validation_data_size+=1 + + self.training_shape = pred_output.shape + #pred_output[pred_output >= 0.5] = 1 + #pred_output[pred_output < 0.5] = 2 + target[target==2] = 0 + #target[target==1] = 1 + print("values in prediction:", np.unique(pred_output) ) + if (validation_flag == False): + print("after updating prediction:", pred_output) + #pred_output[pred_output >= 0.5] = 2 + #pred_output[pred_output < 0.5] = 1 + criterion_class = getattr(nn, self.loss_fn, None) + assert criterion_class is not None, "Criterion {} not found.".format(method) + criterion_class_obj = criterion_class(reduction="none") + + if len(self.class_dict) > 2: + one_hot_target = integer_to_onehot(target) + else: + one_hot_target = np.expand_dims(target, axis=0) + + pred_output = np.expand_dims(pred_output, axis=0) + if(validation_flag == False): + np.vstack([self.image_class_count, np.zeros((1, self.num_classes + 1))]) + self.image_class_count[-1][0] = self.image_counter + self.image_id[self.image_counter] = block + self.image_counter += 1 + + indices = [0] * (len(target.shape)) + self.record_classes(0, target, indices) + + self.loss_matrix = criterion_class_obj( + torch.from_numpy(one_hot_target.astype(np.float32)), torch.from_numpy(pred_output.astype(np.float32)) + ) + if(validation_flag == False): + indices = [0] * (len(self.loss_matrix.shape)) + self.record_class_loss(2, indices, self.loss_matrix.shape) + + curr_total_loss = torch.sum(self.loss_matrix) + + if len(self.class_dict) > 2: + target = one_hot_target + super().write_metric( + pred_output.flatten().round().astype(np.int32), + target.flatten().round().astype(np.int32), + curr_total_loss, + validation_flag, + ) + + def record_classes(self, curr_dim, label, indices): + """ record the number of occurences of each class in a patch + + Args: + curr_dim (int): current dimension to index + label (numpy.ndarray): target label + indices (list): list of variables each representing the current state of index for the n dimension + """ + + if curr_dim + 1 == len(label.shape): + for i in range(label.shape[curr_dim]): + indices[curr_dim] = i + self.image_class_count[-1][label[tuple(indices)] + 1] += 1 + return + + for i in range(label.shape[curr_dim]): + indices[curr_dim] = i + self.record_classes(curr_dim + 1, label, indices) + + def record_class_loss(self, curr_dim, indices, output_shape): + """ record the loss class wise for the given image + + Args: + curr_dim (int): current dimension to index + indices (list): list of variables each representing the current state of index for the n dimension + output_shape (tuple): shape of loss matrix + """ + + if curr_dim + 1 == len(output_shape): + for i in range(output_shape[curr_dim]): + indices[curr_dim] = i + index = indices + index[1] = slice(0, output_shape[1]) + index = tuple(index) + # print("record_class_loss", index) + curr_losses = self.loss_matrix[index].numpy().tolist() + # print(curr_losses) + for j in range(len(curr_losses)): + self.class_loss[j] += curr_losses[j] + + return + + for i in range(output_shape[curr_dim]): + indices[curr_dim] = i + self.record_class_loss(curr_dim + 1, indices, output_shape) + + def rearrange(self): + """ rearrange the rows of the image_class_count matrix wrt to the class (column) with the highest loss + """ + self.image_class_count[self.image_class_count[:, np.argmax(self.class_loss) + 1].argsort()[::-1]] + + def get_next_batch(self, batch_size=1): + self.rearrange() + return_block_set = [ + self.image_id[image_number] + for image_number in [image_id for image_id in self.image_class_count[:batch_size, 0]] + ] + self.class_loss = [0] * self.num_classes + self.image_class_count = np.zeros((1, self.num_classes + 1)).astype(np.int32) + self.image_counter = 1 + self.image_id.clear() + return super().get_annotated_data(return_block_set) + + +class VideoLabelling(BaseStrategy): + """emulates user who randomly annotates/de-annotates/updates various patches at different timestamps + The strategy expects a series of operations, which are then performed and added to the canvas (sparse matrix) + The canvas state at each time step is added to the training data + """ + + def __init__( + self, + loss_fn, + class_dict, + raw_data_file, + labelled_data_file, + paths, + labelling_strategy, + annotation_percent, + training_shape, + ): + super().__init__( + loss_fn, class_dict, raw_data_file, labelled_data_file, paths, labelling_strategy, annotation_percent + ) + dataset_shape = list(self.raw_data_file[self.paths["path_to_raw_data"]].shape) + dataset_shape[0] = 1 + self.canvas_shape = tuple(dataset_shape) + self.video = [] + self.training_shape = training_shape + self.process_video() + + def rearrange(self): + pass + + def get_next_batch(self, batch_size=None): + if not self.video: + raise ValueError("no more annotations in the video available!") + return self.video.pop(0) + + def update_canvas(self, label, block, curr_dim, index_list): + """ update the canvas from the received labels + + UPDATE RULE: + if label is: + -1: erase current label from canvas + 0: retain previous label on canvas + x: update label on canvas, where x is new class label + + Args: + label (np.ndarray): label received from user + block (tuple): tuple specifying the block in the canvas + curr_dim (int): current dimension in the recursion step + index_list (list): list of indices specifying the current index of each dimension + """ + + if curr_dim + 1 == len(label.shape): + for i in range(label.shape[curr_dim]): + index_list[curr_dim] = i + index = tuple(index_list) + if label[index] == -1: + self.canvas[block][index] = 0 + elif label[index] != 0: + self.canvas[block][index] = label[index] + + return + + for i in range(label.shape[curr_dim]): + index_list[curr_dim] = i + self.update_canvas(label, block, curr_dim + 1, index_list) + + def paint(self, label, block): + """ paints the newly received label onto the canvas + """ + + index_list = [0] * len(label.shape) + self.update_canvas(label, block, 0, index_list) + + def process_video(self): + """ take the series of annotations performed by the user and resize them to trainable shape + """ + + self.canvas = np.zeros((self.canvas_shape)) + user_annotations = user_simulator( + self.raw_data_file, self.labelled_data_file, self.paths, self.canvas_shape, len(self.class_dict) + ) + + for image, label, block in user_annotations: + self.paint(label, block) + base_coordinate = get_coordinate(block) + + # resize image and label by zero padding if image size is less than training shape + image_shape = tuple([max(image.shape[i], self.training_shape[i]) for i in range(len(image.shape))]) + image.resize(image_shape, refcheck=False) + label.resize(image_shape, refcheck=False) + canvas_tiles = tile_image(image.shape, self.training_shape) + + # iterate over the tiled image and add the data to the list 'video'. + # Each timestep is added as a list + curr_timestep = [] + for tile in canvas_tiles: + local_id = get_coordinate(tile) + global_id = tuple([base_coordinate[i] + local_id[i] for i in range(len(base_coordinate))]) + curr_timestep.append((image[tile], label[tile], global_id)) + + self.video.append(curr_timestep) + + +class StrategyAbstract: + """ abstract strategy which is a combination of one or more basic strategies + + Args: + *args [(Any, iterations)]: list of strategy objects to applied in given order, for {iterations} number of times + """ + + def __init__(self, tikserver, *args): + self.strategies = args + self.index = 0 + self.tikserver = tikserver + # self.tiktorch_config = {"training": {"num_iterations_done": 0}} + self.patches_added = 0 + self.patches_max = self.strategies[0][1] + + def update_strategy(self, batch_size=1): + if self.index >= len(self.strategies) - 1: + return + + elif self.patches_added >= self.patches_max: + self.index += 1 + self.patches_added = 0 + self.patches_max = self.strategies[self.index][1] + print("curr strategy:", str(self.strategies[self.index][0])) + + else: + self.patches_added += batch_size + + def update_state(self, pred_output, target, block, validation_flag): + self.strategies[self.index][0].update_state(pred_output, target, block, validation_flag) + + def rearrange(self): + self.strategies[self.index][0].rearrange() + + def get_metrics(self): + return self.strategies[self.index][0].get_metrics() + + def get_next_batch(self, batch_size=1): + new_batch = self.strategies[self.index][0].get_next_batch(batch_size) + #assert len(new_batch) == batch_size + self.update_strategy(batch_size) + return new_batch + + diff --git a/mr_robot/utils.py b/mr_robot/utils.py new file mode 100644 index 00000000..a2c21714 --- /dev/null +++ b/mr_robot/utils.py @@ -0,0 +1,190 @@ +# utility functions for the robot +import numpy as np +from scipy.ndimage import convolve +import seaborn as sn +import pandas as pd +import matplotlib.pyplot as plt + +from sklearn.metrics import confusion_matrix + + +def read_config(path: str) -> dict: + with open(path, mode="r") as f: + conf = yaml.load(f) + + return conf + + +# ref: https://github.com/constantinpape/vis_tools/blob/master/vis_tools/edges.py#L5 +def make_edges3d(segmentation): + """ Make 3d edge volume from 3d segmentation + """ + # NOTE we add one here to make sure that we don't have zero in the segmentation + gz = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(3, 1, 1)) + gy = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 3, 1)) + gx = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 1, 3)) + return (gx ** 2 + gy ** 2 + gz ** 2) > 0 + + +n = 0 +block_list, idx_list, visited = [], [], {} + + +def recursive_chop(dim_number, arr_shape, block_shape): + global block_list, idx_list, visited + + if dim_number >= n: + return + + for i in range(0, arr_shape[dim_number], block_shape[dim_number]): + idx_list[dim_number] = i + recursive_chop(dim_number + 1, arr_shape, block_shape) + slice_list, visited_key = [], [] + + for j in range(n): + visited_key.append(idx_list[j]) + if idx_list[j] + block_shape[j] > arr_shape[j]: + slice_list.append(slice(arr_shape[j] - block_shape[j], arr_shape[j])) + else: + slice_list.append(slice(idx_list[j], idx_list[j] + block_shape[j])) + + visited_key = tuple(visited_key) + if visited.get(visited_key) == None: + visited[visited_key] = 1 + block_list.append(tuple(slice_list)) + + idx_list[dim_number] = 0 + + +def tile_image(arr_shape, block_shape): + """ + chops of blocks of given size from an array + + Args: + arr_shape(tuple): size of input array (ndarray) + block_shape (tuple): size of block to cut into (ndarray) + + Return type: list(tuple(slice()))- a list of tuples, one per block where each tuple has + n slice objects, one per dimension (n: number of dimensions) + """ + + assert len(arr_shape) == len(block_shape), "block shape not compatible with array shape" + for i in range(len(arr_shape)): + assert arr_shape[i] >= block_shape[i], "block shape not compatible with array shape" + + global n, idx_list, visited + n = len(arr_shape) + block_list.clear(), visited.clear() + idx_list = [0 for i in range(n)] + + recursive_chop(0, arr_shape, block_shape) + return block_list + + +def get_confusion_matrix(predictions, actual_labels, classes): + if predictions.shape != actual_labels.shape: + print("SHAPE MISMATCH", predictions.shape, actual_labels.shape) + return np.zeros(shape=(len(classes), len(classes))) + c_mat_arr = confusion_matrix(actual_labels, predictions, labels=classes) + c_mat_n = c_mat_arr / c_mat_arr.astype(np.float).sum(axis=1, keepdims=True) + return np.nan_to_num(c_mat_n) + + +def plot_confusion_matrix(c_mat_n, cls_dict): + pd_cm_n = pd.DataFrame( + c_mat_n, index=[str(i) for i in cls_dict.values()], columns=[str(i) for i in cls_dict.values()] + ) + + sn_plot_n = sn.heatmap(pd_cm_n, annot=True) + return sn_plot_n.figure + + +def integer_to_onehot(integer_maps): + # return np.stack( + # [integer_maps == integer for integer in range(integer_maps.min(), integer_maps.max() + 1)], axis=1 + # ).astype(np.uint8) + return np.stack([integer_maps == integer for integer in range(0, 14)], axis=1).astype(np.uint8) + + +def onehot_preds_to_integer(one_hot_preds): + return np.argmax(one_hot_preds, axis=0) + + +def get_coordinate(block): + """ return the starting co-ordinate of a block + + Args: + block(tuple): tuple of slice objects, one per dimension + """ + + coordinate = [] + for slice_ in block: + coordinate.append(slice_.start) + + return tuple(coordinate) + + +def get_random_index(canvas_shape): + random_index = [] + for i in range(len(canvas_shape)): + random_index.append(np.random.randint(0, canvas_shape[i])) + + return tuple(random_index) + + +def get_random_patch(canvas_shape): + rand_index = get_random_index(canvas_shape) + patch_dimension = [] + for i in range(len(canvas_shape)): + patch_dimension.append(np.random.randint(1, canvas_shape[i] - rand_index[i] + 1)) + + block = [] + for i in range(len(patch_dimension)): + block.append(slice(rand_index[i], rand_index[i] + patch_dimension[i])) + return tuple(block) + + +def make_plot(stats): + hack_array = [0]*len(stats["training_loss"]) + plt.figure(1) + fig1, ax1 = plt.subplots() + ax1.set_xlabel("number_of_patches") + ax1.set_ylabel("losses") + print("training_loss:", stats["training_loss"]) + print("robo_predict_loss:", stats["robo_predict_loss"]) + ax1.plot(stats["number_of_patches"], stats["training_loss"], color="yellow") + ax1.plot(stats["number_of_patches"], stats["robo_predict_loss"], color="tab:blue") + #ax1.plot(stats["number_of_patches"], stats["validation_loss"], color="tab:green") + ax1.tick_params(axis="y", labelcolor="tab:red") + + ax2 = ax1.twiny() # instantiate a second axes that shares the same x-axis + + ax2.set_xlabel("training iterations") # we already handled the x-label with ax1 + ax2.plot(stats["training_iterations"], hack_array, color="white") + ax2.tick_params(axis="y", labelcolor="tab:blue") + ax1.legend(["training loss", "robo prediction loss"], loc="upper right") + fig1.tight_layout() + # plt.savefig("abc") + + hack_array = [0]*len(stats["validation_accuracy"]) + plt.figure(2) + fig2, ax3 = plt.subplots() + ax3.set_xlabel("number_of_patches") + ax3.set_ylabel("accuracies") + ax3.plot(stats["number_of_patches"], stats["training_accuracy"], color="orange") + #ax3.plot(stats["number_of_patches"], stats["validation_accuracy"], color="tab:red") + ax3.plot(stats["number_of_patches"], stats["robo_predict_accuracy"], color="tab:blue") + #ax3.plot(stats["number_of_patches"], stats["f1_score"], color="tab:green") + ax3.tick_params(axis="y", labelcolor="tab:red") + + ax4 = ax3.twiny() # instantiate a second axes that shares the same x-axis + + ax4.set_xlabel("training iterations") # we already handled the x-label with ax1 + ax4.plot(stats["training_iterations"], hack_array) + ax4.tick_params(axis="y", labelcolor="tab:blue") + ax3.legend(["training", "robo prediction"], loc="lower right") + fig2.tight_layout() + # plt.savefig("def") + plt.close(fig1) + + return (fig1, fig2) diff --git a/tests/robot/test_highestloss.py b/tests/robot/test_highestloss.py new file mode 100644 index 00000000..c4611e39 --- /dev/null +++ b/tests/robot/test_highestloss.py @@ -0,0 +1,103 @@ +import os +import pytest +import numpy as np +import random +import matplotlib.pyplot as plt + +from mr_robot.mr_robot import MrRobot +from mr_robot.strategies.strategy import HighestLoss, ClassWiseLoss, StrategyRandom, StrategyAbstract +from mr_robot.utils import tile_image, get_confusion_matrix, make_plot +from tiktorch.server import TikTorchServer + +os.environ["CUDA_VISIBLE_DEVICES"] = "7" + + +def test_tile_image(): + # when image dim are multiple of patch size + tiled_indices = tile_image((2, 2, 2), (2, 2, 1)) + assert tiled_indices == [ + (slice(0, 2, None), slice(0, 2, None), slice(0, 1, None)), + (slice(0, 2, None), slice(0, 2, None), slice(1, 2, None)), + ] + + tiled_indices = tile_image((2, 2, 2), (1, 1, 1)) + assert tiled_indices == [ + (slice(0, 1, None), slice(0, 1, None), slice(0, 1, None)), + (slice(0, 1, None), slice(0, 1, None), slice(1, 2, None)), + (slice(0, 1, None), slice(1, 2, None), slice(0, 1, None)), + (slice(0, 1, None), slice(1, 2, None), slice(1, 2, None)), + (slice(1, 2, None), slice(0, 1, None), slice(0, 1, None)), + (slice(1, 2, None), slice(0, 1, None), slice(1, 2, None)), + (slice(1, 2, None), slice(1, 2, None), slice(0, 1, None)), + (slice(1, 2, None), slice(1, 2, None), slice(1, 2, None)), + ] + + # when image dimension are not multiple of patch size + tiled_indices = tile_image((5, 5), (3, 3)) + assert tiled_indices == [ + (slice(0, 3, None), slice(0, 3, None)), + (slice(0, 3, None), slice(2, 5, None)), + (slice(2, 5, None), slice(0, 3, None)), + (slice(2, 5, None), slice(2, 5, None)), + ] + + tiled_indices = tile_image((10, 2, 2, 2), (5, 2, 1, 1)) + assert tiled_indices == [ + (slice(0, 5, None), slice(0, 2, None), slice(0, 1, None), slice(0, 1, None)), + (slice(0, 5, None), slice(0, 2, None), slice(0, 1, None), slice(1, 2, None)), + (slice(0, 5, None), slice(0, 2, None), slice(1, 2, None), slice(0, 1, None)), + (slice(0, 5, None), slice(0, 2, None), slice(1, 2, None), slice(1, 2, None)), + (slice(5, 10, None), slice(0, 2, None), slice(0, 1, None), slice(0, 1, None)), + (slice(5, 10, None), slice(0, 2, None), slice(0, 1, None), slice(1, 2, None)), + (slice(5, 10, None), slice(0, 2, None), slice(1, 2, None), slice(0, 1, None)), + (slice(5, 10, None), slice(0, 2, None), slice(1, 2, None), slice(1, 2, None)), + ] + + # when image too small for the patch + with pytest.raises(AssertionError): + tiled_indices = tile_image((1, 48, 48), (1, 64, 32)) + tiled_indices = tile_image((64, 64), (2, 1, 1)) + + +def test_MrRobot(): + + robo = MrRobot("/home/psharma/psharma/repos/tiktorch/mr_robot/robot_config.yml", "strategyabstract", ["gpu:3"]) + assert isinstance(robo, MrRobot) + assert isinstance(robo.new_server, TikTorchServer) + + robo._load_model() + robo._run() + + +def test_get_confusion_matrix(): + predicted = np.array([1, 2, 3]) + actual = np.array([2, 1, 3]) + classes = [0, 1, 2, 3] + + expected = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]] + + res = get_confusion_matrix(predicted, actual, classes) + assert expected == res.tolist() + + +def test_make_plot(): + stats = { + "training_loss": [], + "robo_predict_accuracy": [], + "robo_predict_f1_score": [], + "robo_predict_loss": [], + "confusion_matrix": [], + "validation_loss": [], + "validation_accuracy": [], + "training_iterations": [], + "number_of_patches": [], + } + + for key in stats.keys(): + stats[key] = random.sample(range(0, 100), 10) + stats["training_iterations"] = [i for i in range(0, 100, 10)] + stats["number_of_patches"] = [i for i in range(0, 80, 8)] + + make_plot(stats) + # plt.savefig() + # plt.show() diff --git a/tests/robot/test_mr_robot.py b/tests/robot/test_mr_robot.py new file mode 100644 index 00000000..18240526 --- /dev/null +++ b/tests/robot/test_mr_robot.py @@ -0,0 +1,104 @@ +import os +import random + +import numpy as np + +import matplotlib.pyplot as plt +import pytest +from mr_robot.mr_robot import MrRobot +from mr_robot.strategies.strategy import ClassWiseLoss, HighestLoss, StrategyAbstract, StrategyRandom +from mr_robot.utils import get_confusion_matrix, make_plot, tile_image +from tiktorch.server import TikTorchServer + +os.environ["CUDA_VISIBLE_DEVICES"] = "7" + + +def test_tile_image(): + # when image dim are multiple of patch size + tiled_indices = tile_image((2, 2, 2), (2, 2, 1)) + assert tiled_indices == [ + (slice(0, 2, None), slice(0, 2, None), slice(0, 1, None)), + (slice(0, 2, None), slice(0, 2, None), slice(1, 2, None)), + ] + + tiled_indices = tile_image((2, 2, 2), (1, 1, 1)) + assert tiled_indices == [ + (slice(0, 1, None), slice(0, 1, None), slice(0, 1, None)), + (slice(0, 1, None), slice(0, 1, None), slice(1, 2, None)), + (slice(0, 1, None), slice(1, 2, None), slice(0, 1, None)), + (slice(0, 1, None), slice(1, 2, None), slice(1, 2, None)), + (slice(1, 2, None), slice(0, 1, None), slice(0, 1, None)), + (slice(1, 2, None), slice(0, 1, None), slice(1, 2, None)), + (slice(1, 2, None), slice(1, 2, None), slice(0, 1, None)), + (slice(1, 2, None), slice(1, 2, None), slice(1, 2, None)), + ] + + # when image dimension are not multiple of patch size + tiled_indices = tile_image((5, 5), (3, 3)) + assert tiled_indices == [ + (slice(0, 3, None), slice(0, 3, None)), + (slice(0, 3, None), slice(2, 5, None)), + (slice(2, 5, None), slice(0, 3, None)), + (slice(2, 5, None), slice(2, 5, None)), + ] + + tiled_indices = tile_image((10, 2, 2, 2), (5, 2, 1, 1)) + assert tiled_indices == [ + (slice(0, 5, None), slice(0, 2, None), slice(0, 1, None), slice(0, 1, None)), + (slice(0, 5, None), slice(0, 2, None), slice(0, 1, None), slice(1, 2, None)), + (slice(0, 5, None), slice(0, 2, None), slice(1, 2, None), slice(0, 1, None)), + (slice(0, 5, None), slice(0, 2, None), slice(1, 2, None), slice(1, 2, None)), + (slice(5, 10, None), slice(0, 2, None), slice(0, 1, None), slice(0, 1, None)), + (slice(5, 10, None), slice(0, 2, None), slice(0, 1, None), slice(1, 2, None)), + (slice(5, 10, None), slice(0, 2, None), slice(1, 2, None), slice(0, 1, None)), + (slice(5, 10, None), slice(0, 2, None), slice(1, 2, None), slice(1, 2, None)), + ] + + # when image too small for the patch + with pytest.raises(AssertionError): + tiled_indices = tile_image((1, 48, 48), (1, 64, 32)) + tiled_indices = tile_image((64, 64), (2, 1, 1)) + + +def test_MrRobot(): + + robo = MrRobot("/home/psharma/psharma/repos/tiktorch/mr_robot/robot_config.yml", [StrategyRandom, HighestLoss], ["gpu:3"]) + assert isinstance(robo, MrRobot) + assert isinstance(robo.new_server, TikTorchServer) + + robo._load_model() + robo._run() + + +def test_get_confusion_matrix(): + predicted = np.array([1, 2, 3]) + actual = np.array([2, 1, 3]) + classes = [0, 1, 2, 3] + + expected = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]] + + res = get_confusion_matrix(predicted, actual, classes) + assert expected == res.tolist() + + +def test_make_plot(): + stats = { + "training_loss": [], + "robo_predict_accuracy": [], + "robo_predict_f1_score": [], + "robo_predict_loss": [], + "confusion_matrix": [], + "validation_loss": [], + "validation_accuracy": [], + "training_iterations": [], + "number_of_patches": [], + } + + for key in stats.keys(): + stats[key] = random.sample(range(0, 100), 10) + stats["training_iterations"] = [i for i in range(0, 100, 10)] + stats["number_of_patches"] = [i for i in range(0, 80, 8)] + + make_plot(stats) + # plt.savefig() + # plt.show() diff --git a/tests/test_server/test_handler/test_training.py b/tests/test_server/test_handler/test_training.py index 04fbdb68..19643c0d 100644 --- a/tests/test_server/test_handler/test_training.py +++ b/tests/test_server/test_handler/test_training.py @@ -84,4 +84,30 @@ def test_training_in_proc(tiny_model_2d, log_queue): client.shutdown() +def test_train_for(tiny_model_2d): + config = tiny_model_2d["config"] + config["num_iterations_per_update"] = 10 + in_channels = config["input_channels"] + model = TinyConvNet2d(in_channels=in_channels) + training = TrainingProcess(config=config, model=model) + try: + data = TikTensorBatch( + [ + TikTensor(torch.zeros(in_channels, 15, 15), ((1,), (1,))), + TikTensor(torch.ones(in_channels, 9, 9), ((2,), (2,))), + ] + ) + labels = TikTensorBatch( + [ + TikTensor(torch.ones(in_channels, 15, 15, dtype=torch.uint8), ((1,), (1,))), + TikTensor(torch.full((in_channels, 9, 9), 2, dtype=torch.uint8), ((2,), (2,))), + ] + ) + training.set_devices([torch.device("cpu")]) + training.update_dataset("training", data, labels) + res = training.train_for(10).result() + finally: + training.shutdown() + + # def test_validation(tiny_model_2d): diff --git a/tiktorch/rpc_interface.py b/tiktorch/rpc_interface.py index 83326eeb..065e0c4a 100644 --- a/tiktorch/rpc_interface.py +++ b/tiktorch/rpc_interface.py @@ -70,3 +70,7 @@ def log(self, msg: str) -> None: @exposed def get_model_state(self) -> ModelState: raise NotImplementedError + + @exposed + def train_for(self, num_iterations: int) -> RPCFuture: + raise NotImplementedError diff --git a/tiktorch/server/base.py b/tiktorch/server/base.py index d4074678..03beff44 100644 --- a/tiktorch/server/base.py +++ b/tiktorch/server/base.py @@ -230,6 +230,9 @@ def forward(self, image: NDArray) -> RPCFuture[NDArray]: def update_training_data(self, data: NDArrayBatch, labels: NDArrayBatch) -> None: self.handler.update_training_data(TikTensorBatch(data), TikTensorBatch(labels)) + def train_for(self, num_iterations: int) -> RPCFuture: + return self.handler.train_for(num_iterations) + def update_validation_data(self, data: NDArrayBatch, labels: NDArrayBatch) -> None: return self.handler.update_validation_data(TikTensorBatch(data), TikTensorBatch(labels)) diff --git a/tiktorch/server/handler/datasets.py b/tiktorch/server/handler/datasets.py index e969356e..5091fe69 100644 --- a/tiktorch/server/handler/datasets.py +++ b/tiktorch/server/handler/datasets.py @@ -54,7 +54,7 @@ def update(self, images: TikTensorBatch, labels: TikTensorBatch) -> None: # self.data.update(zip(keys, values)) # update update counts for image, label in zip(images, labels): - assert image.id == label.id + assert image.id == label.id, (image.id, label.id) assert image.id is not None key = image.id self.update_counts[key] = self.update_counts.get(key, 0) + 1 diff --git a/tiktorch/server/handler/handler.py b/tiktorch/server/handler/handler.py index f9458f96..86077e0d 100644 --- a/tiktorch/server/handler/handler.py +++ b/tiktorch/server/handler/handler.py @@ -77,6 +77,10 @@ def resume_training(self) -> None: def pause_training(self) -> None: raise NotImplementedError + @exposed + def train_for(self, num_iterations: int) -> RPCFuture: + raise NotImplementedError + @exposed def update_training_data(self, data: TikTensorBatch, labels: TikTensorBatch) -> None: raise NotImplementedError @@ -503,6 +507,9 @@ def resume_training(self) -> None: if not self.training_devices: self.new_device_names.put("whatever_just_update_idle_because_this_is_not_a_tuple_nor_None") + def train_for(self, num_iterations: int) -> RPCFuture: + return self.training.train_for(num_iterations) + def pause_training(self) -> None: self.training.pause_training() diff --git a/tiktorch/server/handler/training.py b/tiktorch/server/handler/training.py index 2a2dcff3..9dd178c2 100644 --- a/tiktorch/server/handler/training.py +++ b/tiktorch/server/handler/training.py @@ -126,6 +126,10 @@ def pause_training(self) -> None: def get_idle(self) -> bool: raise NotImplementedError + @exposed + def train_for(self, num_iterations: int) -> RPCFuture: + raise NotImplementedError + @exposed def update_dataset(self, name: str, data: TikTensorBatch, labels: TikTensorBatch): raise NotImplementedError @@ -173,6 +177,7 @@ def __init__(self, config: dict, model: torch.nn.Module, optimizer_state: bytes self.logger.info("started") self.shutdown_event = threading.Event() self.idle = False + self._train_for = {} # Iteration number -> Future self.common_model = model self.model = copy.deepcopy(model) @@ -231,7 +236,9 @@ def __init__(self, config: dict, model: torch.nn.Module, optimizer_state: bytes log_scalars_every=(1, "iteration"), log_images_every=(1, "epoch"), ) - self.trainer.register_callback(self.end_of_training_iteration, trigger="end_of_training_iteration") + self.trainer.register_callback( + self.begin_of_training_iteration, trigger=self.trainer.callbacks.BEGIN_OF_TRAINING_ITERATION + ) # FIXME: End of training iteration not called by inferno trainer self.trainer.register_callback(self.end_of_validation_iteration, trigger="end_of_validation_iteration") self.trainer._iteration_count = self.config[TRAINING].get(NUM_ITERATIONS_DONE, 0) @@ -250,8 +257,11 @@ def __init__(self, config: dict, model: torch.nn.Module, optimizer_state: bytes def end_of_validation_iteration(self, trigger): pass # todo: return validation - def end_of_training_iteration(self, iteration_num, trigger): - pass + def begin_of_training_iteration(self, iteration_num, trigger): + end_iteration = self.trainer.iteration_count + 1 + res = self._train_for.pop(end_iteration, None) + if res: + res.set_result(None) def create_trainer_config(self) -> Dict: trainer_config = {} @@ -327,7 +337,8 @@ def _training_worker(self): try: self.trainer.fit() except Exception as e: - self.logger.debug(e, exc_info=True) + self.trainer.next_iteration() # XXX(m-novikov) + self.logger.debug("Exception during trainer fit", exc_info=True) self.logger.info( "Break training at %d/%d iterations", @@ -410,6 +421,14 @@ def shutdown(self) -> Shutdown: self.logger.debug("Shutdown complete") return Shutdown() + def train_for(self, num_iterations: int) -> RPCFuture: + res = RPCFuture() + self.config[TRAINING][NUM_ITERATIONS_MAX] += num_iterations + self._train_for[self.config[TRAINING][NUM_ITERATIONS_MAX]] = res + self.update_trainer_event.set() + self._pause_event.clear() + return res + def resume_training(self) -> None: self.logger.warning("RESUME") self._pause_event.clear()