-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
57 lines (45 loc) · 1.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import flwr as fl
import matplotlib.pyplot as plt
from server import strategy
from client import client_fn
parser = argparse.ArgumentParser(
description="Finetuning of a ViT with Flower Simulation."
)
parser.add_argument(
"--num-rounds",
type=int,
default=20,
help="Number of rounds.",
)
def main():
args = parser.parse_args()
# To control the degree of parallelism
# With default settings in this example,
# each client should take just ~1GB of VRAM.
client_resources = {
"num_cpus": 4,
"num_gpus": 0.2,
}
# Launch simulation
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=20,
client_resources=client_resources,
config=fl.server.ServerConfig(num_rounds=args.num_rounds),
strategy=strategy,
)
print(history)
# Basic plotting
global_accuracy_centralised = history.metrics_centralized["accuracy"]
round = [int(data[0]) for data in global_accuracy_centralised]
acc = [100.0 * data[1] for data in global_accuracy_centralised]
plt.plot(round, acc)
plt.xticks(round)
plt.grid()
plt.ylabel("Accuracy (%)")
plt.xlabel("Round")
plt.title("Federated finetuning of ViT for Flowers-102")
plt.savefig("central_evaluation.png")
if __name__ == "__main__":
main()