Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared tensor mechanism for avoiding tensor serialization/ipc costs #58

Merged
merged 5 commits into from
Jul 16, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def sanity_check(args):

# change these if you want to use different client/loader/runner impls
from rnb_logging import logmeta, logroot
from control import TerminationFlag, BenchmarkQueues
from control import TerminationFlag, BenchmarkQueues, BenchmarkTensors
from client import *
from runner import runner

Expand All @@ -152,17 +152,21 @@ def sanity_check(args):
parser.add_argument('-p', '--per_gpu_queue',
help='Whether to place intermediate queues on each GPU',
action='store_true')
parser.add_argument('-t', '--tensors_per_process',
help='Number of shared output tensors per process',
type=positive_int, default=100)
args = parser.parse_args()
print('Args:', args)

sanity_check(args)

job_id = '%s-mi%d-b%d-v%d-qs%d-p%d' % (dt.today().strftime('%y%m%d_%H%M%S'),
args.mean_interval_ms,
args.batch_size,
args.videos,
args.queue_size,
args.per_gpu_queue)
job_id = '%s-mi%d-b%d-v%d-qs%d-p%d-t%d' % (dt.today().strftime('%y%m%d_%H%M%S'),
args.mean_interval_ms,
args.batch_size,
args.videos,
args.queue_size,
args.per_gpu_queue,
args.tensors_per_process)

# do a quick pass through the pipeline to count the total number of runners
with open(args.config_file_path, 'r') as f:
Expand Down Expand Up @@ -210,6 +214,9 @@ def sanity_check(args):
process_client = Process(target=client_impl,
args=client_args)

# create BenchmarkTensors object for managing shared tensors between steps
benchmark_tensors = BenchmarkTensors(pipeline, args.tensors_per_process)

process_runner_list = []
for step_idx, step in enumerate(pipeline):
is_final_step = step_idx == len(pipeline) - 1
Expand All @@ -224,6 +231,9 @@ def sanity_check(args):

prev_queue, next_queue = benchmark_queues.get_tensor_queue(step_idx, gpu)

shared_input_tensors, shared_output_tensors = \
benchmark_tensors.get_shared_tensors(step_idx, instance_idx)

# check the replica index of this particular runner, for this gpu
# if this runner is the first, then give it index 0
replica_idx = replica_dict.get(gpu, 0)
Expand All @@ -236,11 +246,12 @@ def sanity_check(args):
process_runner = Process(target=runner,
args=(prev_queue, next_queue,
print_summary,
job_id, gpu, replica_idx,
job_id, gpu, replica_idx, instance_idx,
global_inference_counter, args.videos,
termination_flag, step_idx,
sta_bar, fin_bar,
model),
model, shared_input_tensors,
shared_output_tensors),
kwargs=step)

replica_dict[gpu] = replica_idx + 1
Expand Down
93 changes: 93 additions & 0 deletions control.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import torch
from collections import namedtuple
from torch.multiprocessing import Event
from utils.class_utils import load_class


class TerminationFlag:
"""An enum class for representing various termination states."""
UNSET = -1
Expand Down Expand Up @@ -78,3 +84,90 @@ def get_tensor_queue(self, step_idx, gpu_idx):

def get_filename_queue(self):
return self.filename_queue


class TensorEvent:
"""Basically a tuple of a torch.Tensor and multiprocessing.Event.

The Tensor can be used as a "shared tensor" for passing intermediate tensors
across processes.

The Event should be used to signal that the consumer process has finished
reading from the Tensor. When writing values to the Tensor, the producer
process should first check if the Tensor is free, by calling event.wait(). If
the Tensor is indeed free, then event.wait() will return immediately. If not,
then event.wait() will block until the consumer process calls event.set().
Thus, the consumer should make sure that it calls event.set() AFTER the
Tensor's contents have been copied to a safe area, such as the consumer's own
local tensor.
"""
def __init__(self, shape, device, dtype=torch.float32):
self.tensor = torch.empty(*shape, dtype=dtype, device=device)
self.event = Event()
self.event.set()


