-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.py
83 lines (62 loc) · 2.65 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from torch.utils.data import DataLoader
import flwr
from flwr.client import NumPyClient
from dataset import apply_transforms, get_dataset_with_partitions
from model import get_model, set_parameters, train
from peft import get_peft_model_state_dict
class FedViTClient(NumPyClient):
def __init__(self, trainset):
self.trainset = trainset
self.model = get_model()
# Determine device
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device) # send model to device
#def set_for_finetuning(self):
"""Freeze all parameter except those in the final head.
Only output MLP will be updated by the client and therefore, the only part of
the model that will be federated (hence, communicated back to the server for
aggregation.)
"""
# Disable gradients for everything
# self.model.requires_grad_(False)
# Now enable just for output head
# self.model.heads.requires_grad_(True)
def get_parameters(self, config):
"""Get locally updated parameters."""
#finetune_layers = self.model.heads
#return [val.cpu().numpy() for _, val in finetune_layers.state_dict().items()]
state_dict = get_peft_model_state_dict(self.model)
return [val.cpu().numpy() for _, val in state_dict.items()]
def fit(self, parameters, config):
set_parameters(self.model, parameters)
# Get some info from the config
# Get batchsize and LR set from server
batch_size = config["batch_size"]
lr = config["lr"]
trainloader = DataLoader(
self.trainset, batch_size=batch_size, num_workers=2, shuffle=True
)
# Set optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
# Train locally
avg_train_loss = train(
self.model, trainloader, optimizer, epochs=1, device=self.device
)
# Return locally-finetuned part of the model
return (
self.get_parameters(config={}),
len(trainloader.dataset),
{"train_loss": avg_train_loss},
)
# Downloads and partition dataset
federated_ox_flowers, _ = get_dataset_with_partitions(num_partitions=20)
def client_fn(cid: str):
"""Return a FedViTClient that trains with the cid-th data partition."""
trainset_for_this_client = federated_ox_flowers.load_partition(int(cid), "train")
trainset = trainset_for_this_client.with_transform(apply_transforms)
return FedViTClient(trainset).to_client()
# To be used with Flower Next
app = flwr.client.ClientApp(
client_fn=client_fn,
)