Home > dal > objdall1.m

objdall1

PURPOSE ^

objdall1 - objective function of DAL with L1 regularization

SYNOPSIS ^

function varargout=objdall1(aa, info, prob, ww, uu, A, B, lambda, eta)

DESCRIPTION ^

 objdall1 - objective function of DAL with L1 regularization

 Copyright(c) 2009 Ryota Tomioka
 This software is distributed under the MIT license. See license.txt

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % objdall1 - objective function of DAL with L1 regularization
0002 %
0003 % Copyright(c) 2009 Ryota Tomioka
0004 % This software is distributed under the MIT license. See license.txt
0005 
0006 function varargout=objdall1(aa, info, prob, ww, uu, A, B, lambda, eta)
0007 
0008 m = length(aa);
0009 n = length(ww);
0010 vv = A'*aa+ww/eta;
0011 
0012 Ip = find(vv>lambda);
0013 In = find(vv<-lambda);
0014 
0015 if nargout<=3
0016   [floss, gloss, hmin]=feval(prob.floss.d,aa, prob.floss.args{:});
0017 else
0018   [floss, gloss, hloss, hmin]=prob.floss.d(aa, prob.floss.args{:});
0019 end
0020 
0021 
0022 vsth = l1_softth(vv,lambda);
0023 
0024 
0025 fval = floss+0.5*eta*sum(vsth.^2);
0026 if ~isempty(uu)
0027   u1   = uu/eta+B'*aa;
0028   fval = fval + 0.5*eta*sum(u1.^2);
0029 end
0030 
0031 varargout{1}=fval;
0032 
0033 
0034 if nargout<=2
0035   varargout{2} = info;
0036 else
0037   gg  = gloss+eta*(A*vsth);
0038   soc = sum((vsth-ww/eta).^2);
0039   if ~isempty(uu)
0040     gg  = gg+eta*(B*u1);
0041     soc = soc+sum((B'*aa).^2);
0042   end
0043 
0044   if soc>0
0045     info.ginfo = norm(gg)/(sqrt(eta*hmin*soc));
0046   else
0047     info.ginfo = inf;
0048   end
0049   varargout{2} = gg;
0050 
0051   if nargout==3
0052     varargout{3} = info;
0053   else
0054     I = sort([Ip; In]);
0055     AF = A(:,I);
0056 
0057     switch(info.solver)
0058      case 'cg'
0059       prec=hloss+spdiag(eta*sum(AF.^2,2));
0060       if ~isempty(uu)
0061         prec =prec+spdiag(eta*sum(B.^2,2));
0062       end
0063       varargout{3} = struct('hloss',hloss,'AF',AF,'I',I,'n',n,'prec',prec,'B',B);
0064      otherwise
0065       if length(I)>0
0066         varargout{3} = hloss+eta*AF*AF';
0067 % $$$         sp=svd(varargout{3});
0068 % $$$         cond=max(sp)/min(sp);
0069 % $$$         fprintf('cond=%g\n',cond);
0070       else
0071         varargout{3} = hloss;
0072       end
0073       if ~isempty(uu)
0074         varargout{3} = varargout{3}+eta*B*B';
0075       end
0076     end
0077     varargout{4} = info;
0078   end
0079 end
0080

Generated on Sat 22-Aug-2009 22:15:36 by m2html © 2003