class BenchmarkTensors:
yunseong marked this conversation as resolved.
Show resolved Hide resolved
"""Manages intermediate tensors that are passed across steps in the benchmark.

Args:
pipeline: The whole pipeline info parsed from the input configuration file
num_tensors_per_process: The number of shared output tensors that are given
to each process, for writing tensor values. A big value allows
processes to produce many tensors before having to block, but requires
a lot of GPU memory. A small value saves memory, but results in early
blocking.
"""
def __init__(self, pipeline, num_tensors_per_process):
# self.tensors is a 3-level list of TensorEvents, e.g.,
# [
# None, (the first step does not need shared input tensors)
# [ (shared tensors between step 0 & 1)
# [tensorEvent000, tensorEvent001, ...] (outputs of process 0 in step 0)
# [tensorEvent010, tensorEvent011, ...] (outputs of process 1 in step 0)
# [tensorEvent020, tensorEvent021, ...] (outputs of process 2 in step 0)
# ],

# [ (shared tensors between step 1 & 2)
# [tensorEvent100, tensorEvent101, ...] (outputs of process 0 in step 1)
# [tensorEvent110, tensorEvent111, ...] (outputs of process 1 in step 1)
# [tensorEvent120, tensorEvent121, ...] (outputs of process 2 in step 1)
# ],
# ...,
# [None, None, ...] (the last step does not need shared output tensors)
# ]
self.tensors = [None]

# we exclude the last step since the last step does need to output tensors
for step in pipeline[:-1]:
step_output_tensors = []

# load the model class to check the output tensor shape of this step
model_module_path = step['model']
model_class = load_class(model_module_path)
shape = model_class.output_shape()

for gpu in step['gpus']:
device = torch.device('cuda:%d' % gpu)
tensors = [TensorEvent(shape, device)
for _ in range(num_tensors_per_process)]
step_output_tensors.append(tensors)

self.tensors.append(step_output_tensors)

# add Nones as output placeholders for the last tsep
yunseong marked this conversation as resolved.
Show resolved Hide resolved
self.tensors.append([None for _ in pipeline[-1]['gpus']])

def get_shared_tensors(self, step_idx, instance_idx):
"""Returns the shared input tensors and output tensors for a given process.

The shared input tensors are returned as a 2-level list, containing the
output tensors of all processes of the previous step. On the other hand,
the output tensors are returned as a 1-level list, since this process does
not need to access the output tensors of other processes from the same step.
"""
return self.tensors[step_idx], self.tensors[step_idx + 1][instance_idx]


