-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathload_data.py
150 lines (107 loc) · 4.46 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import numpy as np
import os, urllib.request
import pickle
import nmastandard as nmas
# Data URLs
urlfiles = {
'agvxh': 'steinmetz_part1.npz',
'uv3mw': 'steinmetz_part2.npz',
'ehmw2': 'steinmetz_part3.npz'}
# trials are all 2500ms, with 500 ms ITI and stim onset at 500 ms or 50 bins
# (https://neurostars.org/t/steinmetz-et-al-2019-dataset-questions/14539/8)
# input e.g: dat['response_time']
def timePoints2Bins(timePoints):
x = timePoints * 100 # seconds to ms in 10 ms bins
bins = np.floor(x)
return (bins)
# for a single session, input the time bin of the reference timepoint (e.g. response time),
# plus or minus duration in ms (e.g. 300ms)
# if method = 'add','minus'
def cutSpikeTimes(referenceTime_bin, spikes, method, msDuration):
msIdx = int(msDuration / 10)
# initialize new 3D matrix
newSpikes = np.zeros((spikes.shape[0], spikes.shape[1], int(msIdx)))
if np.min(referenceTime_bin) < msIdx:
raise IndexError('index falls outside of the dataframe')
for iTrial, x in enumerate(referenceTime_bin):
if method == 'add':
newSpikes[:, iTrial, :] = spikes[:, iTrial, int(x):int(x) + msIdx]
elif method == 'minus':
newSpikes[:, iTrial, :] = spikes[:, iTrial, int(x) - msIdx:int(x)]
return (newSpikes)
# 450 ms before response time:responseTime
# cutSpikeTimes(referenceTime_bin, spikes, 'minus', 450)
def pad_along_axis(array: np.ndarray, target_length: int, axis: int = 0):
pad_size = target_length - array.shape[axis]
if pad_size <= 0:
return array
npad = [(0, 0)] * array.ndim
npad[axis] = (0, pad_size)
return np.pad(array, pad_width=npad, mode='constant', constant_values=0)
def concatSpikesPerSession(sessionIdx):
alldatTmp = alldat.copy()
brainRegionIdx = np.concatenate([alldatTmp[x]['brain_area'] for x in sessionIdx])
maxTrials = np.max([alldatTmp[x]['spks'].shape[1] for x in sessionIdx])
spikesAllSessions = []
for num, i in enumerate(sessionIdx):
# alldatTmp[i]['spks'] = pad_along_axis(alldatTmp[i]['spks'], 300, axis=1)
if num == 0:
stacked = pad_along_axis(alldatTmp[i]['spks'], 300, axis=1)
else:
stacked = np.vstack((stacked, pad_along_axis(alldatTmp[i]['spks'], 300, axis=1)))
return (stacked, brainRegionIdx)
# print(alldatTmp[i]['spks'].shape)
# print(stacked.shape)
def load_neural_data(subject=11):
alldat = []
# get spike data
for aurl in urlfiles:
file = urlfiles[aurl]
if not os.path.exists(file):
urllib.request.urlretrieve('https://osf.io/{}/download'.format(aurl),
filename=file)
alldat = np.hstack((alldat, np.load(file, allow_pickle=True)['dat']))
print('Loaded: {}'.format(file))
dat = alldat[subject]
# spike data: dat['spks']
# spike data is a 3D matrix: neurons x trials x time
spikes = dat['spks']
_, trialLen, _ = spikes.shape
# for cutting spikes from stim onset to 300 ms
stimDuration = spikes[:, :, 50:80]
referenceTime_bin = timePoints2Bins(dat['response_time'])
mouseList = [alldat[x]['mouse_name'] for x in range(len(alldat))]
# index dataframe by mouse
indices = [i for i, elem in enumerate(mouseList) if 'Cori' in elem]
# NeuronsxtrialsxTime
stackedCori, brainRegion_Cori = concatSpikesPerSession(indices)
np.sum(brainRegion_Cori == 'MOs')
Cori_MOs = stackedCori[brainRegion_Cori == 'MOs', :, :]
Cori_MOs = Cori_MOs.transpose(2, 1, 0) # time x batch x neurons
return Cori_MOs
def gen_save_indices(spk, ntrials):
SEED = 2021
nmas.set_seed(seed=SEED)
idx_selection = np.random.choice(spk.shape[1], size=spk.shape[1])
trainidx = idx_selection[:ntrials*3//5]
validx = idx_selection[ntrials*3//5:ntrials*4//5]
testidx = idx_selection[ntrials*4//5:]
file = open("Cori_Post_trainidx.pkl",'wb')
pickle.dump(trainidx, file)
file = open("Cori_Post_validx.pkl",'wb')
pickle.dump(validx, file)
file = open("Cori_Post_testidx.pkl",'wb')
pickle.dump(testidx, file)
return
def save_data(data, filename):
file = open(f"{filename}.pkl",'wb')
pickle.dump(data, file)
return
def pull_indices():
file = open("Cori_Post_trainidx.pkl",'rb')
trainidx = pickle.load(file)
file = open("Cori_Post_validx.pkl",'rb')
validx = pickle.load(file)
file = open("Cori_Post_testidx.pkl",'rb')
testidx = pickle.load(file)
return trainidx, validx, testidx