-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfp_solver.py
36 lines (31 loc) · 1.1 KB
/
fp_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
""" A fixed point solver
"""
from abc import abstractmethod
import torch
from torch import autograd
class FixedPointSolver(object):
""" fixed point solver base class """
@abstractmethod
def get_fixed_point(self, init_states, energy_fn):
"""
:param init_states: A list of tensor
:param energy_fn: A function that take `states` and return energy for each example
:return: The fixed point state
"""
pass
class FixedStepSolver(FixedPointSolver):
""" Use step size each time """
def __init__(self, step_size, max_steps=500):
self.step_size = step_size
self.max_steps = max_steps
def get_fixed_point(self, states, energy_fn):
""" Use fixed step size gradient decsent """
step = 0
while step < self.max_steps:
energy = energy_fn(states)
grads = autograd.grad(-torch.sum(energy), states)
for tensor, grad in zip(states, grads):
tensor[:] = tensor + self.step_size * grad
tensor[:] = torch.clamp(tensor, 0, 1)
step += 1
return states