From 3314bddc4b317cff80722ebe1776eaab253140c6 Mon Sep 17 00:00:00 2001 From: "Jason (Joo Seong) Jeong" Date: Sat, 13 Jul 2019 16:52:43 +0000 Subject: [PATCH 1/5] implement shared tensor mechanism --- benchmark.py | 29 +++++++++---- client.py | 2 +- config/r2p1d-whole.json | 2 +- control.py | 93 +++++++++++++++++++++++++++++++++++++++++ models/r2p1d/model.py | 20 +++++++++ runner.py | 76 +++++++++++++++++++++++++++++---- runner_model.py | 5 +++ 7 files changed, 208 insertions(+), 19 deletions(-) diff --git a/benchmark.py b/benchmark.py index d6cf92d..860e10c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -131,7 +131,7 @@ def sanity_check(args): # change these if you want to use different client/loader/runner impls from rnb_logging import logmeta, logroot - from control import TerminationFlag, BenchmarkQueues + from control import TerminationFlag, BenchmarkQueues, BenchmarkTensors from client import * from runner import runner @@ -152,17 +152,21 @@ def sanity_check(args): parser.add_argument('-p', '--per_gpu_queue', help='Whether to place intermediate queues on each GPU', action='store_true') + parser.add_argument('-t', '--tensors_per_process', + help='Number of shared output tensors per process', + type=positive_int, default=100) args = parser.parse_args() print('Args:', args) sanity_check(args) - job_id = '%s-mi%d-b%d-v%d-qs%d-p%d' % (dt.today().strftime('%y%m%d_%H%M%S'), - args.mean_interval_ms, - args.batch_size, - args.videos, - args.queue_size, - args.per_gpu_queue) + job_id = '%s-mi%d-b%d-v%d-qs%d-p%d-t%d' % (dt.today().strftime('%y%m%d_%H%M%S'), + args.mean_interval_ms, + args.batch_size, + args.videos, + args.queue_size, + args.per_gpu_queue, + args.tensors_per_process) # do a quick pass through the pipeline to count the total number of runners with open(args.config_file_path, 'r') as f: @@ -210,6 +214,9 @@ def sanity_check(args): process_client = Process(target=client_impl, args=client_args) + # create BenchmarkTensors object for managing shared tensors between steps + benchmark_tensors = BenchmarkTensors(pipeline, args.tensors_per_process) + process_runner_list = [] for step_idx, step in enumerate(pipeline): is_final_step = step_idx == len(pipeline) - 1 @@ -224,6 +231,9 @@ def sanity_check(args): prev_queue, next_queue = benchmark_queues.get_tensor_queue(step_idx, gpu) + shared_input_tensors, shared_output_tensors = \ + benchmark_tensors.get_shared_tensors(step_idx, instance_idx) + # check the replica index of this particular runner, for this gpu # if this runner is the first, then give it index 0 replica_idx = replica_dict.get(gpu, 0) @@ -236,11 +246,12 @@ def sanity_check(args): process_runner = Process(target=runner, args=(prev_queue, next_queue, print_summary, - job_id, gpu, replica_idx, + job_id, gpu, replica_idx, instance_idx, global_inference_counter, args.videos, termination_flag, step_idx, sta_bar, fin_bar, - model), + model, shared_input_tensors, + shared_output_tensors), kwargs=step) replica_dict[gpu] = replica_idx + 1 diff --git a/client.py b/client.py index 20e5c4b..1fc009c 100644 --- a/client.py +++ b/client.py @@ -38,7 +38,7 @@ def poisson_client(video_path_iterator, filename_queue, beta, termination_flag, print('[WARNING] Filename queue is full. Aborting...') termination_flag.value = TerminationFlag.FILENAME_QUEUE_FULL break - time.sleep(exponential(float(beta) / 1000)) # milliseconds --> seconds + time.sleep(float(beta) / 1000) # milliseconds --> seconds # mark the end of the input stream # the loaders should exit by themselves, but we enqueue these markers just in diff --git a/config/r2p1d-whole.json b/config/r2p1d-whole.json index f6861e9..7659d24 100644 --- a/config/r2p1d-whole.json +++ b/config/r2p1d-whole.json @@ -8,7 +8,7 @@ }, { "model": "models.r2p1d.model.R2P1DRunner", - "gpus": [0], + "gpus": [0,1,2,3,4,5,6,7], "start_index": 1, "end_index": 5 } diff --git a/control.py b/control.py index 40b0376..619d4ee 100644 --- a/control.py +++ b/control.py @@ -1,3 +1,9 @@ +import torch +from collections import namedtuple +from torch.multiprocessing import Event +from utils.class_utils import load_class + + class TerminationFlag: """An enum class for representing various termination states.""" UNSET = -1 @@ -78,3 +84,90 @@ def get_tensor_queue(self, step_idx, gpu_idx): def get_filename_queue(self): return self.filename_queue + + +class TensorEvent: + """Basically a tuple of a torch.Tensor and multiprocessing.Event. + + The Tensor can be used as a "shared tensor" for passing intermediate tensors + across processes. + + The Event should be used to signal that the consumer process has finished + reading from the Tensor. When writing values to the Tensor, the producer + process should first check if the Tensor is free, by calling event.wait(). If + the Tensor is indeed free, then event.wait() will return immediately. If not, + then event.wait() will block until the consumer process calls event.set(). + Thus, the consumer should make sure that it calls event.set() AFTER the + Tensor's contents have been copied to a safe area, such as the consumer's own + local tensor. + """ + def __init__(self, shape, device, dtype=torch.float32): + self.tensor = torch.empty(*shape, dtype=dtype, device=device) + self.event = Event() + self.event.set() + + +class BenchmarkTensors: + """Manages intermediate tensors that are passed across steps in the benchmark. + + Args: + pipeline: The whole pipeline info parsed from the input configuration file + num_tensors_per_process: The number of shared output tensors that are given + to each process, for writing tensor values. A big value allows + processes to produce many tensors before having to block, but requires + a lot of GPU memory. A small value saves memory, but results in early + blocking. + """ + def __init__(self, pipeline, num_tensors_per_process): + # self.tensors is a 3-level list of TensorEvents, e.g., + # [ + # None, (the first step does not need shared input tensors) + # [ (shared tensors between step 0 & 1) + # [tensorEvent000, tensorEvent001, ...] (outputs of process 0 in step 0) + # [tensorEvent010, tensorEvent011, ...] (outputs of process 1 in step 0) + # [tensorEvent020, tensorEvent021, ...] (outputs of process 2 in step 0) + # ], + + # [ (shared tensors between step 1 & 2) + # [tensorEvent100, tensorEvent101, ...] (outputs of process 0 in step 1) + # [tensorEvent110, tensorEvent111, ...] (outputs of process 1 in step 1) + # [tensorEvent120, tensorEvent121, ...] (outputs of process 2 in step 1) + # ], + # ..., + # [None, None, ...] (the last step does not need shared output tensors) + # ] + self.tensors = [None] + + # we exclude the last step since the last step does need to output tensors + for step in pipeline[:-1]: + step_output_tensors = [] + + # load the model class to check the output tensor shape of this step + model_module_path = step['model'] + model_class = load_class(model_module_path) + shape = model_class.output_shape() + + for gpu in step['gpus']: + device = torch.device('cuda:%d' % gpu) + tensors = [TensorEvent(shape, device) + for _ in range(num_tensors_per_process)] + step_output_tensors.append(tensors) + + self.tensors.append(step_output_tensors) + + # add Nones as output placeholders for the last tsep + self.tensors.append([None for _ in pipeline[-1]['gpus']]) + + def get_shared_tensors(self, step_idx, instance_idx): + """Returns the shared input tensors and output tensors for a given process. + + The shared input tensors are returned as a 2-level list, containing the + output tensors of all processes of the previous step. On the other hand, + the output tensors are returned as a 1-level list, since this process does + not need to access the output tensors of other processes from the same step. + """ + return self.tensors[step_idx], self.tensors[step_idx + 1][instance_idx] + + +# An integer tuple for accessing tensors from BenchmarkTensors. +Signal = namedtuple('Signal', ['instance_idx', 'tensor_idx']) diff --git a/models/r2p1d/model.py b/models/r2p1d/model.py index 780bf4a..f595d03 100644 --- a/models/r2p1d/model.py +++ b/models/r2p1d/model.py @@ -68,6 +68,12 @@ def __init__(self, device, start_index=1, end_index=5, num_classes=400, layer_si def input_shape(self): return self.input_dict[self.start_index] + @staticmethod + def output_shape(): + # TODO: the output shape may not be (10, 400), depending on self.end_index + # need to change return value accordingly + return (10, 400) + def __call__(self, input): return self.model(input) @@ -134,6 +140,13 @@ def __call__(self, input): def __del__(self): self.loader.close() + def input_shape(self): + return None + + @staticmethod + def output_shape(): + return (10, 3, 8, 112, 112) + class R2P1DSingleStep(RunnerModel): """RunnerModel impl that contains all inference logic regarding R(2+1)D. @@ -199,3 +212,10 @@ def __call__(self, input): def __del__(self): self.loader.close() + + def input_shape(self): + return None + + @staticmethod + def output_shape(): + return (10, 400) diff --git a/runner.py b/runner.py index 10448a2..7daca66 100644 --- a/runner.py +++ b/runner.py @@ -3,10 +3,11 @@ NUM_EXIT_MARKERS = 10 NUM_SUMMARY_SKIPS = 10 def runner(input_queue, output_queue, print_summary, - job_id, g_idx, r_idx, global_inference_counter, num_videos, + job_id, g_idx, r_idx, instance_idx, + global_inference_counter, num_videos, termination_flag, step_idx, sta_bar, fin_bar, - model_module_path, + model_module_path, shared_input_tensors, shared_output_tensors, **model_kwargs): # PyTorch seems to have an issue with sharing modules between # multiple processes, so we just do the imports here and @@ -16,7 +17,7 @@ def runner(input_queue, output_queue, print_summary, from queue import Empty, Full from tqdm import tqdm from rnb_logging import logname, TimeCardSummary - from control import TerminationFlag + from control import TerminationFlag, Signal from utils.class_utils import load_class # Use our own CUDA stream to avoid synchronizing with other processes @@ -43,6 +44,16 @@ def runner(input_queue, output_queue, print_summary, # collect incoming time measurements for later logging time_card_summary = TimeCardSummary() + # keep track of the next position to write output tensors + shared_output_tensor_counter = 0 + + # Create placeholder tensor to copy values from shared input tensors. + # In case the model does not provide any tensor shape, then we assume + # the expected input is not a tensor and do not make any placeholders. + shape = model.input_shape() + if shape is not None: + input_placeholder = torch.empty(*shape, dtype=torch.float32).cuda() + sta_bar.wait() if print_summary: @@ -54,15 +65,48 @@ def runner(input_queue, output_queue, print_summary, if tpl is None: break - tensor, time_card = tpl + signal_or_input, time_card = tpl time_card.record('runner%d_start' % step_idx) - if isinstance(tensor, torch.Tensor) and tensor.device != device: - tensor = tensor.to(device=device) + if isinstance(signal_or_input, Signal): + # we need to copy values from the designated shared input tensor + instance_idx, tensor_idx = signal_or_input + tensor_event = shared_input_tensors[instance_idx][tensor_idx] + + # Under normal circumstances, the event should not be set yet. + # However, this may not be true if the job is terminating, in which + # case we immediately exit. + if tensor_event.event.is_set() and \ + termination_flag.value != TerminationFlag.UNSET: + break + + # This is basically a device-to-device memcpy if the source tensor + # is coming from a different device. If not, then this op becomes + # a memcpy within the same device. + input_placeholder.copy_(tensor_event.tensor) + + # release the shared tensor to be reused later + tensor_event.event.set() + + else: + # this process does not use the shared tensor mechanism + # simply use the input as-is + input_placeholder = signal_or_input + time_card.record('inference%d_start' % step_idx) - outputs = model(tensor) + outputs = model(input_placeholder) stream.synchronize() + + if shared_output_tensors is not None: + # we need to copy the results into a shared output tensor + tensor_event = shared_output_tensors[shared_output_tensor_counter] + + # check to see if the tensor has been released or not + tensor_event.event.wait() + tensor_event.tensor.copy_(outputs) + tensor_event.event.clear() + time_card.record('inference%d_finish' % step_idx) @@ -91,7 +135,16 @@ def runner(input_queue, output_queue, print_summary, # this is NOT the final step # pass on the intermediate tensor to the next step try: - output_queue.put_nowait((outputs, time_card)) + if shared_output_tensors is not None: + # pass a Signal object for accessing shared tensors + signal = Signal(instance_idx, shared_output_tensor_counter) + output_queue.put_nowait((signal, time_card)) + shared_output_tensor_counter = \ + (shared_output_tensor_counter + 1) \ + % len(shared_output_tensors) + else: + # no need to pass any signals, just enqueue outputs directly + output_queue.put_nowait((outputs, time_card)) except Full: print('[WARNING] Queue between runner step %d and %d is full. ' 'Aborting...' % (step_idx, step_idx+1)) @@ -108,6 +161,13 @@ def runner(input_queue, output_queue, print_summary, except Full: pass + if shared_input_tensors is not None: + # release all shared input tensors in case any process from the + # previous step is waiting for a tensor to be released + for instance_tensors in shared_input_tensors: + for protected_tensor in instance_tensors: + protected_tensor.event.set() + fin_bar.wait() if output_queue is not None: output_queue.cancel_join_thread() diff --git a/runner_model.py b/runner_model.py index 5744ab4..9d2a0d9 100644 --- a/runner_model.py +++ b/runner_model.py @@ -17,6 +17,11 @@ def input_shape(self): """Returns the expected shape of the input tensor to this model.""" raise NotImplementedError + @staticmethod + def output_shape(): + """Returns the expected shape of the output tensor of this model.""" + raise NotImplementedError + def __call__(self, input): """Perform inference on this model with the given input. From 21a479db230597fbd0549e02718c3028bb140ddf Mon Sep 17 00:00:00 2001 From: "Jason (Joo Seong) Jeong" Date: Sat, 13 Jul 2019 17:39:29 +0000 Subject: [PATCH 2/5] undo local changes --- client.py | 2 +- config/r2p1d-whole.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.py b/client.py index 1fc009c..20e5c4b 100644 --- a/client.py +++ b/client.py @@ -38,7 +38,7 @@ def poisson_client(video_path_iterator, filename_queue, beta, termination_flag, print('[WARNING] Filename queue is full. Aborting...') termination_flag.value = TerminationFlag.FILENAME_QUEUE_FULL break - time.sleep(float(beta) / 1000) # milliseconds --> seconds + time.sleep(exponential(float(beta) / 1000)) # milliseconds --> seconds # mark the end of the input stream # the loaders should exit by themselves, but we enqueue these markers just in diff --git a/config/r2p1d-whole.json b/config/r2p1d-whole.json index 7659d24..f6861e9 100644 --- a/config/r2p1d-whole.json +++ b/config/r2p1d-whole.json @@ -8,7 +8,7 @@ }, { "model": "models.r2p1d.model.R2P1DRunner", - "gpus": [0,1,2,3,4,5,6,7], + "gpus": [0], "start_index": 1, "end_index": 5 } From 109749bd99626c44036597c1a1e74570a48b6d3d Mon Sep 17 00:00:00 2001 From: "Jason (Joo Seong) Jeong" Date: Mon, 15 Jul 2019 12:00:53 +0000 Subject: [PATCH 3/5] address comments from #58 --- benchmark.py | 18 ++++++------- client.py | 4 +-- config/r2p1d-whole.json | 2 +- control.py | 56 +++++++++++++++++++++++------------------ models/r2p1d/model.py | 21 +++++++++------- runner.py | 47 +++++++++++++++++++++------------- runner_model.py | 39 ++++++++++++++++++++++++++-- 7 files changed, 122 insertions(+), 65 deletions(-) diff --git a/benchmark.py b/benchmark.py index 860e10c..b8eb7e8 100644 --- a/benchmark.py +++ b/benchmark.py @@ -131,7 +131,7 @@ def sanity_check(args): # change these if you want to use different client/loader/runner impls from rnb_logging import logmeta, logroot - from control import TerminationFlag, BenchmarkQueues, BenchmarkTensors + from control import TerminationFlag, SharedQueues, SharedTensors from client import * from runner import runner @@ -195,10 +195,10 @@ def sanity_check(args): # (mean_interval_ms = 0 is a special case where all videos are put in queues at once) queue_size = args.queue_size if args.mean_interval_ms > 0 else args.videos + num_runners + 1 - # create BenchmarkQueues object for managing queues between steps - benchmark_queues = BenchmarkQueues(Queue, queue_size, pipeline, - args.per_gpu_queue) - filename_queue = benchmark_queues.get_filename_queue() + # create SharedQueues object for managing queues between steps + shared_queues = SharedQueues(Queue, queue_size, pipeline, + args.per_gpu_queue) + filename_queue = shared_queues.get_filename_queue() video_path_iterator = config['video_path_iterator'] @@ -214,8 +214,8 @@ def sanity_check(args): process_client = Process(target=client_impl, args=client_args) - # create BenchmarkTensors object for managing shared tensors between steps - benchmark_tensors = BenchmarkTensors(pipeline, args.tensors_per_process) + # create SharedTensors object for managing shared tensors between steps + shared_tensors = SharedTensors(pipeline, args.tensors_per_process) process_runner_list = [] for step_idx, step in enumerate(pipeline): @@ -229,10 +229,10 @@ def sanity_check(args): for instance_idx, gpu in enumerate(gpus): is_first_instance = instance_idx == 0 - prev_queue, next_queue = benchmark_queues.get_tensor_queue(step_idx, gpu) + prev_queue, next_queue = shared_queues.get_tensor_queue(step_idx, gpu) shared_input_tensors, shared_output_tensors = \ - benchmark_tensors.get_shared_tensors(step_idx, instance_idx) + shared_tensors.get_tensors(step_idx, instance_idx) # check the replica index of this particular runner, for this gpu # if this runner is the first, then give it index 0 diff --git a/client.py b/client.py index 20e5c4b..67e848a 100644 --- a/client.py +++ b/client.py @@ -33,7 +33,7 @@ def poisson_client(video_path_iterator, filename_queue, beta, termination_flag, time_card.record('enqueue_filename') try: - filename_queue.put_nowait((video_path, time_card)) + filename_queue.put_nowait(((None, video_path), time_card)) except Full: print('[WARNING] Filename queue is full. Aborting...') termination_flag.value = TerminationFlag.FILENAME_QUEUE_FULL @@ -81,7 +81,7 @@ def bulk_client(video_path_iterator, filename_queue, num_videos, termination_fla time_card.record('enqueue_filename') try: - filename_queue.put_nowait((video_path, time_card)) + filename_queue.put_nowait(((None, video_path), time_card)) except Full: print('[WARNING] Filename queue is full. Aborting...') termination_flag.value = TerminationFlag.FILENAME_QUEUE_FULL diff --git a/config/r2p1d-whole.json b/config/r2p1d-whole.json index f6861e9..7659d24 100644 --- a/config/r2p1d-whole.json +++ b/config/r2p1d-whole.json @@ -8,7 +8,7 @@ }, { "model": "models.r2p1d.model.R2P1DRunner", - "gpus": [0], + "gpus": [0,1,2,3,4,5,6,7], "start_index": 1, "end_index": 5 } diff --git a/control.py b/control.py index 619d4ee..f06fe59 100644 --- a/control.py +++ b/control.py @@ -12,7 +12,7 @@ class TerminationFlag: FRAME_QUEUE_FULL = 2 -class BenchmarkQueues: +class SharedQueues: """Manages intermediate queues that connect steps in the benchmark. Args: @@ -87,27 +87,28 @@ def get_filename_queue(self): class TensorEvent: - """Basically a tuple of a torch.Tensor and multiprocessing.Event. + """Basically a tuple of several torch.Tensors and a multiprocessing.Event. - The Tensor can be used as a "shared tensor" for passing intermediate tensors + The Tensors can be used as "shared tensors" for passing intermediate tensors across processes. The Event should be used to signal that the consumer process has finished - reading from the Tensor. When writing values to the Tensor, the producer - process should first check if the Tensor is free, by calling event.wait(). If - the Tensor is indeed free, then event.wait() will return immediately. If not, + reading from the Tensors. When writing values to Tensors, the producer + process should first check if Tensors are free, by calling event.wait(). If + the Tensors are indeed free, then event.wait() will return at once. If not, then event.wait() will block until the consumer process calls event.set(). Thus, the consumer should make sure that it calls event.set() AFTER the - Tensor's contents have been copied to a safe area, such as the consumer's own - local tensor. + Tensors' contents have been copied to a safe area, such as the consumer's own + local tensors. """ - def __init__(self, shape, device, dtype=torch.float32): - self.tensor = torch.empty(*shape, dtype=dtype, device=device) + def __init__(self, shapes, device, dtype=torch.float32): + self.tensors = tuple(torch.empty(*shape, dtype=dtype, device=device) + for shape in shapes) self.event = Event() self.event.set() -class BenchmarkTensors: +class SharedTensors: """Manages intermediate tensors that are passed across steps in the benchmark. Args: @@ -116,7 +117,10 @@ class BenchmarkTensors: to each process, for writing tensor values. A big value allows processes to produce many tensors before having to block, but requires a lot of GPU memory. A small value saves memory, but results in early - blocking. + blocking. Note that if a step outputs several tensors during each + iteration, then this class allocates separate memory for each tensor, + but still treats them as one tensor when comparing the count with + num_tensors_per_process. """ def __init__(self, pipeline, num_tensors_per_process): # self.tensors is a 3-level list of TensorEvents, e.g., @@ -138,27 +142,31 @@ def __init__(self, pipeline, num_tensors_per_process): # ] self.tensors = [None] - # we exclude the last step since the last step does need to output tensors + # we exclude the last step since the last step does not need output tensors for step in pipeline[:-1]: - step_output_tensors = [] - # load the model class to check the output tensor shape of this step model_module_path = step['model'] model_class = load_class(model_module_path) - shape = model_class.output_shape() + shapes = model_class.output_shape() + + if shapes is None: + # this step does not need shared output tensors + step_output_tensors = [None for _ in step(['gpus'])] - for gpu in step['gpus']: - device = torch.device('cuda:%d' % gpu) - tensors = [TensorEvent(shape, device) - for _ in range(num_tensors_per_process)] - step_output_tensors.append(tensors) + else: + step_output_tensors = [] + for gpu in step['gpus']: + device = torch.device('cuda:%d' % gpu) + tensors = [TensorEvent(shapes, device) + for _ in range(num_tensors_per_process)] + step_output_tensors.append(tensors) self.tensors.append(step_output_tensors) - # add Nones as output placeholders for the last tsep + # add Nones as output placeholders for the last step self.tensors.append([None for _ in pipeline[-1]['gpus']]) - def get_shared_tensors(self, step_idx, instance_idx): + def get_tensors(self, step_idx, instance_idx): """Returns the shared input tensors and output tensors for a given process. The shared input tensors are returned as a 2-level list, containing the @@ -169,5 +177,5 @@ def get_shared_tensors(self, step_idx, instance_idx): return self.tensors[step_idx], self.tensors[step_idx + 1][instance_idx] -# An integer tuple for accessing tensors from BenchmarkTensors. +# An integer tuple for accessing tensors from SharedTensors. Signal = namedtuple('Signal', ['instance_idx', 'tensor_idx']) diff --git a/models/r2p1d/model.py b/models/r2p1d/model.py index f595d03..e2d1f05 100644 --- a/models/r2p1d/model.py +++ b/models/r2p1d/model.py @@ -66,16 +66,17 @@ def __init__(self, device, start_index=1, end_index=5, num_classes=400, layer_si stream.synchronize() def input_shape(self): - return self.input_dict[self.start_index] + return (self.input_dict[self.start_index],) @staticmethod def output_shape(): # TODO: the output shape may not be (10, 400), depending on self.end_index # need to change return value accordingly - return (10, 400) + return ((10, 400),) def __call__(self, input): - return self.model(input) + (tensor,), _ = input + return ((self.model(tensor),), None) class R2P1DVideoPathIterator(VideoPathIterator): def __init__(self): @@ -128,14 +129,15 @@ def __init__(self, device): self.loader.flush() def __call__(self, input): - self.loader.loadfile(input) + _, filename = input + self.loader.loadfile(filename) for frames in self.loader: pass self.loader.flush() frames = frames.float() frames = frames.permute(0, 2, 1, 3, 4) - return frames + return ((frames,), None) def __del__(self): self.loader.close() @@ -145,7 +147,7 @@ def input_shape(self): @staticmethod def output_shape(): - return (10, 3, 8, 112, 112) + return ((10, 3, 8, 112, 112),) class R2P1DSingleStep(RunnerModel): @@ -200,7 +202,8 @@ def __init__(self, device, num_classes=400, layer_sizes=[2,2,2,2], stream.synchronize() def __call__(self, input): - self.loader.loadfile(input) + _, filename = input + self.loader.loadfile(filename) for frames in self.loader: pass self.loader.flush() @@ -208,7 +211,7 @@ def __call__(self, input): frames = frames.float() frames = frames.permute(0, 2, 1, 3, 4) - return self.model(frames) + return ((self.model(frames),), None) def __del__(self): self.loader.close() @@ -218,4 +221,4 @@ def input_shape(self): @staticmethod def output_shape(): - return (10, 400) + return ((10, 400),) diff --git a/runner.py b/runner.py index 7daca66..242828c 100644 --- a/runner.py +++ b/runner.py @@ -48,11 +48,13 @@ def runner(input_queue, output_queue, print_summary, shared_output_tensor_counter = 0 # Create placeholder tensor to copy values from shared input tensors. - # In case the model does not provide any tensor shape, then we assume - # the expected input is not a tensor and do not make any placeholders. - shape = model.input_shape() - if shape is not None: - input_placeholder = torch.empty(*shape, dtype=torch.float32).cuda() + # In case the model does not provide any tensor shape, then we do not + # make any placeholders. + shapes = model.input_shape() + if shapes is not None: + tensor_input_placeholder = \ + tuple(torch.empty(*shape, dtype=torch.float32).cuda() + for shape in shapes) sta_bar.wait() @@ -65,12 +67,13 @@ def runner(input_queue, output_queue, print_summary, if tpl is None: break - signal_or_input, time_card = tpl + (signal, non_tensor_inputs), time_card = tpl + time_card.record('runner%d_start' % step_idx) - if isinstance(signal_or_input, Signal): + if signal is not None: # we need to copy values from the designated shared input tensor - instance_idx, tensor_idx = signal_or_input + instance_idx, tensor_idx = signal tensor_event = shared_input_tensors[instance_idx][tensor_idx] # Under normal circumstances, the event should not be set yet. @@ -80,22 +83,24 @@ def runner(input_queue, output_queue, print_summary, termination_flag.value != TerminationFlag.UNSET: break - # This is basically a device-to-device memcpy if the source tensor - # is coming from a different device. If not, then this op becomes + # This is basically a device-to-device memcpy if the source tensors + # are coming from a different device. If not, then this op becomes # a memcpy within the same device. - input_placeholder.copy_(tensor_event.tensor) + for placeholder, shared_tensor in zip(tensor_input_placeholder, + tensor_event.tensors): + placeholder.copy_(shared_tensor) # release the shared tensor to be reused later tensor_event.event.set() else: # this process does not use the shared tensor mechanism - # simply use the input as-is - input_placeholder = signal_or_input + tensor_input_placeholder = None time_card.record('inference%d_start' % step_idx) - outputs = model(input_placeholder) + tensor_outputs, non_tensor_outputs = \ + model((tensor_input_placeholder, non_tensor_inputs)) stream.synchronize() if shared_output_tensors is not None: @@ -103,8 +108,13 @@ def runner(input_queue, output_queue, print_summary, tensor_event = shared_output_tensors[shared_output_tensor_counter] # check to see if the tensor has been released or not + # TODO #59: if this tensor is not ready, then check another one tensor_event.event.wait() - tensor_event.tensor.copy_(outputs) + + for tensor_output, shared_tensor in zip(tensor_outputs, + tensor_event.tensors): + shared_tensor.copy_(tensor_output) + tensor_event.event.clear() time_card.record('inference%d_finish' % step_idx) @@ -138,13 +148,14 @@ def runner(input_queue, output_queue, print_summary, if shared_output_tensors is not None: # pass a Signal object for accessing shared tensors signal = Signal(instance_idx, shared_output_tensor_counter) - output_queue.put_nowait((signal, time_card)) shared_output_tensor_counter = \ (shared_output_tensor_counter + 1) \ % len(shared_output_tensors) else: - # no need to pass any signals, just enqueue outputs directly - output_queue.put_nowait((outputs, time_card)) + # no need to pass any signals, just enqueue empty signal + signal = None + output_queue.put_nowait(((signal, non_tensor_outputs), time_card)) + except Full: print('[WARNING] Queue between runner step %d and %d is full. ' 'Aborting...' % (step_idx, step_idx+1)) diff --git a/runner_model.py b/runner_model.py index 9d2a0d9..c7da8aa 100644 --- a/runner_model.py +++ b/runner_model.py @@ -14,17 +14,52 @@ def __init__(self, device): pass def input_shape(self): - """Returns the expected shape of the input tensor to this model.""" + """Returns the expected shapes of the input tensors to this model. + + The return value should be a nested tuple, containing a shape tuple for each + expected input tensor. Note that this applies even if the model expects only + one tensor; you can create a single-item tuple by doing `(shape,)`. + + If the model does not receive any tensors, then return None. You can still + receive any non-tensor objects the previous step passes, in __call__(). + Keep in mind that returning None and returning an empty tuple (`()`) are + completely different. Copy-paste the previous step's output shape to be + safe. See output_shape() for more details. + """ raise NotImplementedError @staticmethod def output_shape(): - """Returns the expected shape of the output tensor of this model.""" + """Returns the expected shape of the output tensors of this model. + + The return value should be a nested tuple, containing a shape tuple for each + expected output tensor. Note that this applies even if the model outputs + only one tensor; you can create a single-item tuple by doing `(shape,)`. + + If the model does not output any tensors, then return None. You are still + allowed to output any non-tensor objects, in __call__(). + Keep in mind that returning None and returning an empty tuple (`()`) are + completely different. For the former, the benchmark does not even bother + creating any synchronization (multiprocessing.Event) objects for sharing + tensors, but for the latter, the benchmark does create them. + """ raise NotImplementedError def __call__(self, input): """Perform inference on this model with the given input. We purposely follow PyTorch's convention of using __call__ for inference. + The input parameter is a pair of tensor tuples and non-tensor tuples, e.g., + ((tensor1, tensor2, tensor3), (integer, string)). In case the previous step + does not provide any tensor outputs, the tensor tuple is set to None. This + is also true for non-tensor objects. + + Note that even if there is only one tensor input, the tensor tuple is still + a tuple and not a standalone tensor object. In that case, one way you can + extract the single tensor from `input` is `(tensor,), _ = input`. + + This tuple format is the same for the output. For both tensor outputs and + non-tensor outputs, make sure to return None if there is no output, and to + return a tuple if there is at least one output. """ raise NotImplementedError From 0b932e51982a737197fb69af8781478ae3ca13b4 Mon Sep 17 00:00:00 2001 From: "Jason (Joo Seong) Jeong" Date: Mon, 15 Jul 2019 12:03:26 +0000 Subject: [PATCH 4/5] undo local changes --- config/r2p1d-whole.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/r2p1d-whole.json b/config/r2p1d-whole.json index 7659d24..f6861e9 100644 --- a/config/r2p1d-whole.json +++ b/config/r2p1d-whole.json @@ -8,7 +8,7 @@ }, { "model": "models.r2p1d.model.R2P1DRunner", - "gpus": [0,1,2,3,4,5,6,7], + "gpus": [0], "start_index": 1, "end_index": 5 } From c600b2b3d3c94908a8d22c0a562cff9a866cf0e1 Mon Sep 17 00:00:00 2001 From: "Jason (Joo Seong) Jeong" Date: Mon, 15 Jul 2019 12:12:53 +0000 Subject: [PATCH 5/5] minor commment fixes --- runner_model.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/runner_model.py b/runner_model.py index c7da8aa..dfc7f81 100644 --- a/runner_model.py +++ b/runner_model.py @@ -49,17 +49,21 @@ def __call__(self, input): """Perform inference on this model with the given input. We purposely follow PyTorch's convention of using __call__ for inference. - The input parameter is a pair of tensor tuples and non-tensor tuples, e.g., - ((tensor1, tensor2, tensor3), (integer, string)). In case the previous step - does not provide any tensor outputs, the tensor tuple is set to None. This - is also true for non-tensor objects. + The input parameter is a pair of tensor tuples and a non-tensor object + (which could also be a tuple, but does not necessaily have to be), e.g., + ((tensor1, tensor2, tensor3), string). In case the previous step does not + provide any tensor outputs, the tensor tuple is set to None. This is also + true for the non-tensor object. Note that even if there is only one tensor input, the tensor tuple is still a tuple and not a standalone tensor object. In that case, one way you can - extract the single tensor from `input` is `(tensor,), _ = input`. + extract the single tensor from `input` is `(tensor,), _ = input`. This is + NOT true for the non-tensor object; the non-tensor output from the previous + step can be literally anything. - This tuple format is the same for the output. For both tensor outputs and - non-tensor outputs, make sure to return None if there is no output, and to - return a tuple if there is at least one output. + This tuple format is the same for the output. For the tensor outputs, make + sure to return None if there is no output, and to return a tuple if there + is at least one output. Also don't forget to return None for the non-tensor + object if you don't have any non-tensor output. """ raise NotImplementedError