-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathtuto.html
600 lines (548 loc) · 49.7 KB
/
tuto.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- The above 3 meta tags *must* come first in the head; any other head content must come *after* these tags -->
<title>Writing Distributed Applications with PyTorch - Séb Arnold</title>
<!-- Bootstrap -->
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/css/bootstrap.min.css" />
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/bootswatch/3.3.7/sandstone/bootstrap.min.css" />
<!--Prism for code high-lighting-->
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.5.1/themes/prism.min.css" />
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.5.1/themes/prism-solarizedlight.css" / >
<!--KaTeX for fast embedded math-->
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.6.0/katex.min.css">
<!--Pseudocode.js-->
<link rel="stylesheet" href="https://cdn.rawgit.com/seba-1511/cdn/master/pseudocode.js/pseudocode.min.css">
<style type="text/css" media="all">
/* Space out content a bit */
body {
padding-top: 20px;
padding-bottom: 20px;
}
p {
font-size: 16px;
text-align: justify;
}
/* Everything but the jumbotron gets side spacing for mobile first views */
.header,
.footer {
padding-right: 15px;
padding-left: 15px;
}
/* Custom page header */
.header {
padding-bottom: 20px;
border-bottom: 1px solid #e5e5e5;
}
/* Make the masthead heading the same height as the navigation */
.header h3 {
margin-top: 0;
margin-bottom: 0;
line-height: 40px;
}
img {
max-width: 100%;
}
/* Custom page footer */
.footer {
padding-top: 19px;
color: #777;
border-top: 1px solid #e5e5e5;
}
/* Customize container */
@media (min-width: 768px) {
.container {
max-width: 730px;
}
}
.container-narrow > hr {
margin: 30px 0;
}
/* Responsive: Portrait tablets and up */
@media screen and (min-width: 768px) {
/* Remove the padding we set earlier */
.header,
.marketing,
.footer {
padding-right: 0;
padding-left: 0;
}
/* Space out the masthead */
.header {
margin-bottom: 30px;
}
/* Remove the bottom border on the jumbotron for visual effect */
.jumbotron {
border-bottom: 0;
}
}
.well {
border: 1px solid #767676;
width: 157px;
max-width: 157px;
}
.well a {
color:#767676;
margin-bottom:5px;
}
.well ul {
list-style: none;
margin: 0px;
padding-left: 10px;
}
</style>
<!-- HTML5 shim and Respond.js for IE8 support of HTML5 elements and media queries -->
<!-- WARNING: Respond.js doesn't work if you view the page via file:// -->
<!--[if lt IE 9]>
<script src="https://oss.maxcdn.com/html5shiv/3.7.3/html5shiv.min.js"></script>
<script src="https://oss.maxcdn.com/respond/1.4.2/respond.min.js"></script>
<![endif]-->
<!--Plotly.js-->
<!--Needs to be imported before body, else figs won't load.-->
<script src="https://cdn.plot.ly/plotly-1.2.0.min.js"></script>
</head>
<body>
<div class="container">
<div class="header clearfix">
<!--<nav>-->
<!--<ul class="nav nav-pills pull-right">-->
<!--<li role="presentation" class="active"><a href="#">Home</a></li>-->
<!--<li role="presentation"><a href="#">About</a></li>-->
<!--<li role="presentation"><a href="#">Contact</a></li>-->
<!--</ul>-->
<!--</nav>-->
<h1 class="text-center">Writing Distributed Applications with PyTorch</h1>
<h4 class="text-sm text-muted text-center"> by Séb Arnold, <span style="font-weight:normal;"><i>June 14, 2017</i></span></h4>
</div>
<div style="margin-top:20px;margin-bottom:20px;"><p><p style="text-align:center;margin:0px;"><b>Abstract</b><p><br/>
In this short tutorial, we will be going over the distributed package of PyTorch. We'll see how to set up the distributed setting, use the different communication strategies, and go over some the internals of the package.
</p></div>
<h1 id="setup">Setup</h1>
<!--
* Processes & machines
* variables and init_process_group
-->
<p>The distributed package included in PyTorch (i.e., <code>torch.distributed</code>) enables researchers and practitioners to easily parallelize their computations across processes and clusters of machines. To do so, it leverages the messaging passing semantics allowing each process to communicate data to any of the other processes. As opposed to the multiprocessing (<code>torch.multiprocessing</code>) package, processes can use different communication backends and are not restricted to being executed on the same machine.</p>
<p>In order to get started we need the ability to run multiple processes simultaneously. If you have access to compute cluster you should check with your local sysadmin or use your favorite coordination tool. (e.g., <a href="https://linux.die.net/man/1/pdsh">pdsh</a>, <a href="http://cea-hpc.github.io/clustershell/">clustershell</a>, or <a href="https://slurm.schedmd.com/">others</a>) For the purpose of this tutorial, we will use a single machine and fork multiple processes using the following template.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">"""run.py:"""</span>
<span class="co">#!/usr/bin/env python</span>
<span class="im">import</span> os
<span class="im">import</span> torch
<span class="im">import</span> torch.distributed <span class="im">as</span> dist
<span class="im">from</span> torch.multiprocessing <span class="im">import</span> Process
<span class="kw">def</span> run(rank, size):
<span class="co">""" Distributed function to be implemented later. """</span>
<span class="cf">pass</span>
<span class="kw">def</span> init_processes(rank, size, fn, backend<span class="op">=</span><span class="st">'tcp'</span>):
<span class="co">""" Initialize the distributed environment. """</span>
os.environ[<span class="st">'MASTER_ADDR'</span>] <span class="op">=</span> <span class="st">'127.0.0.1'</span>
os.environ[<span class="st">'MASTER_PORT'</span>] <span class="op">=</span> <span class="st">'29500'</span>
dist.init_process_group(backend, rank<span class="op">=</span>rank, world_size<span class="op">=</span>size)
fn(rank, size)
<span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:
size <span class="op">=</span> <span class="dv">2</span>
processes <span class="op">=</span> []
<span class="cf">for</span> rank <span class="kw">in</span> <span class="bu">range</span>(size):
p <span class="op">=</span> Process(target<span class="op">=</span>init_processes, args<span class="op">=</span>(rank, size, run))
p.start()
processes.append(p)
<span class="cf">for</span> p <span class="kw">in</span> processes:
p.join()</code></pre></div>
<p>The above script spawns two processes who will each setup the distributed environment, initialize the process group (<code>dist.init_process_group</code>), and finally execute the given <code>run</code> function.</p>
<p>Let's have a look at the <code>init_processes</code> function. It ensures that every process will be able to coordinate through a master, using the same ip address and port. Note that we used the TCP backend, but we could have used <a href="https://en.wikipedia.org/wiki/Message_Passing_Interface">MPI</a> or <a href="http://github.com/facebookincubator/gloo">Gloo</a> instead. (c.f. <a href="#communication-backends">Section 5.1</a>) We will go over the magic happening in <code>dist.init_process_group</code> at the end of this tutorial, but it essentially allows processes to communicate with each other by sharing their locations.</p>
<h1 id="point-to-point-communication">Point-to-Point Communication</h1>
<!--
* send/recv
* isend/irecv
-->
<table>
<tbody>
<tr>
</tr>
<tr>
<td align="center">
<img src='./figs/send_recv.png' width=100% /><br/> <b>Send and Recv</b>
</td>
</tr>
</tbody>
</table>
<p>A transfer of data from one process to another is called a point-to-point communication. These are achieved through the <code>send</code> and <code>recv</code> functions or their <em>immediate</em> counter-parts, <code>isend</code> and <code>irecv</code>.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">"""Blocking point-to-point communication."""</span>
<span class="kw">def</span> run(rank, size):
tensor <span class="op">=</span> torch.zeros(<span class="dv">1</span>)
<span class="cf">if</span> rank <span class="op">==</span> <span class="dv">0</span>:
tensor <span class="op">+=</span> <span class="dv">1</span>
<span class="co"># Send the tensor to process 1</span>
dist.send(tensor<span class="op">=</span>tensor, dst<span class="op">=</span><span class="dv">1</span>)
<span class="cf">else</span>:
<span class="co"># Receive tensor from process 0</span>
dist.recv(tensor<span class="op">=</span>tensor, src<span class="op">=</span><span class="dv">0</span>)
<span class="bu">print</span>(<span class="st">'Rank '</span>, rank, <span class="st">' has data '</span>, tensor[<span class="dv">0</span>])</code></pre></div>
<p>In the above example, both processes start with a zero tensor, then process 0 increments the tensor and sends it to process 1 so that they both end up with 1.0. Notice that process 1 needs to allocate memory in order to store the data it will receive.</p>
<p>Also notice that <code>send</code>/<code>recv</code> are <strong>blocking</strong>: both processes stop until the communication is completed. On the other hand immediates are <strong>non-blocking</strong>; the script continues its execution and the methods return a <code>DistributedRequest</code> object upon which we can choose to <code>wait()</code>.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">"""Non-blocking point-to-point communication."""</span>
<span class="kw">def</span> run(rank, size):
tensor <span class="op">=</span> torch.zeros(<span class="dv">1</span>)
req <span class="op">=</span> <span class="va">None</span>
<span class="cf">if</span> rank <span class="op">==</span> <span class="dv">0</span>:
tensor <span class="op">+=</span> <span class="dv">1</span>
<span class="co"># Send the tensor to process 1</span>
req <span class="op">=</span> dist.isend(tensor<span class="op">=</span>tensor, dst<span class="op">=</span><span class="dv">1</span>)
<span class="bu">print</span>(<span class="st">'Rank 0 started sending'</span>)
<span class="cf">else</span>:
<span class="co"># Receive tensor from process 0</span>
req <span class="op">=</span> dist.irecv(tensor<span class="op">=</span>tensor, src<span class="op">=</span><span class="dv">0</span>)
<span class="bu">print</span>(<span class="st">'Rank 1 started receiving'</span>)
<span class="bu">print</span>(<span class="st">'Rank 1 has data '</span>, tensor[<span class="dv">0</span>])
req.wait()
<span class="bu">print</span>(<span class="st">'Rank '</span>, rank, <span class="st">' has data '</span>, tensor[<span class="dv">0</span>])</code></pre></div>
<p>Running the above function might result in process 1 still having 0.0 while having already started receiving. However, after <code>req.wait()</code> has been executed we are guaranteed that the communication took place, and that the value stored in <code>tensor[0]</code> is 1.0.</p>
<p>Point-to-point communication is useful when we want a fine-grained control over the communication of our processes. They can be used to implement fancy algorithms, such as the one used in <a href="https://github.com/baidu-research/baidu-allreduce">Baidu's DeepSpeech</a> or <a href="https://research.fb.com/publications/imagenet1kin1h/">Facebook's large-scale experiments</a>.(c.f. <a href="#our-own-ring-allreduce">Section 4.1</a>)</p>
<h1 id="collective-communication">Collective Communication</h1>
<!--
* gather
* reduce
* broadcast
* scatter
* all_reduce
-->
<table>
<tbody>
<tr>
<td align="center">
<img src='./figs/scatter.png' width=100% /><br/> <b>Broadcast</b>
</td>
<td align="center">
<img src='./figs/all_gather.png' width=100% /><br/> <b>AllGather</b>
</td>
</tr>
<tr>
<td align="center">
<img src='./figs/reduce.png' width=100% /><br/> <b>Reduce</b>
</td>
<td align="center">
<img src='./figs/all_reduce.png' width=100% /><br/> <b>AllReduce</b>
</td>
</tr>
<tr>
<td align="center">
<img src='./figs/scatter.png' width=100% /><br/> <b>Scatter</b>
</td>
<td align="center">
<img src='./figs/gather.png' width=100% /><br/> <b>Gather</b>
</td>
</tr>
</tbody>
</table>
<p>As opposed to point-to-point communcation, collectives allow for communication patterns across all processes in a <strong>group</strong>. A group is a subset of all our processes. To create a group, we can pass a list of ranks to <code>dist.new_group(group)</code>. By default, collectives are executed on the all processes, also known as the <strong>world</strong>. For example, in order to obtain the sum of all tensors at all processes, we can use the <code>dist.all_reduce(tensor, op, group)</code> collective.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">""" All-Reduce example."""</span>
<span class="kw">def</span> run(rank, size):
<span class="co">""" Simple point-to-point communication. """</span>
group <span class="op">=</span> dist.new_group([<span class="dv">0</span>, <span class="dv">1</span>])
tensor <span class="op">=</span> torch.ones(<span class="dv">1</span>)
dist.all_reduce(tensor, op<span class="op">=</span>dist.reduce_op.SUM, group<span class="op">=</span>group)
<span class="bu">print</span>(<span class="st">'Rank '</span>, rank, <span class="st">' has data '</span>, tensor[<span class="dv">0</span>])</code></pre></div>
<p>Since we want the sum of all tensors in the group, we use <code>dist.reduce_op.SUM</code> as the reduce operator. Generally speaking, any commutative mathematical operation can be used as an operator. Out-of-the-box, PyTorch comes with 4 such operators, all working at the element-wise level:</p>
<ul>
<li><code>dist.reduce_op.SUM</code>,</li>
<li><code>dist.reduce_op.PRODUCT</code>,</li>
<li><code>dist.reduce_op.MAX</code>,</li>
<li><code>dist.reduce_op.MIN</code>.</li>
</ul>
<p>In addition to <code>dist.all_reduce(tensor, op, group)</code>, there are a total of 6 collectives currently implemented in PyTorch.</p>
<ul>
<li><code>dist.broadcast(tensor, src, group)</code>: Copies <code>tensor</code> from <code>src</code> to all other processes.</li>
<li><code>dist.reduce(tensor, dst, op, group)</code>: Applies <code>op</code> to all <code>tensor</code> and stores the result in <code>dst</code>.</li>
<li><code>dist.all_reduce(tensor, op, group)</code>: Same as reduce, but the result is stored in all processes.</li>
<li><code>dist.scatter(tensor, src, scatter_list, group)</code>: Copies the <span class="math inline">\(i^{\text{th}}\)</span> tensor <code>scatter_list[i]</code> to the <span class="math inline">\(i^{\text{th}}\)</span> process.</li>
<li><code>dist.gather(tensor, dst, gather_list, group)</code>: Copies <code>tensor</code> from all processes in <code>dst</code>.</li>
<li><code>dist.all_gather(tensor_list, tensor, group)</code>: Copies <code>tensor</code> from all processes to <code>tensor_list</code>, on all processes.</li>
</ul>
<h1 id="distributed-training">Distributed Training</h1>
<!--
* Gloo Backend
* Simple all_reduce on the gradients
* Point to optimized DistributedDataParallel
TODO: Custom ring-allreduce
-->
<p><strong>Note:</strong> You can find the example script of this section in <a href="https://github.com/seba-1511/dist_tuto.pth/">this GitHub repository</a>.</p>
<p>Now that we understand how the distributed module works, let us write something useful with it. Our goal will be to replicate the functionality of <a href="http://pytorch.org/docs/master/nn.html#torch.nn.parallel.DistributedDataParallel">DistributedDataParallel</a>. Of course, this will be a didactic example and in a real-world situtation you should use the official, well-tested and well-optimized version linked above.</p>
<p>Quite simply we want to implement a distributed version of stochastic gradient descent. Our script will let all processes compute the gradients of their model on their batch of data and then average their gradients. In order to ensure similar convergence results when changing the number of processes, we will first have to partition our dataset. (You could also use <a href="https://github.com/pytorch/tnt/blob/master/torchnet/dataset/splitdataset.py#L4">tnt.dataset.SplitDataset</a>, instead of the snippet below.)</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">""" Dataset partitioning helper """</span>
<span class="kw">class</span> Partition(<span class="bu">object</span>):
<span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data, index):
<span class="va">self</span>.data <span class="op">=</span> data
<span class="va">self</span>.index <span class="op">=</span> index
<span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):
<span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.index)
<span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, index):
data_idx <span class="op">=</span> <span class="va">self</span>.index[index]
<span class="cf">return</span> <span class="va">self</span>.data[data_idx]
<span class="kw">class</span> DataPartitioner(<span class="bu">object</span>):
<span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data, sizes<span class="op">=</span>[<span class="fl">0.7</span>, <span class="fl">0.2</span>, <span class="fl">0.1</span>], seed<span class="op">=</span><span class="dv">1234</span>):
<span class="va">self</span>.data <span class="op">=</span> data
<span class="va">self</span>.partitions <span class="op">=</span> []
rng <span class="op">=</span> Random()
rng.seed(seed)
data_len <span class="op">=</span> <span class="bu">len</span>(data)
indexes <span class="op">=</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, data_len)]
rng.shuffle(indexes)
<span class="cf">for</span> frac <span class="kw">in</span> sizes:
part_len <span class="op">=</span> <span class="bu">int</span>(frac <span class="op">*</span> data_len)
<span class="va">self</span>.partitions.append(indexes[<span class="dv">0</span>:part_len])
indexes <span class="op">=</span> indexes[part_len:]
<span class="kw">def</span> use(<span class="va">self</span>, partition):
<span class="cf">return</span> Partition(<span class="va">self</span>.data, <span class="va">self</span>.partitions[partition])</code></pre></div>
<p>With the above snippet, we can now simply partition any dataset using the following few lines:</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">""" Partitioning MNIST """</span>
<span class="kw">def</span> partition_dataset():
dataset <span class="op">=</span> datasets.MNIST(<span class="st">'./data'</span>, train<span class="op">=</span><span class="va">True</span>, download<span class="op">=</span><span class="va">True</span>,
transform<span class="op">=</span>transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((<span class="fl">0.1307</span>,), (<span class="fl">0.3081</span>,))
]))
size <span class="op">=</span> dist.get_world_size()
bsz <span class="op">=</span> <span class="dv">128</span> <span class="op">/</span> <span class="bu">float</span>(size)
partition_sizes <span class="op">=</span> [<span class="fl">1.0</span> <span class="op">/</span> size <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]
partition <span class="op">=</span> DataPartitioner(dataset, partition_sizes)
partition <span class="op">=</span> partition.use(dist.get_rank())
train_set <span class="op">=</span> torch.utils.data.DataLoader(partition,
batch_size<span class="op">=</span>bsz,
shuffle<span class="op">=</span><span class="va">True</span>)
<span class="cf">return</span> train_set, bsz</code></pre></div>
<p>Assuming we have 2 replicas, then each process will have a <code>train_set</code> of 60000 / 2 = 30000 samples. We also divide the batch size by the number of replicas in order to maintain the <em>overall</em> batch size of 128.</p>
<p>We can now write our usual forward-backward-optimize training code, and add a function call to average the gradients of our models. (The following is largely inspired from the official <a href="https://github.com/pytorch/examples/blob/master/mnist/main.py">PyTorch MNIST example</a>.)</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">""" Distributed Synchronous SGD Example """</span>
<span class="kw">def</span> run(rank, size):
torch.manual_seed(<span class="dv">1234</span>)
train_set, bsz <span class="op">=</span> partition_dataset()
model <span class="op">=</span> Net()
optimizer <span class="op">=</span> optim.SGD(model.parameters(),
lr<span class="op">=</span><span class="fl">0.01</span>, momentum<span class="op">=</span><span class="fl">0.5</span>)
num_batches <span class="op">=</span> ceil(<span class="bu">len</span>(train_set.dataset) <span class="op">/</span> <span class="bu">float</span>(bsz))
<span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):
epoch_loss <span class="op">=</span> <span class="fl">0.0</span>
<span class="cf">for</span> data, target <span class="kw">in</span> train_set:
data, target <span class="op">=</span> Variable(data), Variable(target)
optimizer.zero_grad()
output <span class="op">=</span> model(data)
loss <span class="op">=</span> F.nll_loss(output, target)
epoch_loss <span class="op">+=</span> loss.data[<span class="dv">0</span>]
loss.backward()
average_gradients(model)
optimizer.step()
<span class="bu">print</span>(<span class="st">'Rank '</span>, dist.get_rank(), <span class="st">', epoch '</span>,
epoch, <span class="st">': '</span>, epoch_loss <span class="op">/</span> num_batches) </code></pre></div>
<p>It remains to implement the <code>average_gradients(model)</code> function, which simply takes in a model and averages its gradients across the whole world.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">""" Gradient averaging. """</span>
<span class="kw">def</span> average_gradients(model):
size <span class="op">=</span> <span class="bu">float</span>(dist.get_world_size())
<span class="cf">for</span> param <span class="kw">in</span> model.parameters():
dist.all_reduce(param.grad.data, op<span class="op">=</span>dist.reduce_op.SUM)
param.grad.data <span class="op">/=</span> size </code></pre></div>
<p><em>Et voilà </em>! We successfully implemented distributed synchronous SGD and could train any model on a large computer cluster.</p>
<p><strong>Note:</strong> While the last sentence is <em>technically</em> true, there are <a href="http://seba-1511.github.io/dist_blog">a lot more tricks</a> required to implement a production-level implementation of synchronous SGD. Again, use what <a href="http://pytorch.org/docs/master/nn.html#torch.nn.parallel.DistributedDataParallel">has been tested and optimized</a>.</p>
<h2 id="our-own-ring-allreduce">Our Own Ring-Allreduce</h2>
<p>As an additional challenge, imagine that we wanted to implement DeepSpeech's efficient ring allreduce. This is fairly easily implemented using point-to-point collectives.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python"><span class="co">""" Implementation of a ring-reduce with addition. """</span>
<span class="kw">def</span> allreduce(send, recv):
rank <span class="op">=</span> dist.get_rank()
size <span class="op">=</span> dist.get_world_size()
send_buff <span class="op">=</span> th.zeros(send.size())
recv_buff <span class="op">=</span> th.zeros(send.size())
accum <span class="op">=</span> th.zeros(send.size())
accum[:] <span class="op">=</span> send[:]
left <span class="op">=</span> ((rank <span class="op">-</span> <span class="dv">1</span>) <span class="op">+</span> size) <span class="op">%</span> size
right <span class="op">=</span> (rank <span class="op">+</span> <span class="dv">1</span>) <span class="op">%</span> size
<span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(size <span class="op">-</span> <span class="dv">1</span>):
<span class="cf">if</span> i <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span>:
<span class="co"># Send send_buff</span>
send_req <span class="op">=</span> dist.isend(send_buff, right)
dist.recv(recv_buff, left)
accum[:] <span class="op">+=</span> recv[:]
<span class="cf">else</span>:
<span class="co"># Send recv_buff</span>
send_req <span class="op">=</span> dist.isend(recv_buff, right)
dist.recv(send_buff, left)
accum[:] <span class="op">+=</span> send[:]
send_req.wait()
recv[:] <span class="op">=</span> accum[:]</code></pre></div>
<p>In the above script, the <code>allreduce(send, recv)</code> function has a slightly different signature than the ones in PyTorch. It takes a <code>recv</code> tensor and will store the sum of all <code>send</code> tensors in it. As an exercise left to the reader, there is still one difference between our version and the one in DeepSpeech: their implementation divide the gradient tensor into <em>chunks</em>, so as to optimially utilize the communication bandwidth. (Hint: <a href="http://pytorch.org/docs/master/torch.html#torch.chunk">toch.chunk</a>)</p>
<h1 id="advanced-topics">Advanced Topics</h1>
<p>We are now ready to discover some of the more advanced functionalities of <code>torch.distributed</code>. Since there is a lot to cover, this section is divided into two subsections:</p>
<ol style="list-style-type: decimal">
<li>Communication Backends: where we learn how to use MPI and Gloo for GPU-GPU communication.</li>
<li>Initialization Methods: where we understand how to best setup the initial coordination phase in <code>dist.init_process_group()</code>.</li>
</ol>
<h2 id="communication-backends">Communication Backends</h2>
<p>One of the most elegant aspects of <code>torch.distributed</code> is its ability to abstract and build on top of different backends. As mentioned before, there are currently three backends implemented in PyTorch: TCP, MPI, and Gloo. They each have different specifications and tradeoffs, depending on the desired use-case. A comparative table of supported functions can be found <a href="http://pytorch.org/docs/master/distributed.html#module-torch.distributed">here</a>.</p>
<h3 id="tcp-backend">TCP Backend</h3>
<p>So far we have made extensive usage of the TCP backend. It is quite handy as a development platform, as it is guaranteed to work on most machines and operating systems. It also supports all point-to-point and collective functions on CPU. However, there is no support for GPUs and its communication routines are not as optimized as the MPI one.</p>
<h3 id="gloo-backend">Gloo Backend</h3>
<p>The <a href="https://github.com/facebookincubator/gloo">Gloo backend</a> provides an optimized implementation of <em>collective</em> communication procedures, both for CPUs and GPUs. It particularly shines on GPUs as it can perform communication without transferring data to the CPU's memory using <a href="https://developer.nvidia.com/gpudirect">GPUDirect</a>. It is also capable of using <a href="https://github.com/NVIDIA/nccl">NCCL</a> to perform fast intra-node communication and implements its <a href="https://github.com/facebookincubator/gloo/blob/master/docs/algorithms.md">own algorithms</a> for inter-node routines.</p>
<p>Since version 0.2.0, the Gloo backend is automatically included with the pre-compiled binaries of PyTorch. As you have surely noticed, our distributed SGD example does not work if you put <code>model</code> on the GPU. Let's fix it by first replacing <code>backend='gloo'</code> in <code>init_processes(rank, size, fn, backend='tcp')</code>. At this point, the script will still run on CPU but uses the Gloo backend behind the scenes. In order to use multiple GPUs, let us also do the following modifications:</p>
<ol start="0" style="list-style-type: decimal">
<li><code>init_processes(rank, size, fn, backend='tcp')</code> <span class="math inline">\(\rightarrow\)</span> <code>init_processes(rank, size, fn, backend='gloo')</code></li>
<li><code>model = Net()</code> <span class="math inline">\(\rightarrow\)</span> <code>model = Net().cuda(rank)</code></li>
<li><code>data, target = Variable(data), Variable(target)</code> <span class="math inline">\(\rightarrow\)</span> <code>data, target = Variable(data.cuda(rank)), Variable(target.cuda(rank))</code></li>
</ol>
<p>With the above modifications, our model is now training on two GPUs and you can monitor their utilization with <code>watch nvidia-smi</code>.</p>
<h3 id="mpi-backend">MPI Backend</h3>
<p>The Message Passing Interface (MPI) is a standardized tool from the field of high-performance computing. It allows to do point-to-point and collective communications and was the main inspiration for the API of <code>torch.distributed</code>. Several implementations of MPI exist (e.g. <a href="https://www.open-mpi.org/">Open-MPI</a>, <a href="http://mvapich.cse.ohio-state.edu/">MVAPICH2</a>, <a href="https://software.intel.com/en-us/intel-mpi-library">Intel MPI</a>) each optimized for different purposes. The advantage of using the MPI backend lies in MPI's wide availability - and high-level of optimization - on large computer clusters. <a href="https://developer.nvidia.com/mvapich">Some</a> <a href="https://developer.nvidia.com/ibm-spectrum-mpi">recent</a> <a href="http://www.open-mpi.org/">implementations</a> are also able to take advantage of CUDA IPC and GPU Direct technologies in order to avoid memory copies through the CPU.</p>
<p>Unfortunately, PyTorch's binaries can not include an MPI implementation and we'll have to recompile it by hand. Fortunately, this process is fairly simple given that upon compilation, PyTorch will look <em>by itself</em> for an available MPI implementation. The following steps install the MPI backend, by installing PyTorch <a href="https://github.com/pytorch/pytorch#from-source">from sources</a>.</p>
<ol style="list-style-type: decimal">
<li>Create and activate your Anaconda environment, install all the pre-requisites following <a href="https://github.com/pytorch/pytorch#from-source">the guide</a>, but do <strong>not</strong> run <code>python setup.py install</code> yet.</li>
<li>Choose and install your favorite MPI implementation. Note that enabling CUDA-aware MPI might require some additional steps. In our case, we'll stick to Open-MPI <em>without</em> GPU support: <code>conda install -c conda-forge openmpi</code></li>
<li>Now, go to your cloned PyTorch repo and execute <code>python setup.py install</code>.</li>
</ol>
<p>In order to test our newly installed backend, a few modifications are required.</p>
<ol style="list-style-type: decimal">
<li>Replace the content under <code>if __name__ == '__main__':</code> with <code>init_processes(0, 0, run, backend='mpi')</code>.</li>
<li>Run <code>mpirun -n 4 python myscript.py</code>.</li>
</ol>
<p>The reason for these changes is that MPI needs to create its own environment before spawning the processes. MPI will also spawn its own processes and perform the handshake described in <a href="#initialization-methods">Initialization Methods</a>, making the <code>rank</code>and <code>size</code> arguments of <code>init_process_group</code> superfluous. This is actually quite powerful as you can pass additional arguments to <code>mpirun</code> in order to tailor computational resources for each process. (Things like number of cores per process, hand-assigning machines to specific ranks, and <a href="https://www.open-mpi.org/faq/?category=running#mpirun-hostfile">some more</a>) Doing so, you should obtain the same familiar output as with the other communication backends.</p>
<h2 id="initialization-methods">Initialization Methods</h2>
<p>To finish this tutorial, let's talk about the very first function we called: <code>dist.init_process_group(backend, init_method)</code>. In particular, we will go over the different initialization methods which are responsible for the initial coordination step between each process. Those methods allow you to define how this coordination is done. Depending on your hardware setup, one of these methods should be naturally more suitable than the others. In addition to the following sections, you should also have a look at the <a href="http://pytorch.org/docs/master/distributed.html#initialization">official documentation</a>.</p>
<p>Before diving into the initialization methods, let's have a quick look at what happens behind <code>init_process_group</code> from the C/C++ perspective.</p>
<ol style="list-style-type: decimal">
<li>First, the arguments are parsed and validated.</li>
<li>The backend is resolved via the <code>name2channel.at()</code> function. A <code>Channel</code> class is returned, and will be used to perform the data transmission.</li>
<li>The GIL is dropped, and <code>THDProcessGroupInit()</code> is called. This instantiates the channel and adds the address of the master node.</li>
<li>The process with rank 0 will execute the <code>master</code> procedure, while all other ranks will be <code>workers</code>.</li>
<li>The master
<ol style="list-style-type: lower-alpha">
<li>Creates sockets for all workers.</li>
<li>Waits for all workers to connect.</li>
<li>Sends them information about the location of the other processes.</li>
</ol></li>
<li>Each worker
<ol style="list-style-type: lower-alpha">
<li>Creates a socket to the master.</li>
<li>Sends their own location information.</li>
<li>Receives information about the other workers.</li>
<li>Opens a socket and handshakes with all other workers.</li>
</ol></li>
<li>The initialization is done, and everyone is connected to everyone.</li>
</ol>
<h3 id="environment-variable">Environment Variable</h3>
<p>We have been using the environment variable initialization method throughout this tutorial. By setting the following four environment variables on all machines, all processes will be able to properly connect to the master, obtain information about the other processes, and finally handshake with them.</p>
<ul>
<li><code>MASTER_PORT</code>: A free port on the machine that will host the process with rank 0.</li>
<li><code>MASTER_ADDR</code>: IP address of the machine that will host the process with rank 0.</li>
<li><code>WORLD_SIZE</code>: The total number of processes, so that the master knows how many workers to wait for.</li>
<li><code>RANK</code>: Rank of each process, so they will know whether it is the master of a worker.</li>
</ul>
<h3 id="shared-file-system">Shared File System</h3>
<p>The shared filesystem requires all processes to have access to a shared file system, and will coordinate them through a shared file. This means that each process will open the file, write its information, and wait until everybody did so. After what all required information will be readily available to all processes. In order to avoid race conditions, the file system must support locking through <a href="http://man7.org/linux/man-pages/man2/fcntl.2.html">fcntl</a>. Note that you can specify ranks manually or let the processes figure it out by themselves. Be defining a unique <code>groupname</code> per job you can use the same file path for multiple jobs and safely avoid collision.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python">dist.init_process_group(init_method<span class="op">=</span><span class="st">'file:///mnt/nfs/sharedfile'</span>, world_size<span class="op">=</span><span class="dv">4</span>,
group_name<span class="op">=</span><span class="st">'mygroup'</span>)</code></pre></div>
<h3 id="tcp-init-multicast">TCP Init & Multicast</h3>
<p>Initializing via TCP can be achieved in two different ways:</p>
<ol style="list-style-type: decimal">
<li>By providing the IP address of the process with rank 0 and the world size.</li>
<li>By providing <em>any</em> valid IP <a href="https://en.wikipedia.org/wiki/Multicast_address">multicast address</a> and the world size.</li>
</ol>
<p>In the first case, all workers will be able to connect to the process with rank 0 and follow the procedure described above.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python">dist.init_process_group(init_method<span class="op">=</span><span class="st">'tcp://10.1.1.20:23456'</span>, rank<span class="op">=</span>args.rank, world_size<span class="op">=</span><span class="dv">4</span>)</code></pre></div>
<p>In the second case, the multicast address specifies the group of nodes who might potentially be active and the coordination can be handled by allowing each process to have an initial handshake before following the above procedure. In addition TCP multicast initialization also supports a <code>group_name</code> argument (as with the shared file method) allowing multiple jobs to be scheduled on the same cluster.</p>
<div class="sourceCode"><pre class="sourceCode python"><code class="sourceCode python">dist.init_process_group(init_method<span class="op">=</span><span class="st">'tcp://[ff15:1e18:5d4c:4cf0:d02d:b659:53ba:b0a7]:23456'</span>,
world_size<span class="op">=</span><span class="dv">4</span>)</code></pre></div>
<!--
## Internals
* The magic behind init_process_group:
1. validate and parse the arguments
2. resolve the backend: name2channel.at()
3. Drop GIL & THDProcessGroupInit: instantiate the channel and add address of master from config
4. rank 0 inits master, others workers
5. master: create sockets for all workers -> wait for all workers to connect -> send them each the info about location of other processes
6. worker: create socket to master, send own info, receive info about each worker, and then handshake with each of them
7. By this time everyone has handshake with everyone.
-->
<br /><br />
<center>
<strong>Acknowledgements</strong>
</center>
<p><small>I'd like to thank the PyTorch developers for doing such a good job on their implementation, documentation, and tests. When the code was unclear, I could always count on the <a href="http://pytorch.org/docs/master/distributed.html">docs</a> or the <a href="https://github.com/pytorch/pytorch/blob/master/test/test_distributed.py">tests</a> to find an answer. In particular, I'd like to thank Soumith Chintala, Adam Paszke, and Natalia Gimelshein for providing insightful comments and answering questions on early drafts.</small></p>
<footer class="footer">
<p><b>Writing Distributed Applications with PyTorch</b> - <i>Séb Arnold</i>, June 14, 2017.</p>
</footer>
</div> <!-- /container -->
<!-- jQuery (necessary for Bootstrap's JavaScript plugins) -->
<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.12.4/jquery.min.js"></script>
<!-- Include all compiled plugins (below), or include individual files as needed -->
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/js/bootstrap.min.js"></script>
<!--Prism for code highlighting-->
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.5.1/prism.js"></script>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.5.1/components/prism-python.min.js"></script>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.5.1/components/prism-c.min.js"></script>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.5.1/components/prism-java.min.js"></script>
<!--MathJax-->
<script type="text/x-mathjax-config">
var delim = '\u0024';
MathJax.Hub.Config({
tex2jax: {inlineMath: [[delim, delim], ['\\(','\\)']]}
});
</script>
<script src='https://cdn.rawgit.com/mathjax/MathJax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML'></script>
<!--KaTeX JavaScript-->
<script src="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.6.0/katex.min.js"></script>
<!--<script src="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.6.0/contrib/auto-render.min.js"></script>-->
<!--Pseudocode.js-->
<script src="https://cdn.rawgit.com/seba-1511/cdn/master/pseudocode.js/pseudocode.min.js"></script>
<!--<script src="https://rawgit.com/seba-1511/cdn/master/pseudocode.js/pseudocode.min.js"></script>-->
<!--Custom scripting-->
<script type="text/javascript">
// Allows prism to work properly
jQuery(document).ready(function() {
jQuery('.python').addClass('language-python').removeClass('python');
jQuery('.javascript').addClass('language-js').removeClass('javascript');
jQuery('.c').addClass('language-c').removeClass('c');
jQuery('.java').addClass('language-java').removeClass('java');
jQuery('.sourceCode').removeClass('sourceCode');
jQuery('table').addClass('table table-striped table-bordered');
jQuery('img').addClass('img-responsive');
// renderMathInElement(document.body, {
// displayMode: false,
// throwOnError: false,
// errorColor: '#cc0000',
// });
var math = document.getElementsByClassName("math");
// MathJax.Hub.Queue(["Typeset", MathJax.Hub, math]);
MathJax.Hub.Queue([math, ]);
Prism.highlightAll(false);
// The following uses pseudocode.js to render algorithms
var i, content, container;
var pseudocodeElems = document.querySelectorAll('pre.algo code');
var parents = document.querySelectorAll('pre.algo');
var displayOptions = {
indentSize: '1.5em',
commentDelimiter: '//',
lineNumber: true,
lineNumberPunc: ':',
noEnd: true,
captionCount: 1,
throwOnError: false,
};
for (i=0; i < pseudocodeElems.length; i++) {
content = pseudocodeElems[i].textContent;
container = document.createElement('div');
parents[i].parentNode.insertBefore(container, parents[i]);
pseudocode.render(content, container, displayOptions);
parents[i].style.display = 'none';
parents[i].parentNode.removeChild(parents[i]);
}
});
</script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-68693545-3', 'auto');
ga('send', 'pageview');
</script>
</body>
</html>