diff --git a/joss/paper.pdf b/joss/paper.pdf index ee8e792..df425e8 100644 Binary files a/joss/paper.pdf and b/joss/paper.pdf differ diff --git a/tfaip/predict/predictorbase.py b/tfaip/predict/predictorbase.py index afe0694..f07893c 100644 --- a/tfaip/predict/predictorbase.py +++ b/tfaip/predict/predictorbase.py @@ -30,6 +30,7 @@ from tfaip.data.pipeline.datagenerator import DataGenerator from tfaip.data.pipeline.datapipeline import DataPipeline from tfaip.device.device_config import DeviceConfig, distribute_strategy +from tfaip.predict.raw_predictor import RawPredictor from tfaip.trainer.callbacks.benchmark_callback import BenchmarkResults from tfaip.util.multiprocessing.parallelmap import tqdm_wrapper from tfaip.util.profiling import MeasureTime @@ -95,6 +96,27 @@ def _load_model(self, model: Union[str, keras.Model]): return model + def raw(self) -> RawPredictor: + """Create a raw predictor from this predictor that allows to asynchronly predict raw samples. + + Usage: + + Either call + + with predictor.raw() as raw_pred: + raw_pred(sample1) + raw_pred(sample2) + ... + + or + + raw_pred = predictor.raw().__enter__() + raw_pred(sample1) + raw_pred(sample2) + + """ + return RawPredictor(self) + def predict(self, params: DataGeneratorParams) -> Iterable[Sample]: """ Predict a DataGenerator based on its params. @@ -134,6 +156,7 @@ def predict_raw(self, inputs: Iterable[Any], *, size=None) -> Iterable[Sample]: - predict_pipeline - predict_dataset - predict + - raw """ if size is None: # Automatically compute the size (number of samples) @@ -159,10 +182,11 @@ def create_data_generator(self) -> DataGenerator: return RawGenerator(mode=self.mode, params=self.generator_params) pipeline = RawInputsPipeline( - DataPipelineParams(mode=PipelineMode.PREDICTION), + self.params.pipeline, self._data, DataGeneratorParams(), ) + return self.predict_pipeline(pipeline) def predict_pipeline(self, pipeline: DataPipeline) -> Iterable[Sample]: diff --git a/tfaip/predict/raw_predictor.py b/tfaip/predict/raw_predictor.py new file mode 100644 index 0000000..b6553f2 --- /dev/null +++ b/tfaip/predict/raw_predictor.py @@ -0,0 +1,68 @@ +import logging +from contextlib import ExitStack +from queue import Queue +from threading import Thread +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tfaip.predict.predictorbase import PredictorBase + +logger = logging.getLogger(__name__) + + +class StopSignal: + ... + + +class RawPredictorThread(Thread): + def __init__(self, in_queue: Queue, out_queue: Queue, predictor: "PredictorBase"): + super().__init__(daemon=True) + self.predictor = predictor + self.in_queue = in_queue + self.out_queue = out_queue + + def run(self) -> None: + def generator(): + while True: + sample = self.in_queue.get() + if isinstance(sample, StopSignal): + break + yield sample + + for sample in self.predictor.predict_raw(generator(), size=5): + self.out_queue.put(sample) + + +class RawPredictorCaller: + def __init__(self, in_queue, out_queue): + self.in_queue = in_queue + self.out_queue = out_queue + + def __call__(self, sample): + self.in_queue.put(sample) + return self.out_queue.get() + + +class RawPredictor: + """Utility class to allow for raw prediction but keeping the internal queues open.""" + + def __init__(self, predictor: "PredictorBase"): + self.predictor = predictor + if self.predictor.params.pipeline.batch_size != 1: + logger.warning(f"Raw prediction via threading requires batch size == 1. Automatically setting.") + self.predictor.params.pipeline.batch_size = 1 + if not self.predictor.params.silent: + logger.warning(f"Consider setting predictor to silent by setting predictor_params.silent = True.") + + self.exit_stack = ExitStack() + self.in_queue = Queue(10) + self.out_queue = Queue(10) + self.thread = RawPredictorThread(self.in_queue, self.out_queue, predictor) + + def __enter__(self): + self.thread.start() + return RawPredictorCaller(self.in_queue, self.out_queue) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.in_queue.put(StopSignal()) + self.thread.join() diff --git a/tfaip/version.py b/tfaip/version.py index 776a2ac..0a5634c 100644 --- a/tfaip/version.py +++ b/tfaip/version.py @@ -15,4 +15,4 @@ # You should have received a copy of the GNU General Public License along with # tfaip. If not, see http://www.gnu.org/licenses/. # ============================================================================== -__version__ = "1.2.1" +__version__ = "1.2.2"