-
Notifications
You must be signed in to change notification settings - Fork 429
/
Copy path周志华《机器学习》课后习题解答系列(五):Ch4.4 - 编程实现CART算法与剪枝操作.html
434 lines (365 loc) · 12.4 KB
/
周志华《机器学习》课后习题解答系列(五):Ch4.4 - 编程实现CART算法与剪枝操作.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
<!DOCTYPE html>
<html>
<head>
<title>周志华《机器学习》课后习题解答系列(五):Ch4.4 - 编程实现CART算法与剪枝操作</title>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<style type="text/css">
/* GitHub stylesheet for MarkdownPad (http://markdownpad.com) */
/* Author: Nicolas Hery - http://nicolashery.com */
/* Version: b13fe65ca28d2e568c6ed5d7f06581183df8f2ff */
/* Source: https://github.com/nicolahery/markdownpad-github */
/* RESET
=============================================================================*/
html, body, div, span, applet, object, iframe, h1, h2, h3, h4, h5, h6, p, blockquote, pre, a, abbr, acronym, address, big, cite, code, del, dfn, em, img, ins, kbd, q, s, samp, small, strike, strong, sub, sup, tt, var, b, u, i, center, dl, dt, dd, ol, ul, li, fieldset, form, label, legend, table, caption, tbody, tfoot, thead, tr, th, td, article, aside, canvas, details, embed, figure, figcaption, footer, header, hgroup, menu, nav, output, ruby, section, summary, time, mark, audio, video {
margin: 0;
padding: 0;
border: 0;
}
/* BODY
=============================================================================*/
body {
font-family: Helvetica, arial, freesans, clean, sans-serif;
font-size: 14px;
line-height: 1.6;
color: #333;
background-color: #fff;
padding: 20px;
max-width: 960px;
margin: 0 auto;
}
body>*:first-child {
margin-top: 0 !important;
}
body>*:last-child {
margin-bottom: 0 !important;
}
/* BLOCKS
=============================================================================*/
p, blockquote, ul, ol, dl, table, pre {
margin: 15px 0;
}
/* HEADERS
=============================================================================*/
h1, h2, h3, h4, h5, h6 {
margin: 20px 0 10px;
padding: 0;
font-weight: bold;
-webkit-font-smoothing: antialiased;
}
h1 tt, h1 code, h2 tt, h2 code, h3 tt, h3 code, h4 tt, h4 code, h5 tt, h5 code, h6 tt, h6 code {
font-size: inherit;
}
h1 {
font-size: 28px;
color: #000;
}
h2 {
font-size: 24px;
border-bottom: 1px solid #ccc;
color: #000;
}
h3 {
font-size: 18px;
}
h4 {
font-size: 16px;
}
h5 {
font-size: 14px;
}
h6 {
color: #777;
font-size: 14px;
}
body>h2:first-child, body>h1:first-child, body>h1:first-child+h2, body>h3:first-child, body>h4:first-child, body>h5:first-child, body>h6:first-child {
margin-top: 0;
padding-top: 0;
}
a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 {
margin-top: 0;
padding-top: 0;
}
h1+p, h2+p, h3+p, h4+p, h5+p, h6+p {
margin-top: 10px;
}
/* LINKS
=============================================================================*/
a {
color: #4183C4;
text-decoration: none;
}
a:hover {
text-decoration: underline;
}
/* LISTS
=============================================================================*/
ul, ol {
padding-left: 30px;
}
ul li > :first-child,
ol li > :first-child,
ul li ul:first-of-type,
ol li ol:first-of-type,
ul li ol:first-of-type,
ol li ul:first-of-type {
margin-top: 0px;
}
ul ul, ul ol, ol ol, ol ul {
margin-bottom: 0;
}
dl {
padding: 0;
}
dl dt {
font-size: 14px;
font-weight: bold;
font-style: italic;
padding: 0;
margin: 15px 0 5px;
}
dl dt:first-child {
padding: 0;
}
dl dt>:first-child {
margin-top: 0px;
}
dl dt>:last-child {
margin-bottom: 0px;
}
dl dd {
margin: 0 0 15px;
padding: 0 15px;
}
dl dd>:first-child {
margin-top: 0px;
}
dl dd>:last-child {
margin-bottom: 0px;
}
/* CODE
=============================================================================*/
pre, code, tt {
font-size: 12px;
font-family: Consolas, "Liberation Mono", Courier, monospace;
}
code, tt {
margin: 0 0px;
padding: 0px 0px;
white-space: nowrap;
border: 1px solid #eaeaea;
background-color: #f8f8f8;
border-radius: 3px;
}
pre>code {
margin: 0;
padding: 0;
white-space: pre;
border: none;
background: transparent;
}
pre {
background-color: #f8f8f8;
border: 1px solid #ccc;
font-size: 13px;
line-height: 19px;
overflow: auto;
padding: 6px 10px;
border-radius: 3px;
}
pre code, pre tt {
background-color: transparent;
border: none;
}
kbd {
-moz-border-bottom-colors: none;
-moz-border-left-colors: none;
-moz-border-right-colors: none;
-moz-border-top-colors: none;
background-color: #DDDDDD;
background-image: linear-gradient(#F1F1F1, #DDDDDD);
background-repeat: repeat-x;
border-color: #DDDDDD #CCCCCC #CCCCCC #DDDDDD;
border-image: none;
border-radius: 2px 2px 2px 2px;
border-style: solid;
border-width: 1px;
font-family: "Helvetica Neue",Helvetica,Arial,sans-serif;
line-height: 10px;
padding: 1px 4px;
}
/* QUOTES
=============================================================================*/
blockquote {
border-left: 4px solid #DDD;
padding: 0 15px;
color: #777;
}
blockquote>:first-child {
margin-top: 0px;
}
blockquote>:last-child {
margin-bottom: 0px;
}
/* HORIZONTAL RULES
=============================================================================*/
hr {
clear: both;
margin: 15px 0;
height: 0px;
overflow: hidden;
border: none;
background: transparent;
border-bottom: 4px solid #ddd;
padding: 0;
}
/* TABLES
=============================================================================*/
table th {
font-weight: bold;
}
table th, table td {
border: 1px solid #ccc;
padding: 6px 13px;
}
table tr {
border-top: 1px solid #ccc;
background-color: #fff;
}
table tr:nth-child(2n) {
background-color: #f8f8f8;
}
/* IMAGES
=============================================================================*/
img {
max-width: 100%
}
</style>
</head>
<body>
<p>这里采用<strong>Python-sklearn</strong>的方式,环境搭建可参考<a href="http://blog.csdn.net/snoopy_yuan/article/details/61211639"> 数据挖掘入门:Python开发环境搭建(eclipse-pydev模式)</a>.</p>
<p>相关答案和源代码托管在我的Github上:<a href="https://github.com/PY131/Machine-Learning_ZhouZhihua">PY131/Machine-Learning_ZhouZhihua</a>.</p>
<h2>4.6 编程实现CART算法与剪枝操作</h2>
<blockquote>
<p><img src="Ch4/4.4.png" />
<img src="Ch4/4.4.1.png" /></p>
</blockquote>
<ul>
<li>
<p>决策树基于训练集完全构建易陷入<strong>过拟合</strong>。为提升泛化能力。通常需要对决策树进行<strong>剪枝</strong>。</p>
</li>
<li>
<p>原始的CART算法采用<strong>基尼指数</strong>作为最优属性划分选择标准。</p>
</li>
</ul>
<p>编码基于Python实现,详细解答和编码过程如下:(<a href="https://github.com/PY131/Machine-Learning_ZhouZhihua/tree/master/ch4_decision_tree/4.4_CART">查看完整代码和数据集</a>):</p>
<h3>1.最优划分属性选择 - 基尼指数</h3>
<p>同信息熵类似,<strong>基尼指数(Gini index)</strong>也常用以度量<strong>数据纯度<strong>,一般基尼值越小,数据纯度越高,相关内容可参考书p79,最典型的相关决策树生成算法是</strong>CART算法</strong>。</p>
<p>下面是某属性下数据的基尼指数计算代码样例(连续和离散的不同操作):</p>
<pre><code>def GiniIndex(df, attr_id):
'''
calculating the gini index of an attribution
@param df: dataframe, the pandas dataframe of the data_set
@param attr_id: the target attribution in df
@return gini_index: the gini index of current attribution
@return div_value: for discrete variable, value = 0
for continuous variable, value = t (the division value)
'''
gini_index = 0 # info_gain for the whole label
div_value = 0 # div_value for continuous attribute
n = len(df[attr_id]) # the number of sample
# 1.for continuous variable using method of bisection
if df[attr_id].dtype == (float, int):
sub_gini = {} # store the div_value (div) and it's subset gini value
df = df.sort([attr_id], ascending=1) # sorting via column
df = df.reset_index(drop=True)
data_arr = df[attr_id]
label_arr = df[df.columns[-1]]
for i in range(n-1):
div = (data_arr[i] + data_arr[i+1]) / 2
sub_gini[div] = ( (i+1) * Gini(label_arr[0:i+1]) / n ) \
+ ( (n-i-1) * Gini(label_arr[i+1:-1]) / n )
# our goal is to get the min subset entropy sum and it's divide value
div_value, gini_index = min(sub_gini.items(), key=lambda x: x[1])
# 2.for discrete variable (categoric variable)
else:
data_arr = df[attr_id]
label_arr = df[df.columns[-1]]
value_count = ValueCount(data_arr)
for key in value_count:
key_label_arr = label_arr[data_arr == key]
gini_index += value_count[key] * Gini(key_label_arr) / n
return gini_index, div_value
</code></pre>
<hr />
<h3>2.完全决策树生成</h3>
<p>下图是基于基尼指数进行最优划分属性选择,然后在数据集watermelon-2.0全集上递归生成的完全决策树。(基础算法和流程可参考<a href="http://blog.csdn.net/snoopy_yuan/article/details/68959025">题4.3</a>,或<a href="https://github.com/PY131/Machine-Learning_ZhouZhihua/blob/master/ch4_decision_tree/4.4_CART/src/decision_tree.py">查看完整代码</a>)</p>
<p><img src="Ch4/4.4_decision_tree_CART.png" /></p>
<hr />
<h3>3.剪枝操作</h3>
<p>参考书4.3节(p79-83),<strong>剪枝</strong>是提高决策树模型泛化能力的重要手段,一般将剪枝操作分为<strong>预剪枝、后剪枝</strong>两种方式,简要说明如下:</p>
<table>
<thead>
<tr>
<th>剪枝类型</th>
<th>搜索方向</th>
<th>方法开销</th>
<th>结果树的大小</th>
<th>拟合风险</th>
<th>泛化能力</th>
</tr>
</thead>
<tbody>
<tr>
<td>预剪枝(prepruning)</td>
<td>自顶向下</td>
<td>小(与建树同时进行)</td>
<td>很小</td>
<td>存在欠拟合风险</td>
<td>较强</td>
</tr>
<tr>
<td>后剪枝(postpruning)</td>
<td>自底向上</td>
<td>较大(决策树已建好)</td>
<td>较小</td>
<td></td>
<td>很强</td>
</tr>
</tbody>
</table>
<p>基于训练集与测试集的划分,编程实现预剪枝与后剪枝操作:</p>
<h4>3.1 完全决策树</h4>
<p>下图是基于训练集生成的完全决策树模型,可以看到,在有限的数据集下,树的结构过于复杂,模型的泛化能力应该很差:</p>
<p><img src="Ch4/4.4_decision_tree_full.png" /></p>
<p>此时在测试集(验证集)上进行预测,精度结果如下:</p>
<pre><code>accuracy of full tree: 0.571
</code></pre>
<h4>3.2 预剪枝</h4>
<p>参考书p81,采用预剪枝生成决策树,<a href="https://github.com/PY131/Machine-Learning_ZhouZhihua/blob/master/ch4_decision_tree/4.4_CART/src/decision_tree.py">查看相关代码</a>, 结果树如下:</p>
<p><img src="Ch4/4.4_decision_tree_pre.png" /></p>
<p>现在的决策树退化成了单个节点,(比决策树桩还要简单),其测试精度为:</p>
<pre><code>accuracy of pre-pruning tree: 0.571
</code></pre>
<p>此精度与完全决策树相同。进一步分析如下:</p>
<ul>
<li>基于<strong>奥卡姆剃刀</strong>准则,这棵决策树模型要优于前者;</li>
<li>由于数据集小,所以预剪枝优越性不明显,实际预剪枝操作是有较好的模型提升效果的。</li>
<li>此处结果模型太简单,有严重的<strong>欠拟合风险</strong>。</li>
</ul>
<h4>3.3 后剪枝</h4>
<p>参考书p83-84 ,采用后剪枝生成决策树,<a href="https://github.com/PY131/Machine-Learning_ZhouZhihua/blob/master/ch4_decision_tree/4.4_CART/src/decision_tree.py">查看相关代码</a>,结果树如下:</p>
<p><img src="Ch4/4.4_decision_tree_post.png" /></p>
<p>决策树相较完全决策树有了很大的简化,其测试精度为:</p>
<pre><code>accuracy of post-pruning tree: 0.714
</code></pre>
<p>此精度相较于前者有了很大的提升,说明经过后剪枝,模型<strong>泛化能力</strong>变强,同时保留了一定树规模,<strong>拟合</strong>较好。</p>
<h3>4.总结</h3>
<ul>
<li>由于本题数据集较差,决策树的总体表现一般,交叉验证存在很大波动性。</li>
<li>剪枝操作是提升模型泛化能力的重要途径,在不考虑建模开销的情况下,后剪枝一般会优于预剪枝。</li>
<li>除剪枝外,常采用最大叶深度约束等方法来保持决策树泛化能力。</li>
</ul>
<hr />
</body>
</html>
<!-- This document was created with MarkdownPad, the Markdown editor for Windows (http://markdownpad.com) -->