-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualize_threat_models.py
43 lines (32 loc) · 1.21 KB
/
visualize_threat_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from groot.model import GrootTreeClassifier
from groot.visualization import plot_adversary
from groot.adversary import DecisionTreeAdversary
from sklearn.datasets import make_moons
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white", context="paper")
"""
Train on a simple toy dataset with 4 pre-defined threat models, then
visualize the learned trees using the DecisionTreeAdversary.
"""
attack_models = [
(0, 0),
(0.1, 0.1),
("<", 0.1),
(0.1, "<>"),
]
X, y = make_moons(n_samples=100, noise=0.3, random_state=1)
X = MinMaxScaler().fit_transform(X)
_, ax = plt.subplots(2, 2, figsize=(5, 5))
for i, attack_model in enumerate(attack_models):
tree = GrootTreeClassifier(max_depth=3, attack_model=attack_model, random_state=1)
tree.fit(X, y)
adversary = DecisionTreeAdversary(tree, "groot", attack_model=attack_model, is_numeric=[True, True], n_categories=[None, None])
print(adversary.adversarial_accuracy(X, y))
axis = ax[i // 2, i % 2]
axis.set_title(str(attack_model), {"fontsize": 12.0})
plot_adversary(X, y, adversary, ax=axis)
plt.tight_layout()
plt.savefig("out/threat_models_vis.png", dpi=200)
plt.show()