dal - dual augmented Lagrangian method for sparse learaning/reconstruction Overview: Solves the following optimization problem xx = argmin f(x) + lambda*c(x) where f is a user specified (convex, smooth) loss function and c is a measure of sparsity (currently L1 or grouped L1) Syntax: [ww, uu, status] = dal(prob, ww0, uu0, A, B, lambda, <opt>) Inputs: prob : structure that contains the following fields: .obj : DAL objective function .floss : structure with three fields (p: primal loss, d: dual loss, args: arguments to the loss functions) .fspec : function handle to the regularizer spectrum function (absolute values for L1, vector of norms for grouped L1, etc.) .dnorm : function handle to the conjugate of the regularizer function (max(abs(x)) for L1, max(norms) for grouped L1, etc.) .softth : soft threshold function .mm : number of samples (scalar) .nn : number of unknown variables (scalar) .ll : lower constraint for the Lagrangian multipliers (mm x 1) .uu : upper constraint for the Lagrangian multipliers (mm x 1) .Ac : inequality constraint Ac*aa<=bc for the LMs (pp x mm) .bc : (pp x 1) .info : auxiliary variables for the objective function .stopcond : function handle for the stopping condition .hessMult : function handle to the Hessian product function (H*x) .softth : function handle to the "soft threshold" function ww0 : initial solution ((nn x 1) or (ns x nc) with ns*nc=nn) uu0 : initial unregularized component (nu x 1) A : design matrix (mm x nn) B : design matrix for the unregularized component (mm x nu) lambda : regularization constant (scalar) <opt> : list of 'fieldname1', value1, 'filedname2', value2, ... aa : initial Lagrangian multiplier [mm,1] (default zero(mm,1)) tol : tolerance (default 1e-3) maxiter : maximum number of outer iterations (default 100) eta : initial barrier parameter (default 1) eps : initial internal tolerance parameter (default 1e-4) eta_multp : multiplying factor for eta (default 2) eps_multp : multiplying factor for eps (default 0.5) solver : internal solver. Can be either: 'nt' : Newton method with cholesky factorization (default) 'ntsv' : Newton method saves memory (slightly slower) 'cg' : Newton method with PCG 'qn' : Quasi-Newton method display : display level (0: none, 1: only the last, 2: every outer iteration, (default) 3: every inner iteration) iter : output the value of ww at each iteration (boolean, default 0) Outputs: ww : the final solution uu : the final unregularized component status : various status values Reference: "Dual Augmented Lagrangian Method for Efficient Sparse Reconstruction" Ryota Tomioka and Masashi Sugiyama http://arxiv.org/abs/0904.0584 Copyright(c) 2009 Ryota Tomioka This software is distributed under the MIT license. See license.txt
0001 % dal - dual augmented Lagrangian method for sparse learaning/reconstruction 0002 % 0003 % Overview: 0004 % Solves the following optimization problem 0005 % xx = argmin f(x) + lambda*c(x) 0006 % where f is a user specified (convex, smooth) loss function and c 0007 % is a measure of sparsity (currently L1 or grouped L1) 0008 % 0009 % Syntax: 0010 % [ww, uu, status] = dal(prob, ww0, uu0, A, B, lambda, <opt>) 0011 % 0012 % Inputs: 0013 % prob : structure that contains the following fields: 0014 % .obj : DAL objective function 0015 % .floss : structure with three fields (p: primal loss, d: dual loss, args: arguments to the loss functions) 0016 % .fspec : function handle to the regularizer spectrum function 0017 % (absolute values for L1, vector of norms for grouped L1, etc.) 0018 % .dnorm : function handle to the conjugate of the regularizer function 0019 % (max(abs(x)) for L1, max(norms) for grouped L1, etc.) 0020 % .softth : soft threshold function 0021 % .mm : number of samples (scalar) 0022 % .nn : number of unknown variables (scalar) 0023 % .ll : lower constraint for the Lagrangian multipliers (mm x 1) 0024 % .uu : upper constraint for the Lagrangian multipliers (mm x 1) 0025 % .Ac : inequality constraint Ac*aa<=bc for the LMs (pp x mm) 0026 % .bc : (pp x 1) 0027 % .info : auxiliary variables for the objective function 0028 % .stopcond : function handle for the stopping condition 0029 % .hessMult : function handle to the Hessian product function (H*x) 0030 % .softth : function handle to the "soft threshold" function 0031 % ww0 : initial solution ((nn x 1) or (ns x nc) with ns*nc=nn) 0032 % uu0 : initial unregularized component (nu x 1) 0033 % A : design matrix (mm x nn) 0034 % B : design matrix for the unregularized component (mm x nu) 0035 % lambda : regularization constant (scalar) 0036 % <opt> : list of 'fieldname1', value1, 'filedname2', value2, ... 0037 % aa : initial Lagrangian multiplier [mm,1] (default zero(mm,1)) 0038 % tol : tolerance (default 1e-3) 0039 % maxiter : maximum number of outer iterations (default 100) 0040 % eta : initial barrier parameter (default 1) 0041 % eps : initial internal tolerance parameter (default 1e-4) 0042 % eta_multp : multiplying factor for eta (default 2) 0043 % eps_multp : multiplying factor for eps (default 0.5) 0044 % solver : internal solver. Can be either: 0045 % 'nt' : Newton method with cholesky factorization (default) 0046 % 'ntsv' : Newton method saves memory (slightly slower) 0047 % 'cg' : Newton method with PCG 0048 % 'qn' : Quasi-Newton method 0049 % display : display level (0: none, 1: only the last, 2: every 0050 % outer iteration, (default) 3: every inner iteration) 0051 % iter : output the value of ww at each iteration 0052 % (boolean, default 0) 0053 % Outputs: 0054 % ww : the final solution 0055 % uu : the final unregularized component 0056 % status : various status values 0057 % 0058 % Reference: 0059 % "Dual Augmented Lagrangian Method for Efficient Sparse Reconstruction" 0060 % Ryota Tomioka and Masashi Sugiyama 0061 % http://arxiv.org/abs/0904.0584 0062 % 0063 % Copyright(c) 2009 Ryota Tomioka 0064 % This software is distributed under the MIT license. See license.txt 0065 0066 0067 0068 function [xx, uu, status]=dal(prob, ww0, uu0, A, B, lambda, varargin) 0069 0070 opt=propertylist2struct(varargin{:}); 0071 opt=set_defaults(opt, 'aa', [],... 0072 'tol', 1e-3, ... 0073 'iter', 0, ... 0074 'maxiter', 100,... 0075 'eta', 1,... 0076 'eps', 1, ... 0077 'eps_multp', 0.99,... 0078 'eta_multp', 2, ... 0079 'solver', 'nt', ... 0080 'display',2); 0081 0082 0083 prob=set_defaults(prob, 'll', -inf*ones(prob.mm,1), ... 0084 'uu', inf*ones(prob.mm,1), ... 0085 'Ac', [], ... 0086 'bc', [], ... 0087 'info', [], ... 0088 'finddir', []); 0089 0090 if opt.display>0 0091 if ~isempty(uu0) 0092 nuu = length(uu0); 0093 vstr=sprintf('%d+%d',prob.nn,nuu); 0094 else 0095 vstr=sprintf('%d',prob.nn); 0096 end 0097 0098 lstr=func2str(prob.floss.p); lstr=lstr(6:end-1); 0099 fprintf(['DAL ver0.98d\n#samples=%d #variables=%s lambda=%g ' ... 0100 'loss=%s solver=%s\n'],prob.mm, vstr, lambda, lstr, ... 0101 opt.solver); 0102 end 0103 0104 0105 if opt.iter 0106 xx = [[ww0(:); uu0(:)], ones(length(ww0(:))+length(uu0(:)),opt.maxiter-1)*nan]; 0107 end 0108 0109 res = nan*ones(1,opt.maxiter); 0110 fval = nan*ones(1,opt.maxiter); 0111 etaout = nan*ones(1,opt.maxiter); 0112 time = nan*ones(1,opt.maxiter); 0113 pred = nan*ones(1,opt.maxiter); 0114 0115 0116 time0=cputime; 0117 ww = ww0; 0118 uu = uu0; 0119 gtmp = zeros(size(ww)); 0120 if isempty(opt.aa) 0121 aa = zeros(prob.mm,1); 0122 else 0123 aa = opt.aa; 0124 end 0125 0126 eta = opt.eta; 0127 epsl = opt.eps; 0128 info = prob.info; 0129 info.solver=opt.solver; 0130 for ii=1:opt.maxiter-1 0131 etaout(ii)=eta; 0132 time(ii)=cputime-time0; 0133 0134 %% Evaluate objective and Check stopping condition 0135 [ret,fval(ii),spec,res(ii)]=feval(prob.stopcond, ww, uu, aa, opt.tol, prob, A, B, lambda); 0136 0137 %% Display 0138 if opt.display>1 || opt.display>0 && ret~=0 0139 if ii>1 0140 fval1 = fval(ii-1)-pred(ii-1); 0141 else 0142 fval1 = nan; 0143 end 0144 nnz = full(sum(spec>0)); 0145 fprintf('[[%d]] fval=%g (pred=%g) #(xx~=0)=%d res=%g\n', ii, fval(ii), ... 0146 fval1,... 0147 nnz, res(ii)); 0148 end 0149 0150 if ret~=0 0151 break; 0152 end 0153 0154 %% Save the original dual variable for daltv2d 0155 info.aa0 = aa; 0156 0157 %% Solve minimization with respect to aa 0158 % fun = @(aa,info)prob.obj(aa, prob, ww, uu, A, AT, B, lambda, eta, info); 0159 args = {prob,ww,uu,A,B,lambda,eta}; 0160 switch(opt.solver) 0161 case {'nt','ntsv'} 0162 [aa,dfval,dgg,stat] = newton(prob.obj, aa, prob.ll, prob.uu, prob.Ac, ... 0163 prob.bc, epsl, prob.finddir, info, opt.display>2, args{:}); 0164 case 'cg' 0165 funh = {prob.hessMult,A,eta}; 0166 fh = {prob.obj, funh}; 0167 [aa,dfval,dgg,stat] = newton(fh, aa, prob.ll, prob.uu, prob.Ac, ... 0168 prob.bc, epsl, prob.finddir, info, opt.display>2, args{:}); 0169 case 'qn' 0170 optlbfgs=struct('epsginfo',epsl,'display',opt.display-1); 0171 [aa,stat]=lbfgs(prob.obj,aa,prob.ll,prob.uu,prob.Ac,prob.bc,info,optlbfgs,args{:}); 0172 case 'fminunc' 0173 optfm=optimset('LargeScale','on','GradObj','on','Hessian', ... 0174 'on','TolFun',1e-16,'TolX',0,'MaxIter',1000,'display','iter'); 0175 [aa,fvalin,exitflag]=fminunc(@(xx)objdall1fminunc(xx,prob,ww, ... 0176 uu,A,B,lambda,eta,epsl), aa, optfm); 0177 stat.info=info; 0178 stat.ret=exitflag~=1; 0179 otherwise 0180 error('Unknown method [%s]',opt.solver); 0181 end 0182 info=stat.info; 0183 0184 0185 if isfield(prob,'Aeq') 0186 gtmp(:) = [A', prob.Aeq']*aa; 0187 else 0188 gtmp(:) = A'*aa; 0189 end 0190 0191 ww1 = fevals(prob.softth, ww+eta*gtmp,eta*lambda,info); 0192 0193 %% Predicted decrease in the objective 0194 %pred(ii) = norm(ww1(:)-ww(:))^2/(2*eta); 0195 %if ~isempty(uu) 0196 %pred(ii) = pred(ii) + 0.5*eta*norm(B'*aa)^2; 0197 %end 0198 0199 %% Update primal variable 0200 ww = ww1; 0201 if ~isempty(uu) 0202 if isfield(prob,'Aeq') 0203 uu = uu+eta*B'*aa(1:end-prob.meq); 0204 else 0205 uu = uu+eta*B'*aa; 0206 end 0207 end 0208 0209 0210 %% Update barrier parameter eta and tolerance parameter epsl 0211 eta = eta*opt.eta_multp^(stat.ret==0); 0212 epsl = epsl*opt.eps_multp^(stat.ret==0); 0213 if opt.iter 0214 xx(:,ii+1)=[ww(:);uu(:)]; 0215 end 0216 end 0217 0218 res(ii+1:end)=[]; 0219 fval(ii+1:end)=[]; 0220 time(ii+1:end)=[]; 0221 etaout(ii+1:end)=[]; 0222 pred(ii+1:end)=[]; 0223 if opt.iter 0224 xx(:,ii+1:end)=[]; 0225 else 0226 xx = ww; 0227 end 0228 0229 0230 status=struct('aa', aa,... 0231 'niter',length(res),... 0232 'eta', etaout,... 0233 'pred', pred,... 0234 'time', time,... 0235 'res', res,... 0236 'opt', opt, ... 0237 'info', info,... 0238 'fval', fval);