function [x,f,g,hist,histx] = gradient_descent_2(fg,x0,varargin) % GRADIENT_DESCENT_2 A simple implementation of pure gradient descent % % x = gradient_descent_2(f,x0) computes x, the argmin of f using a pure % gradient descent algorithm. This routine SHOULD NOT be used for anything % other than pedagogical purposes as it does not include any technique to % make gradient descent a practical algorithm. % % The function f(x0) must return both the function value and the gradient % at x0, i.e. [fx,gx] = f(x0); See FCOMBINE for a helper function or % ROSENBROCK for an example. % % [x,f,g,hist,histx] = gradient_descent_2(...) returns the final function % value, final gradient, a history of iteration data, and a history of % all the iterates x. Memory for histx is only allocated if this output is % requested as it could grow rather large. % % The function call % [...] = gradient_descent_2(f,x0,'Key',Value,'Key',Value) % provide optional arguments to the function that all have reasonable % default values set. % % The options are: % % 'maxiter' : the maximum number of iterations/function evaluations % 'tol' : the stopping tolerance in terms of the infinity norm of the % gradient % 'quiet' : do not display iteration output % % Example: % x = gradient_descent_1(@rosenbrock,[1+0.1;1+0.1]); % does not converge if numel(varargin)>0 if isstruct(varargin{1}) opts = varargin{1}; varargin = varargin(2:end); % pop the first element else opts = struct(); end end p = inputParser; p.addParamValue('maxiter',10000,@(x) isnumeric(x) && x >= 0); p.addParamValue('tol',1e-8, @(x) isnumeric(x) && x >= 0); p.addParamValue('quiet',0); p.parse(varargin{:}); opts = p.Results; x = x0; n = numel(x); hist = zeros(2,opts.maxiter); savehistx = nargout > 4; if savehistx histx = zeros(n,opts.maxiter); end f = Inf; if ~opts.quiet fprintf(' %6s %9s %9s %9s\n', 'iter', ... 'fval', 'normg', 'fdiff'); end for iter=1:opts.maxiter if savehistx histx(:,iter) = x; end if iter>1 x = x - g/norm(g); end flast = f; [f,g] = fg(x); normg = norm(g,'inf'); fdiff = flast - f; if ~opts.quiet fprintf(' %6i %9.2e %9.2e %9.2e\n', iter, ... f, normg, fdiff) end hist(:,iter) = [f; normg]; if normg <= opts.tol, break; end end if iter < opts.maxiter hist = hist(:,1:iter); if savehistx histx = hist(:,1:iter); end end if normg > opts.tol warning('Did not converge'); end