Skip to content

Commit

Permalink
Merge pull request #10 from viktor-ktorvi/master
Browse files Browse the repository at this point in the history
Fix setting device parameter even more
  • Loading branch information
KindXiaoming authored May 1, 2024
2 parents 0de3c70 + 1e50486 commit c63ff04
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions kan/KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def update_grid_from_samples(self, x):
batch = x.shape[0]
x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_pos = torch.sort(x, dim=1)[0]
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device)
num_interval = self.grid.shape[1] - 1
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[:, ids]
margin = 0.01
grid_uniform = torch.cat([grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in np.linspace(0, 1, num=self.grid.shape[1])], dim=1)
self.grid.data = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device)

def initialize_grid_from_parent(self, parent, x):
'''
Expand Down Expand Up @@ -251,10 +251,10 @@ def initialize_grid_from_parent(self, parent, x):
x_pos = parent.grid
sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0.).to(self.device)
sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1)
y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k)
y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k, device=self.device)
percentile = torch.linspace(-1, 1, self.num + 1).to(self.device)
self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0)
self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k)
self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k, self.device)

def get_subset(self, in_id, out_id):
'''
Expand Down

0 comments on commit c63ff04

Please sign in to comment.