-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSAT_LOOCV_TemplateMatching.py
198 lines (160 loc) · 5.82 KB
/
SAT_LOOCV_TemplateMatching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 30 15:20:54 2024
Classification of the Spatial Auditory Attention using Template Matching
------------------------------------------------------------------------
Based on procedure mentioned in Bleichner et al 2016
DOI: 10.1088/1741-2560/13/6/066004
Feature used: Normalised Cross-Correlation Function (NCF) between each trial
and the templates.
- Templates are the average ERP of left attended trails
and right attended triasl
Classification: Leave-one-out Cross Validation Template Matching Approach
@author: Abin Jacob
Carl von Ossietzky University Oldenburg
"""
#%% libraries
import mne
import numpy as np
import matplotlib.pyplot as plt
import os.path as op
from scipy.io import loadmat
from scipy.signal import correlate
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import accuracy_score
#%% load data
rootpath = r'/Users/abinjacob/Documents/02. NeuroCFN/Research Module/RM02/Data'
# EEGLab file to load (.set)
filename = 'P04_SAT_AllProcessed.set'
filepath = op.join(rootpath,filename)
# load file in mne
raw = mne.io.read_raw_eeglab(filepath, eog= 'auto', preload= True)
# eeg paramters
sfreq = raw.info['sfreq']
# eeg signal
EEG = raw.get_data()
nchannels, nsamples = EEG.shape
# channel names
chnames = raw.info['ch_names']
# re-referencing
raw.set_eeg_reference(ref_channels=['E25', 'E25'])
# extracting events
events, eventinfo = mne.events_from_annotations(raw, verbose= False)
# loading correct trials
trialsfile = 'P04Sat_CorrTrials.mat'
corrTrialsData = loadmat(op.join(rootpath, trialsfile))
# correct trials
corrTrials = [item[0] for item in corrTrialsData['event_name'][0]]
#%% epoching
tmin = -0.25
tmax = 3
# extracting event ids of correct trials from eventinfo
event_id =[eventinfo[corrTrials[idx]] for idx in range(len(corrTrials))]
# epoching
epochs = mne.Epochs(
raw,
events= events,
event_id= event_id,
tmin=tmin, tmax=tmax,
baseline= (tmin, 0),
preload= True,
event_repeated = 'merge',
reject={'eeg': 4.0}) # Reject epochs based on maximum peak-to-peak signal amplitude (PTP)
# event id of left attended trials
trlsLeft = [event_id[idx] for idx, trial in enumerate(corrTrials) if 'left' in trial]
# event id of right attended trials
trlsRight = [event_id[idx] for idx, trial in enumerate(corrTrials) if 'right' in trial]
#%% creating template after splitting the data into train and test sets
# data vector
data = epochs.get_data()
# labels
labels = []
for trial in corrTrials:
if 'left' in trial:
labels.append(0)
elif 'right' in trial:
labels.append(1)
# converting labels to numpy array
labels = np.array(labels)
#%% functions to compute Normalised Cross-Correlation and Template Matching
# function to create template
def createTemplate(X, y):
templateLeft = np.mean(X[y==0], axis=0)
templateRight = np.mean(X[y==1], axis=0)
return np.stack((templateLeft, templateRight))
# function to compute Normalised Cross-Correlation
def computeNCC(signal, template, jitter, start, stop):
# shape of template
ntemp, nchan, _ = template.shape
# convert jitter to samples
jitter = int(jitter * sfreq / 1000)
# convert start and stop period to samples
start = int(start * sfreq / 1000)
stop = int(stop * sfreq / 1000)
# initialise array to store ncc
ncc = np.zeros((nchan, ntemp))
# loop over channels
for ichan in range(nchan):
s = signal[ichan, start:stop]
# loop over left and right templates for the channel
for itemp in range(ntemp):
t = template[itemp, ichan, start:stop]
# normalise signal and tempate
snorm = (s - np.mean(s)) / (np.std(s) * len(s))
tnorm = (t - np.mean(t)) / np.std(t)
# computing cross correlation
corr = correlate(snorm, tnorm, mode= 'full')
lag = np.arange(-len(snorm)+1, len(snorm))
# find the index of maximal correlation within the jitter range
jittrange = (lag >= -jitter) & (lag <= jitter)
# find max corr valiue within the period
corrMAX = np.max(corr[jittrange])
ncc[ichan, itemp] = corrMAX
return ncc
# function to compute the difference between left and right template correlation
def computeDiff(ncc):
nchan, _ = ncc.shape
diff = np.zeros(nchan)
for ichan in range(nchan):
diff[ichan] = ncc[ichan,0] - ncc[ichan,1]
value = np.nansum(diff)
return value
#%% leave one out template matching
# initialising leave-one-out cross validation
loo = LeaveOneOut()
# duration of ERP to consider (400ms to 2800ms - to exclude the onset and offset responses)
start = 400
stop = 2800
# jitter in ms
jitter = 50
acc = []
for trainid,testid in loo.split(data):
# split data and labels into train and test set
trainData, testData = data[trainid], data[testid]
trainLabel, testLabel = labels[trainid], labels[testid]
# creating template for left and right trials (cond x nchan x tpts)
template = createTemplate(trainData, trainLabel)
# create signal matrix (nchan x tpts)
signal = testData[0,:,:]
# computing NCC
ncc = computeNCC(signal, template, jitter, start, stop)
# find the difference of left and right NCC across channels
value = computeDiff(ncc)
# classificatin
if value > 0:
# classify as left
decision = 0
elif value < 0:
# classify as right
decision = 1
# computing accuracy
if decision == testLabel:
# true
acc.append(1)
else:
# false
acc.append(0)
# final accuracy in percentage
accuracy = (np.sum(acc)/ data.shape[0]) * 100
print(f'{accuracy:.2f}%')