Skip to content
This repository has been archived by the owner on Apr 4, 2023. It is now read-only.

Commit

Permalink
Automatic merge of updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick committed Jun 22, 2021
1 parent f5cb7d8 commit e6bd554
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 2 deletions.
Binary file modified joss/paper.pdf
Binary file not shown.
26 changes: 25 additions & 1 deletion tfaip/predict/predictorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tfaip.data.pipeline.datagenerator import DataGenerator
from tfaip.data.pipeline.datapipeline import DataPipeline
from tfaip.device.device_config import DeviceConfig, distribute_strategy
from tfaip.predict.raw_predictor import RawPredictor
from tfaip.trainer.callbacks.benchmark_callback import BenchmarkResults
from tfaip.util.multiprocessing.parallelmap import tqdm_wrapper
from tfaip.util.profiling import MeasureTime
Expand Down Expand Up @@ -95,6 +96,27 @@ def _load_model(self, model: Union[str, keras.Model]):

return model

def raw(self) -> RawPredictor:
"""Create a raw predictor from this predictor that allows to asynchronly predict raw samples.
Usage:
Either call
with predictor.raw() as raw_pred:
raw_pred(sample1)
raw_pred(sample2)
...
or
raw_pred = predictor.raw().__enter__()
raw_pred(sample1)
raw_pred(sample2)
"""
return RawPredictor(self)

def predict(self, params: DataGeneratorParams) -> Iterable[Sample]:
"""
Predict a DataGenerator based on its params.
Expand Down Expand Up @@ -134,6 +156,7 @@ def predict_raw(self, inputs: Iterable[Any], *, size=None) -> Iterable[Sample]:
- predict_pipeline
- predict_dataset
- predict
- raw
"""
if size is None:
# Automatically compute the size (number of samples)
Expand All @@ -159,10 +182,11 @@ def create_data_generator(self) -> DataGenerator:
return RawGenerator(mode=self.mode, params=self.generator_params)

pipeline = RawInputsPipeline(
DataPipelineParams(mode=PipelineMode.PREDICTION),
self.params.pipeline,
self._data,
DataGeneratorParams(),
)

return self.predict_pipeline(pipeline)

def predict_pipeline(self, pipeline: DataPipeline) -> Iterable[Sample]:
Expand Down
68 changes: 68 additions & 0 deletions tfaip/predict/raw_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
from contextlib import ExitStack
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from tfaip.predict.predictorbase import PredictorBase

logger = logging.getLogger(__name__)


class StopSignal:
...


class RawPredictorThread(Thread):
def __init__(self, in_queue: Queue, out_queue: Queue, predictor: "PredictorBase"):
super().__init__(daemon=True)
self.predictor = predictor
self.in_queue = in_queue
self.out_queue = out_queue

def run(self) -> None:
def generator():
while True:
sample = self.in_queue.get()
if isinstance(sample, StopSignal):
break
yield sample

for sample in self.predictor.predict_raw(generator(), size=5):
self.out_queue.put(sample)


class RawPredictorCaller:
def __init__(self, in_queue, out_queue):
self.in_queue = in_queue
self.out_queue = out_queue

def __call__(self, sample):
self.in_queue.put(sample)
return self.out_queue.get()


class RawPredictor:
"""Utility class to allow for raw prediction but keeping the internal queues open."""

def __init__(self, predictor: "PredictorBase"):
self.predictor = predictor
if self.predictor.params.pipeline.batch_size != 1:
logger.warning(f"Raw prediction via threading requires batch size == 1. Automatically setting.")
self.predictor.params.pipeline.batch_size = 1
if not self.predictor.params.silent:
logger.warning(f"Consider setting predictor to silent by setting predictor_params.silent = True.")

self.exit_stack = ExitStack()
self.in_queue = Queue(10)
self.out_queue = Queue(10)
self.thread = RawPredictorThread(self.in_queue, self.out_queue, predictor)

def __enter__(self):
self.thread.start()
return RawPredictorCaller(self.in_queue, self.out_queue)

def __exit__(self, exc_type, exc_val, exc_tb):
self.in_queue.put(StopSignal())
self.thread.join()
2 changes: 1 addition & 1 deletion tfaip/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# You should have received a copy of the GNU General Public License along with
# tfaip. If not, see http://www.gnu.org/licenses/.
# ==============================================================================
__version__ = "1.2.1"
__version__ = "1.2.2"

0 comments on commit e6bd554

Please sign in to comment.