-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcooperativeLearning.txt
25 lines (12 loc) · 1.83 KB
/
cooperativeLearning.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
Cooperative Learning
Cooperative learning works by weighted averaging of models recieved from workers. The weights in this averaging are determined by the number of batches of data that the recieved model has been trained on. Each model has a learning number assigned to it, this number represents the number of batches the model has been trained on in total. When a model is sent to a worker from the manager, its learning number at that point is recorded. This is so that when the model returns from the worker, the manager knows how many batches it has been trained on. The model's weight in the averaging proceedure is dependant on this number, which is the difference between the model's current learning number and the learning number it was sent to the worker with.
Each averaging is only between two models at a time, the manager's model, which will be refered to as the main model, and the model from the worker being recieved by the manager. The worker's weight is proportional to the difference between the learning number of its model when it was first sent out, and the model's current learning number. The main model's weight is proportional to the difference between its learning number and the learning number or the worker's model when it was sent out. I say proportional because the weights must be normalized so that they sum to one.
After the averaging proceedure, the main model's learning number is incremented by the difference between the learning number of the model when it was sent out and the model's current learning number.
starting_ln = the recieved models learning number when it was sent out
current_ln = the received model's current learning number
main_ln = the main model's learning number
a = current_ln - starting_ln
b = main_ln - starting_ln
normConst = a + b
newModel = (a*mainModel + b*recievedModel)/normConst
main_ln += a