Skip to content
This repository has been archived by the owner on Apr 4, 2023. It is now read-only.

Commit

Permalink
Automatic merge of updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick committed Jan 18, 2021
1 parent d22de6d commit dd771eb
Show file tree
Hide file tree
Showing 30 changed files with 268 additions and 156 deletions.
2 changes: 1 addition & 1 deletion test/base/trainer/test_accum_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from test.util.store_logs_callback import StoreLogsCallback
from tfaip.base import TrainerParams
from tfaip.scenario.tutorial.full.data import DataParams
from tfaip.scenario.tutorial.full.data.data_params import DataParams
from tfaip.scenario.tutorial.full.scenario import TutorialScenario


Expand Down
2 changes: 1 addition & 1 deletion test/base/trainer/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tensorflow.python.keras.backend import clear_session

from tfaip.base import TrainerParams
from tfaip.scenario.tutorial.full.data import DataParams
from tfaip.scenario.tutorial.full.data.data_params import DataParams
from tfaip.scenario.tutorial.full.scenario import TutorialScenario
import logging
logging.basicConfig(level=logging.DEBUG)
Expand Down
2 changes: 1 addition & 1 deletion test/base/trainer/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from test.util.store_logs_callback import StoreLogsCallback
from tfaip.base import TrainerParams
from tfaip.scenario.tutorial.full.data import DataParams
from tfaip.scenario.tutorial.full.data.data_params import DataParams
from tfaip.scenario.tutorial.full.scenario import TutorialScenario


Expand Down
2 changes: 1 addition & 1 deletion test/base/trainer/test_multiple_lav_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from test.util.store_logs_callback import StoreLogsCallback
from tfaip.base import TrainerParams
from tfaip.scenario.tutorial.full.data import DataParams
from tfaip.scenario.tutorial.full.data.data_params import DataParams
from tfaip.scenario.tutorial.full.scenario import TutorialScenario


Expand Down
5 changes: 3 additions & 2 deletions test/scenario/tutorial/test_tutorial_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from test.util.training import resume_training, single_train_iter, lav_test_case, warmstart_training_test_case
from tfaip.base.data.databaseparams import DataGeneratorParams
from tfaip.scenario.tutorial.full.data import DataParams, Data
from tfaip.scenario.tutorial.full.data.data import Data
from tfaip.scenario.tutorial.full.data.data_params import DataParams
from tfaip.scenario.tutorial.full.scenario import TutorialScenario


Expand Down Expand Up @@ -66,7 +67,7 @@ def check(data):


class TestTutorialTrain(unittest.TestCase):
def setUp(self) -> None:
def tearDown(self) -> None:
clear_session()

def test_single_train_iter(self):
Expand Down
1 change: 0 additions & 1 deletion test/scripts/test_lav.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_multi_lav_tutorial(self):
'--data_params', 'train.batch_size=2',
])
check_call(['tfaip-multi-lav',
'--scenario', 'tutorial.full',
'--export_dirs', os.path.join(d, 'best'), os.path.join(d, 'best'),
'--data', 'limit=10',
])
14 changes: 10 additions & 4 deletions test/util/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# ==============================================================================
import json
import os
import sys
import tempfile
import time
import unittest
Expand All @@ -31,7 +32,11 @@
from tfaip.util.random import set_global_random_seed


