forked from p-koo/pEM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_pEM.m
162 lines (125 loc) · 5.12 KB
/
main_pEM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
clear all;
clc;
close all;
addpath('pEM');
%% load file
[filename,dirpath] = uigetfile('*.mat','Select protein track positions mat file');
load(fullfile(dirpath,filename));
%% user set parameters
% movie parameters
dt = .032;
dE = .032;
% rEM parameters
numReinitialize = 3;
% pEM parameters
minStates = 1; % minimum number of states to explore
maxStates = 20; % maximum number of states to explore
numPerturb = 50; % number of perturbation trials
maxiter = 10000; % maximum number of iterations within EM trial
convergence = 1e-7; % convergence criteria for change in log-likelihood
showplot = 0; % display the progress
%% run pEM
% structure for track info
trackInfo.numberOfTracks = length(X); % number of tracks
trackInfo.dimensions = size(X{1},2); % particle track dimensions
trackInfo.dt = dt; % frame duration
trackInfo.R = 1/6*dE/dt; % motion blur coefficient
% structure for pEM
params.numPerturbation = numPerturb; % number of perturbations trials
params.converged = convergence; % convergence condition for EM
params.maxiter = maxiter; % maximum number of iterations for EM
params.showplot = showplot; % displays progress of parameter estimates (0,1)
params.verbose = 1; % display progress on command window (0,1)
% calculate the displacements for each particle track
deltaX = cell(trackInfo.numberOfTracks,1);
for i = 1:trackInfo.numberOfTracks
deltaX{i} = diff(X{i});
end
% calculate relevant properties to enhance compuatational time
[trackInfo.trackLength,trackInfo.uniqueLength] = TrackLengthParameters(deltaX);
[trackInfo.diagonals,trackInfo.correlations,trackInfo.C] = CovarianceProperties(deltaX);
% diffusivity and static localization estimate from covariance-based estimator
trackInfo.D_cve = mean((trackInfo.diagonals+2*trackInfo.correlations)/(2*trackInfo.dt),2);
trackInfo.sigma_cve = mean(trackInfo.diagonals,2)/2 - trackInfo.D_cve*trackInfo.dt*(1-2*trackInfo.R);
% BIC Model Selection Loop
results = struct;
BIC = zeros(maxStates,1);
for numStates = minStates:maxStates
startTime = tic;
disp([num2str(numStates) ' state model']);
% random initialization
[D0,P0,S0] = RandomInitialization(numStates,trackInfo.D_cve,trackInfo.sigma_cve);
% run rEM
[baseD,baseS,baseP,Lmax] = rEM(deltaX,D0,P0,S0,params,trackInfo,numReinitialize);
% run pEM
[baseD,baseS,baseP,Lmax,posteriorProb] = pEM(deltaX,baseD,baseP,baseS,Lmax,params,trackInfo);
% calculate BIC
BIC(numStates) = Lmax - numStates/2*log(trackInfo.numberOfTracks);
% display results
disp('-------------------------------------------------------');
disp([num2str(numStates) ' state model results:']);
disp(['D_k = ' num2str(baseD) ' um^2/s']);
disp(['sigma_k = ' num2str(baseS) ' um']);
disp(['pi_k = ' num2str(baseP) ]);
disp(['L = ' num2str(Lmax)]);
disp(['BIC = ' num2str(BIC(numStates))]);
disp('-------------------------------------------------------');
% store results
results(numStates).numberOfStates = numStates;
results(numStates).BIC = BIC(numStates);
results(numStates).optimalD = baseD;
results(numStates).optimalS = baseS;
results(numStates).optimalP = baseP;
results(numStates).optimalL = Lmax;
results(numStates).posteriorProb = posteriorProb;
results(numStates).elapsedTime = toc(startTime);
% check BIC model selection
if numStates > 1
if BIC(numStates-1) ~= 0
if BIC(numStates) < BIC(numStates-1)
display(['Lower BIC found. Optimal number of State:' num2str(numStates-1)]);
break;
end
if numStates == maxStates
display(['Optimal number of states not found is larger than ' num2str(maxStates)]);
break;
end
end
end
end
[MAX,numStates] = max(BIC);
% store results
data.X = X;
data.params = params;
data.trackInfo = trackInfo;
data.results = results;
data.BIC = BIC(numStates);
data.optimalD = results(numStates).optimalD;
data.optimalP = results(numStates).optimalP;
data.optimalS = results(numStates).optimalS;
data.optimalL = results(numStates).optimalL;
data.posteriorProb = results(numStates).posteriorProb;
% display results
disp('Finished analysis');
disp('-------------------------------------------------------');
disp(['Optimal size: ' num2str(numStates) ' states']);
disp(['D_k = ' num2str(data.optimalD) ' um^2/s']);
disp(['sigma_k = ' num2str(data.optimalS) ' um']);
disp(['pi_k = ' num2str(data.optimalP) ]);
disp(['L = ' num2str(data.optimalL(end))]);
disp(['BIC = ' num2str(BIC(numStates))]);
disp('-------------------------------------------------------');
% save results
saveFolder = 'Results';
if ~isdir(saveFolder)
mkdir(saveFolder)
end
[tmp, name] = fileparts(filename);
disp(['Saving results: Results/' name '.mat']);
save(fullfile(saveFolder,[name '.mat']),'data');
%% Display posterior-weighted tracks
DisplayPosteriorTracks(X,data.posteriorProb);
%% Display posterior-weighted MSD
numLags = 10;
DisplayWeightedMSD(X,data.posteriorProb,numLags,trackInfo.dt);
%%