-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathtest_nonNegFactorization.m
67 lines (54 loc) · 1.92 KB
/
test_nonNegFactorization.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
% This script shows how to use FASTA to solve
% minimize_{X,Y} mu|X|+.5||S-XY'||^2
% subject to norm(Y,'inf')<=1
% Where S is an MxN matrix of data, X is an MXK matrix, and Y is a NXK
% matrix. The parameter 'mu' controls the strength of
% the L1 regularizer.
%% Define problem parameters
M = 800; % rows of data matrix
N = 200; % cols of data matrix
K = 10; % rank of factorization
mu = 1;
fprintf('Testing non-negative factorization using N=%d, M=%d, K=%d\n',N,M,K);
%% Create non-negative factors
X = rand(M,K);
Y = rand(N,K);
% Make X 75% sparse
X = X.*(rand(M,K)>.75);
% Create observation/data matrix
S = X*Y';
S = S+randn(size(S))*0.1;
%% The initial iterate: a guess at the solution
X0 = zeros(M,K);
Y0 = rand(N,K);
%% OPTIONAL: give some extra instructions to FASTA using the 'opts' struct
opts = [];
%opts.tol = 1e-8; % Use super strict tolerance
opts.recordObjective = true; % Record the objective function so we can plot it
opts.verbose=true;
opts.stringHeader=' '; % Append a tab to all text output from FISTA. This option makes formatting look a bit nicer.
%% Call the solver 3 times
% Default behavior: adaptive stepsizes
[ Xsol,Ysol, outs_adapt ] = fasta_nonNegativeFactorization( S, X0, Y0, mu, opts );
% Turn on FISTA-type acceleration
opts.accelerate = true;
[ Xsol,Ysol, outs_accel ] = fasta_nonNegativeFactorization( S, X0, Y0, mu, opts );
% Do plain old vanilla FBS
opts.accelerate = false;
opts.adaptive = false;
[ Xsol,Ysol, outs_fbs ] = fasta_nonNegativeFactorization( S, X0, Y0, mu, opts );
%% Plot results
% This block allows plotting to be turned off by setting noPlots=true.
if exist('noPlots','var')
return;
end
figure('Position', [300, 300, 400, 300]);
subplot(2,2,1)
imagesc(X); title('Xtrue');
subplot(2,2,2)
imagesc(Y); title('Ytrue');
subplot(2,2,3)
imagesc(Xsol); title('Xrecovered');
subplot(2,2,4)
imagesc(Ysol); title('Yrecovered');
plotConvergenceCurves;