def warmstart_training_test_case(test: unittest.TestCase, scenario, scenario_params: ScenarioBaseParams, debug=True,
debug_test = sys.flags.debug


def warmstart_training_test_case(test: unittest.TestCase, scenario, scenario_params: ScenarioBaseParams,
debug=debug_test,
delta=None):
# First train a normal iteration and store the results of metrics and losses with a fixed seed
# Then reload the model as warmstart, train an epoch but with a learning rate of 0
Expand Down Expand Up @@ -77,7 +82,7 @@ def warmstart_training_test_case(test: unittest.TestCase, scenario, scenario_par
test.assertAlmostEqual(v, initial_logs[k], delta=delta)


def single_train_iter(test: unittest.TestCase, scenario, scenario_params: ScenarioBaseParams, debug=True):
def single_train_iter(test: unittest.TestCase, scenario, scenario_params: ScenarioBaseParams, debug=debug_test):
scenario_params.debug_graph_construction = debug
scenario_params.debug_graph_n_examples = 1
trainer_params = TrainerParams(
Expand All @@ -95,7 +100,7 @@ def single_train_iter(test: unittest.TestCase, scenario, scenario_params: Scenar
trainer.train()


def lav_test_case(test: unittest.TestCase, scenario: Type[ScenarioBase], scenario_params, debug=True,
def lav_test_case(test: unittest.TestCase, scenario: Type[ScenarioBase], scenario_params, debug=False,
delta=None):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer_params = TrainerParams(
Expand Down Expand Up @@ -147,7 +152,7 @@ def lav_test_case(test: unittest.TestCase, scenario: Type[ScenarioBase], scenari
test.assertAlmostEqual(bs1_results[k], bs5_results[k], delta=delta, msg=f"on key {k}")


def resume_training(test: unittest.TestCase, scenario, scenario_params, delta=None):
def resume_training(test: unittest.TestCase, scenario, scenario_params, delta=None, debug=debug_test):
# simulate by setting epochs to 1, then loading the trainer_params and setting epochs to 2
with tempfile.TemporaryDirectory() as tmp_dir:
store_logs_callback = StoreLogsCallback()
Expand All @@ -156,6 +161,7 @@ def resume_training(test: unittest.TestCase, scenario, scenario_params, delta=No
epochs=1,
samples_per_epoch=scenario_params.data_params.train.batch_size,
skip_model_load_test=True, # not required in this test
force_eager=debug,
export_final=False,
export_best=False,
scenario_params=scenario_params,
Expand Down
2 changes: 1 addition & 1 deletion tfaip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# You should have received a copy of the GNU General Public License along with
# tfaip. If not, see http://www.gnu.org/licenses/.
# ==============================================================================
__version__ = "1.0.0"
__version__ = "1.0.1"
2 changes: 1 addition & 1 deletion tfaip/base/data/databaseparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from dataclasses_json import dataclass_json

from tfaip.base.data.pipeline.datapipeline import SamplePipelineParams
from tfaip.base.data.pipeline.sample.params import SamplePipelineParams
from tfaip.util.argumentparser import dc_meta

logger = logging.getLogger(__name__)
Expand Down
35 changes: 35 additions & 0 deletions tfaip/base/data/pipeline/datagenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from random import shuffle
from typing import Iterable, List

from tfaip.base import DataGeneratorParams
from tfaip.base.data.pipeline.definitions import PipelineMode, Sample


class DataGenerator(ABC):
def __init__(self, mode: PipelineMode, params: 'DataGeneratorParams'):
params.validate()
self.mode = mode
self.params = params

@abstractmethod
def __len__(self):
raise NotImplementedError

@abstractmethod
def generate(self) -> Iterable[Sample]:
raise NotImplementedError


class RawDataGenerator(DataGenerator):
def __init__(self, raw_data: List[Sample], mode: PipelineMode, params: 'DataGeneratorParams'):
super(RawDataGenerator, self).__init__(mode, params)
self.raw_data = raw_data

def __len__(self):
return len(self.raw_data)

def generate(self) -> Iterable[Sample]:
if self.mode == PipelineMode.Training:
shuffle(self.raw_data)
return self.raw_data
93 changes: 8 additions & 85 deletions tfaip/base/data/pipeline/datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@
import copy
import gc
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial
from random import shuffle
from typing import TYPE_CHECKING, List, Iterable, Optional, Callable, Type
import logging

from dataclasses_json import dataclass_json

from tfaip.base.data.pipeline.dataprocessor import DataProcessorFactory, SequenceProcessor, DataProcessor
from tfaip.base.data.pipeline.definitions import Sample, PipelineMode, DataProcessorFactoryParams, \
GENERAL_PROCESSOR
from tfaip.base.data.pipeline.parallelpipeline import ParallelDataProcessorPipeline
from tfaip.base.data.pipeline.datagenerator import DataGenerator, RawDataGenerator
from tfaip.base.data.pipeline.dataprocessor import SequenceProcessor, DataProcessor
from tfaip.base.data.pipeline.definitions import Sample, PipelineMode
from tfaip.base.data.pipeline.sample.params import SamplePipelineParams
from tfaip.base.data.pipeline.sample.processorpipeline import SampleProcessorPipeline, ParallelSampleProcessingPipeline
from tfaip.base.data.pipeline.tfdatasetgenerator import TFDatasetGenerator
from tfaip.util.multiprocessing.join import JoinableHolder

Expand All @@ -41,78 +38,6 @@
logger = logging.getLogger(__name__)


class SampleConsumer:
pass


def create_processor_fn(factory: DataProcessorFactory, processors: List[DataProcessorFactoryParams], params, mode: PipelineMode) -> SequenceProcessor:
return factory.create_sequence(processors, params, mode)


@dataclass_json
@dataclass
class SamplePipelineParams:
run_parallel: bool = True
sample_processors: List[DataProcessorFactoryParams] = field(default_factory=list)


class SampleProcessorPipeline:
def __init__(self, data_pipeline: 'DataPipeline', processor_fn: Optional[Callable[[], SequenceProcessor]] = None):
self.data_pipeline = data_pipeline
self.create_processor_fn = processor_fn

def apply(self, samples: Iterable[Sample]) -> Iterable[Sample]:
if not self.create_processor_fn:
for sample in samples:
yield sample
else:
processor = self.create_processor_fn()
for sample in samples:
r = processor.apply_on_sample(sample)
if r is not None:
yield r


class ParallelSampleProcessingPipeline(SampleProcessorPipeline):
def apply(self, samples: Iterable[Sample]) -> Iterable[Sample]:
parallel_pipeline = ParallelDataProcessorPipeline(self.data_pipeline, samples,
create_processor_fn=self.create_processor_fn,
auto_repeat_input=False)
for x in parallel_pipeline.output_generator():
yield x

parallel_pipeline.join()


class DataGenerator(ABC):
def __init__(self, mode: PipelineMode, params: 'DataGeneratorParams'):
params.validate()
self.mode = mode
self.params = params

@abstractmethod
def __len__(self):
raise NotImplementedError

@abstractmethod
def generate(self) -> Iterable[Sample]:
raise NotImplementedError


class RawDataGenerator(DataGenerator):
def __init__(self, raw_data: List[Sample], mode: PipelineMode, params: 'DataGeneratorParams'):
super(RawDataGenerator, self).__init__(mode, params)
self.raw_data = raw_data

def __len__(self):
return len(self.raw_data)

def generate(self) -> Iterable[Sample]:
if self.mode == PipelineMode.Training:
shuffle(self.raw_data)
return self.raw_data


def _create_sequence_processor_fn(factory, *args) -> Callable[[], SequenceProcessor]:
return factory.create_sequence(*args)

Expand Down Expand Up @@ -161,7 +86,9 @@ def auto_batch(self):
def create_data_generator(self) -> DataGenerator:
raise NotImplementedError

def flat_input_processors(self, preload=False, non_preloadable_params=[]) -> List[DataProcessor]:
def flat_input_processors(self, preload=False, non_preloadable_params=None) -> List[DataProcessor]:
if non_preloadable_params is None:
non_preloadable_params = []
factory = self.data.__class__.data_processor_factory()
params: SamplePipelineParams = self._input_processors

Expand Down Expand Up @@ -208,9 +135,6 @@ def create_output_pipeline(self) -> Optional[SampleProcessorPipeline]:
return SampleProcessorPipeline(self, self._sequence_processor_fn(params))
return SampleProcessorPipeline(self)

def create_data_consumer(self) -> SampleConsumer:
return SampleConsumer()

def __enter__(self):
from tfaip.base.data.pipeline.runningdatapipeline import RunningDataPipeline
return RunningDataPipeline(self)
Expand Down Expand Up @@ -254,7 +178,6 @@ def __init__(self,
super(RawDataPipeline, self).__init__(mode, data_base, generator_params, input_processors, output_processors)
self.samples = samples


def to_mode(self, mode: PipelineMode) -> 'DataPipeline':
return self.__class__(self.samples, mode, self.data, self.generator_params, self._input_processors, self._output_processors)

Expand Down
Empty file.
13 changes: 13 additions & 0 deletions tfaip/base/data/pipeline/sample/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass, field
from typing import List

from dataclasses_json import dataclass_json

from tfaip.base.data.pipeline.definitions import DataProcessorFactoryParams


@dataclass_json
@dataclass
class SamplePipelineParams:
run_parallel: bool = True
sample_processors: List[DataProcessorFactoryParams] = field(default_factory=list)
40 changes: 40 additions & 0 deletions tfaip/base/data/pipeline/sample/processorpipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List, Optional, Callable, Iterable, TYPE_CHECKING

from tfaip.base.data.pipeline.dataprocessor import DataProcessorFactory, SequenceProcessor
from tfaip.base.data.pipeline.definitions import DataProcessorFactoryParams, Sample, PipelineMode
from tfaip.base.data.pipeline.parallelpipeline import ParallelDataProcessorPipeline

if TYPE_CHECKING:
from tfaip.base.data.pipeline.datapipeline import DataPipeline


def create_processor_fn(factory: DataProcessorFactory, processors: List[DataProcessorFactoryParams], params, mode: PipelineMode) -> SequenceProcessor:
return factory.create_sequence(processors, params, mode)


class SampleProcessorPipeline:
def __init__(self, data_pipeline: 'DataPipeline', processor_fn: Optional[Callable[[], SequenceProcessor]] = None):
self.data_pipeline = data_pipeline
self.create_processor_fn = processor_fn

def apply(self, samples: Iterable[Sample]) -> Iterable[Sample]:
if not self.create_processor_fn:
for sample in samples:
yield sample
else:
processor = self.create_processor_fn()
for sample in samples:
r = processor.apply_on_sample(sample)
if r is not None:
yield r


class ParallelSampleProcessingPipeline(SampleProcessorPipeline):
def apply(self, samples: Iterable[Sample]) -> Iterable[Sample]:
parallel_pipeline = ParallelDataProcessorPipeline(self.data_pipeline, samples,
create_processor_fn=self.create_processor_fn,
auto_repeat_input=False)
for x in parallel_pipeline.output_generator():
yield x

parallel_pipeline.join()
12 changes: 7 additions & 5 deletions tfaip/base/lav/multilav.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
data_gen_params: DataGeneratorParams,
predictor_fn: Callable[[List[str], PredictorParams], MultiModelPredictor],
evaluator: Evaluator,
predictor_params: PredictorParams
):
assert params.model_path
self._params = params
Expand All @@ -51,6 +52,10 @@ def __init__(self,
self._evaluator = evaluator
self.device_config = DeviceConfig(self._params.device_params)
self.benchmark_results = PredictorBenchmarkResults()
self.predictor_params = predictor_params
predictor_params.silent = True
predictor_params.progress_bar = True
predictor_params.include_targets = True

@distribute_strategy
def run(self,
Expand All @@ -60,11 +65,8 @@ def run(self,
if callbacks is None:
callbacks = []

predictor_params = PredictorParams(self._params.device_params,
silent=True, progress_bar=True, run_eagerly=run_eagerly,
include_targets=True,
)
predictor = self._predictor_fn(self._params.model_path, predictor_params)
self.predictor_params.run_eagerly = run_eagerly
predictor = self._predictor_fn(self._params.model_path, self.predictor_params)
lav_pipeline = predictor.data.get_pipeline(PipelineMode.Evaluation, self._data_gen_params)

for cb in callbacks:
Expand Down
Loading

0 comments on commit dd771eb

Please sign in to comment.