-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsavgol_filer.py
71 lines (64 loc) · 2.25 KB
/
savgol_filer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import warnings
import numpy as np
import scipy.signal as signal
import torch
import time
class SAVGOLFilter:
"""savgol_filter lib is from:
https://docs.scipy.org/doc/scipy/reference/generated/
scipy.signal.savgol_filter.html.
Args:
window_size (float):
The length of the filter window
(i.e., the number of coefficients).
window_length must be a positive odd integer.
polyorder (int):
The order of the polynomial used to fit the samples.
polyorder must be less than window_length.
Returns:
smoothed poses (np.ndarray, torch.tensor)
"""
def __init__(self, window_size):
super(SAVGOLFilter, self).__init__()
# 1-D Savitzky-Golay filter
self.window_size = window_size
self.polyorder = 2
def __call__(self, x=None):
# x.shape: [t,k,c]
if self.window_size % 2 == 0:
window_size = self.window_size - 1
else:
window_size = self.window_size
if window_size > x.shape[0]:
window_size = x.shape[0]
if window_size <= self.polyorder:
polyorder = window_size - 1
else:
polyorder = self.polyorder
assert polyorder > 0
assert window_size > polyorder
if len(x.shape) != 3:
warnings.warn('x should be a tensor or numpy of [T*M,K,C]')
assert len(x.shape) == 3
x_type = x
if isinstance(x, torch.Tensor):
if x.is_cuda:
x = x.cpu().numpy()
else:
x = x.numpy()
smooth_poses = np.zeros_like(x)
# smooth at different axis
C = x.shape[-1]
start = time.time()
for i in range(C):
smooth_poses[..., i] = signal.savgol_filter(
x[..., i], window_size, polyorder, axis=0)
end = time.time()
inference_time=end-start
if isinstance(x_type, torch.Tensor):
# we also return tensor by default
if x_type.is_cuda:
smooth_poses = torch.from_numpy(smooth_poses).cuda()
else:
smooth_poses = torch.from_numpy(smooth_poses)
return smooth_poses,inference_time