-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnonconvex_problem.m
39 lines (32 loc) · 1.09 KB
/
nonconvex_problem.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
function problem = nonconvex_problem(A,r,p)
%
% r >= p
[n,n,k] = size(A);
problem.n=n; % dimension
problem.r=r; % width of data factorization
problem.p=p; % width of barycenter factorization
problem.k=k; % number of data matrices
problem.A = A;
problem.Y = zeros(n,r,k); % data factorization
for i=1:k
[V,D] = eigs(A(:,:,i),r); %eigs
problem.Y(:,:,i) = V*sqrt(D);
end
problem.cost = @(Z) cost(problem.Y, Z.X, multitransp(Z.Qt));
problem.M = euclidean_orthogonal_factory(n,r,p,k);
problem.egrad = @egrad_transp;
function g = egrad_transp(Z)
g0 = egrad(problem.Y, Z.X, multitransp(Z.Qt));
g.X=g0.X;
g.Qt = multitransp(g0.Q);
end
problem.costgrad = @costgrad;
function [c,g] = costgrad(Z)
[c,ge] = costegrad(problem.Y, Z.X, multitransp(Z.Qt));
ge.Qt = multitransp(ge.Q);
g = problem.M.egrad2rgrad(Z,ge);
end
problem.ehess = @(Z,W) problem.egrad(W);
problem.variance = @(B) cost_variance(problem.A,B);
problem.egrad_variance = @(B) egrad_variance(problem.A,B);
end