# An integer tuple for accessing tensors from BenchmarkTensors.
Signal = namedtuple('Signal', ['instance_idx', 'tensor_idx'])
20 changes: 20 additions & 0 deletions models/r2p1d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def __init__(self, device, start_index=1, end_index=5, num_classes=400, layer_si
def input_shape(self):
return self.input_dict[self.start_index]

@staticmethod
def output_shape():
# TODO: the output shape may not be (10, 400), depending on self.end_index
# need to change return value accordingly
return (10, 400)

def __call__(self, input):
return self.model(input)

Expand Down Expand Up @@ -134,6 +140,13 @@ def __call__(self, input):
def __del__(self):
self.loader.close()

def input_shape(self):
return None

@staticmethod
def output_shape():
return (10, 3, 8, 112, 112)


class R2P1DSingleStep(RunnerModel):
"""RunnerModel impl that contains all inference logic regarding R(2+1)D.
Expand Down Expand Up @@ -199,3 +212,10 @@ def __call__(self, input):

def __del__(self):
self.loader.close()

def input_shape(self):
return None

@staticmethod
def output_shape():
return (10, 400)
76 changes: 68 additions & 8 deletions runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
NUM_EXIT_MARKERS = 10
NUM_SUMMARY_SKIPS = 10
def runner(input_queue, output_queue, print_summary,
job_id, g_idx, r_idx, global_inference_counter, num_videos,
job_id, g_idx, r_idx, instance_idx,
global_inference_counter, num_videos,
termination_flag, step_idx,
sta_bar, fin_bar,
model_module_path,
model_module_path, shared_input_tensors, shared_output_tensors,
**model_kwargs):
# PyTorch seems to have an issue with sharing modules between
# multiple processes, so we just do the imports here and
Expand All @@ -16,7 +17,7 @@ def runner(input_queue, output_queue, print_summary,
from queue import Empty, Full
from tqdm import tqdm
from rnb_logging import logname, TimeCardSummary
from control import TerminationFlag
from control import TerminationFlag, Signal
from utils.class_utils import load_class

# Use our own CUDA stream to avoid synchronizing with other processes
Expand All @@ -43,6 +44,16 @@ def runner(input_queue, output_queue, print_summary,
# collect incoming time measurements for later logging
time_card_summary = TimeCardSummary()

# keep track of the next position to write output tensors
shared_output_tensor_counter = 0

# Create placeholder tensor to copy values from shared input tensors.
# In case the model does not provide any tensor shape, then we assume
# the expected input is not a tensor and do not make any placeholders.
shape = model.input_shape()
if shape is not None:
input_placeholder = torch.empty(*shape, dtype=torch.float32).cuda()

sta_bar.wait()

if print_summary:
Expand All @@ -54,15 +65,48 @@ def runner(input_queue, output_queue, print_summary,
if tpl is None:
break

tensor, time_card = tpl
signal_or_input, time_card = tpl
yunseong marked this conversation as resolved.
Show resolved Hide resolved
time_card.record('runner%d_start' % step_idx)

if isinstance(tensor, torch.Tensor) and tensor.device != device:
tensor = tensor.to(device=device)
if isinstance(signal_or_input, Signal):
# we need to copy values from the designated shared input tensor
instance_idx, tensor_idx = signal_or_input
tensor_event = shared_input_tensors[instance_idx][tensor_idx]

# Under normal circumstances, the event should not be set yet.
# However, this may not be true if the job is terminating, in which
# case we immediately exit.
if tensor_event.event.is_set() and \
termination_flag.value != TerminationFlag.UNSET:
break

# This is basically a device-to-device memcpy if the source tensor
# is coming from a different device. If not, then this op becomes
# a memcpy within the same device.
input_placeholder.copy_(tensor_event.tensor)

# release the shared tensor to be reused later
tensor_event.event.set()

else:
# this process does not use the shared tensor mechanism
# simply use the input as-is
input_placeholder = signal_or_input

time_card.record('inference%d_start' % step_idx)

outputs = model(tensor)
outputs = model(input_placeholder)
stream.synchronize()

if shared_output_tensors is not None:
# we need to copy the results into a shared output tensor
tensor_event = shared_output_tensors[shared_output_tensor_counter]

# check to see if the tensor has been released or not
tensor_event.event.wait()
yunseong marked this conversation as resolved.
Show resolved Hide resolved
tensor_event.tensor.copy_(outputs)
tensor_event.event.clear()

time_card.record('inference%d_finish' % step_idx)


Expand Down Expand Up @@ -91,7 +135,16 @@ def runner(input_queue, output_queue, print_summary,
# this is NOT the final step
# pass on the intermediate tensor to the next step
try:
output_queue.put_nowait((outputs, time_card))
if shared_output_tensors is not None:
# pass a Signal object for accessing shared tensors
signal = Signal(instance_idx, shared_output_tensor_counter)
output_queue.put_nowait((signal, time_card))
shared_output_tensor_counter = \
(shared_output_tensor_counter + 1) \
% len(shared_output_tensors)
else:
# no need to pass any signals, just enqueue outputs directly
output_queue.put_nowait((outputs, time_card))
except Full:
print('[WARNING] Queue between runner step %d and %d is full. '
'Aborting...' % (step_idx, step_idx+1))
Expand All @@ -108,6 +161,13 @@ def runner(input_queue, output_queue, print_summary,
except Full:
pass

if shared_input_tensors is not None:
# release all shared input tensors in case any process from the
# previous step is waiting for a tensor to be released
for instance_tensors in shared_input_tensors:
for protected_tensor in instance_tensors:
protected_tensor.event.set()

fin_bar.wait()
if output_queue is not None:
output_queue.cancel_join_thread()
Expand Down
5 changes: 5 additions & 0 deletions runner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def input_shape(self):
"""Returns the expected shape of the input tensor to this model."""
raise NotImplementedError

@staticmethod
def output_shape():
"""Returns the expected shape of the output tensor of this model."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed previously, could we add the description for the case of None in the input/output shape (i.e., not a Tensor at the dimension)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change should be applied to input_shape() as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

raise NotImplementedError

def __call__(self, input):
"""Perform inference on this model with the given input.

Expand Down