function [bestw bestppi igmax]=ppoptimize5(x,y,varargin)
% PPOPTIMIZE runs optimalization proces of PPI function for given dataset X and labels Y.
%
%   W = PPOPTIMIZE4(X,Y,PARAMETERS,VALUE,...)
%	searching for optimal W witch maximize PP index value
%	X : dataset
%	Y : labels
%
%	posible PARAMETERS are:
%	   Function:
%		'function' - 'triangle','fx4' or 'bicentral' are recognized
%		           as f_triangular, f_x4 and f_bicentral function accordinly (default f_x4)
%		'beta'     - value of BETA parameter - width of given 'function'
%		           (default 2, seams to be good value for normalized data and f_x4 function)
%
%     Optimalization:
%		'maxIterations'	  - maximum number of iterations (default 1000)
%       'initiations'     - number of repetitions of optimalization, best
%                       solusion over those initiations is choosen at the
%                       end of training
%		'learningRate'    - learning rate for gradnient descent procedure (default 0.1)
%		'initWeights'     - initial weights (staring point)
%       'stopCriterium'   - choose criterium for ending of optimalization (default 2), 
%                         currently implemented are:
%                         1 - check difference beetwen two last values,
%                         if |PPI(t)-PPI(t-1)| < eps then stop
%                         2 (default) check difference beetwen average
%                         values over 'checkPeriod' last values
%                         
%       'checkPeriod'     - if 'stopCriterium' = 2 then optimalization
%                         stops when average value of ppifunction over checkPeriod last
%                         values is considered. If difference beetwen two last average values
%                         is less (default: 5)
%       'eps'             - set eps value (default 0.001), for given stoping
%		                  criterium define max. difference between values of (or average values over checkPeriod) PPIFUNCTION
%       
%     Searching for ortogonal solusion:
%        'ortoWeights'    - weights for witch ortgonal solution will be searched  (default [])
%        'lambda'         - value of LAMBDA parameter controls an influence of ortogonalization process 
%
%     Other optimalization parameters;
%        'iGmax'          - index of input vector for with maximum value of G
%                         function will be searched, required in secound
%                         stage of learning (not working right now)
%
%     Plotting:
%		 'plot'           - currently implemented plots: (default 'none') 
%                        'all' - plots all on a single figure for each
%                        iteration (animated plot)
%                        'ppi' - plots PPIFUNCTION for each iteration,
%                        'last' - plots only final best projection 
%      
%     Results displayng: 
%       not done yet
%     Results saving:
%        'savedir'        - name of directory to store logfiles and
%                         pictures created douring learning process.
%                         Default dir name is: ppi-results-yyyy.mm.dd,
%                         where yyyy.mm.dd is the current date.
%        'save'        -  'all', 'last' - saves  pictures in SAVEDIR with
%                          apropriate names (default 'none')
%                         directory. Requred for animations but disk space
%                         consuming. Note thet if SAVEOALL = YES then 
%                         option PLOT is automatic set to ALL.
%        'logFile'        - name of file with stored results (default: data.log if no dataname is given)
%        'dataName'       - name of data used in names of output files
%                         (logs and pictures), (default: data)
%
%	e.g.   w=ppoptimize(x,y,'function','triangular','beta',3,'plot','ppi')
%		use triangular function with beta = 3 and plot change of ppi function douring learning (function width)
%
% TODO : 
%   * find smarter and faster way for automatic estimating sigma value in
%   bgraph3.m plot
%   * document all parameters and options (almost done)
%   * simple gui for controling and visualization of optimalization
%   * new pictures - G values, weights values as bar plot, ... 
%   * implementation of diferent optimalization methods: e.g. sim. aneling
%   (mihgt be done explicitly by MATLAB buldin functions and ppifunction as function to optimize)
%   MATLAB tricks to generate animations 
%    * poprawic zapis logow do plikow
%
%  DONE
%   * multistart, find a way to choose most promisses initializations after
%   fiew iterations (done, see ppoptimize5.m)
%   * learning of two directions (see ppi2woptimise.m)
%   * pictures still looks ugly - sometimes labels are missing (done)
%   * saving results: pictures and log files - add parameters (done)
%   * add scaterplot (done)
%   * improve searching of ortogonal directions (ortogonalization  included into ppifunction.m)

param = inputParser;
% data
param.addRequired('x',@isnumeric);
param.addRequired('y',@isnumeric);

% inner function
param.addOptional('function','f_x4',@(x)any(strcmpi(x,{'triangle','f_x4','fx4','bicentral'})));
param.addOptional('beta',2, @(x)isnumeric(x) && x>0);


% optimalization global
param.addParamValue('learningRate',0.1,@(x)isnumeric(x) && x >= 0);
param.addParamValue('eps',0.001,@(x)isnumeric(x) && x>0);
param.addParamValue('maxIterations',2000,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('killPeriod',10,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('killRatio',0.5,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('initiations',20,@(x)isnumeric(x) && x > 0 && mod(x,1)==0);
param.addParamValue('checkPeriod',5,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('stopCriterium',2,@(x)x==1|| x==2);

% optimalization - case depend
param.addParamValue('initWeights', [],@(x)isnumeric(x));
param.addParamValue('lambda', 0.1,@(x)isnumeric(x) && x > 0 );
param.addParamValue('ortoWeights',[],@isnumeric);
param.addParamValue('indGmax',[],@isnumeric);

% loging
param.addParamValue('logFile',[],@ischar);
param.addParamValue('dataName','data',@ischar);
param.addParamValue('save','none',@(x)any(strcmpi(x,{'none','all','last'})));
param.addParamValue('savedir',[],@ischar);
param.addParamValue('display','none',@(x)any(strcmpi(x,{'none','all','short'})));

% plotting
param.addParamValue('plot','none',@(x)any(strcmpi(x,{'none','all','ppi','last'})));

param.parse(x,y,varargin{:});
%fprintf('Input parameters:\n\n');
%disp(param.Results);
%disp(param.Parameters);


[vx fx]=size(x);

beta        = param.Results.beta;
lrate       = param.Results.learningRate;	% learning rate (step of gradnient descent)
eps         = param.Results.eps;	% 
nmax        = param.Results.maxIterations;     % nax. number of iterations
killPeriod  = param.Results.killPeriod;
%killRatio  = param.Results.killRatio;
leaven      = 10;
%pplot      = 0;
%lastppi    = -1;
%plotall    = 0;
%plotlast   = 0;
iplot       = 0;
% ffplot    = 0;
ninit       = param.Results.initiations;
ww          = param.Results.initWeights;
dataname    = param.Results.dataName;
avgtest     = param.Results.checkPeriod;
stopcriterium = param.Results.stopCriterium;
iGmax       = param.Results.indGmax;
wort        = param.Results.ortoWeights;
lambda      = param.Results.lambda;
orto        = ~isempty(wort);
funcname    = param.Results.function;
alpha       = 0.95;   % prog ortogonalnosci, wektory o wiekszym iloczynie skalarnym traktujemy jako podobne
%wo         = 0;
%saveall    = 0;
%savelast   = 0;
savedir     = strcat('ppi-results-',datestr(now,'yyyy.mm.dd'));
%display    = 1;
%procedure  = '';

%bestw = [];
%bestppi = [];
%bestinit = 0;
%bestn = 0;
%bestigmax = -1;

switch funcname
    case { 'fx4' , 'f_x4'}
        func = @(xx)f_x4(xx,beta);
    case 'triangle'
        func = @(xx)f_triangular(xx,beta);
    case 'bicentral'
        func = @(xx)f_bicentral(xx,beta,0,10);
    otherwise
        error('Ten blad nie powinien wystapic - ale jesli wystapil to znaczy, ze jest cos nie tak z podana funkcja');
end

switch param.Results.plot
    case 'all'
        plotall = 1;
    case 'ppi'
        iplot = 1;
    case 'last'
        plotlast = 1;
end
switch param.Results.display
    case 'all'
        display = 2;
    case 'short'
        display = 1;
    case 'none'
        display = 0;
end

if ~isempty(param.Results.savedir)
    savedir = param.Results.savedir;
end

switch param.Results.save
    case 'all'
        plotall = 1;
        saveall = 1;
    case 'last'
        plotlast = 1;
        savelast = 1;
    case 'iplot'
        iplot = 1;
end

[s comment]=mkdir(savedir);
if s == 0
    error(comment);
end

if isempty(iGmax) 
    if orto == 0
        procedure = 'ppi';
        ppifun = @(wx)ppifunction(x,y,wx,func);
    else
        procedure = strcat('ppi-orto',sprintf('.%0.1f',lambda));
        wort=wort/norm(wort);
        ortoproj = x*wort';
        ppifun = @(wx)ppifunction(x,y,wx,func,wort,lambda);
    end
else 
    procedure = strcat('ppi-gmax',sprintf('.%d',iGmax));
    ppifun = @(wx)ppigmaxfunction2(x,y,wx,iGmax,func);
end

prefix = strcat(savedir,'/',dataname,sprintf('.%s-%s.%0.1f',procedure,funcname,beta));

logfilename = param.Results.logFile;
if isempty(logfilename)
    logfilename = strcat(prefix,'.log');
end

logfile = fopen(logfilename,'wt');
if display > 0
    fprintf(logfile,'procedure     = %s\n',procedure);      fprintf('procedure     = %s\n',procedure);
    fprintf(logfile,'dataname      = %s\n',dataname);       fprintf('dataname      = %s\n',dataname);
    fprintf(logfile,'vectors       = %d\n',vx);             fprintf('vectors       = %d\n',vx);
    fprintf(logfile,'features      = %d\n',fx);             fprintf('features      = %d\n',fx);
    fprintf(logfile,'learningRate  = %f\n',lrate);          fprintf('learningRate  = %f\n',lrate);
    fprintf(logfile,'eps           = %f\n',eps);            fprintf('eps           = %f\n',eps);
    fprintf(logfile,'maxIterations = %d\n',nmax);           fprintf('maxIterations = %d\n',nmax);
    fprintf(logfile,'function      = %s\n',func2str(func)); fprintf('function      = %s\n',func2str(func));
    fprintf(logfile,'ppi function  = %s\n',func2str(ppifun));fprintf('ppi function  = %s\n',func2str(ppifun));
    fprintf(logfile,'beta          = %f\n',beta);           fprintf('beta          = %f\n',beta);
    fprintf(logfile,'initiations   = %f\n',ninit);          fprintf('initiations   = %f\n',ninit);

    if orto == 1
        fprintf(logfile,'ortogonal W   = ');
        fprintf(logfile,'%f  ',wort);
        fprintf(logfile,'\n');
        fprintf(logfile,'lambda        = %f\n',lambda);
    end
    if ~isempty(iGmax)
        fprintf(logfile,'G max index   = %d\n',iGmax);
    end
end
%%%%%%%%%%%%%%   OK - tu zaczyna sie zabawa %%

%for initcount=1:ninit
%	fprintf(logfile,'Statring\nInitialization %d of %d\n',initcount,ninit);
%	fprintf('Statring\nInitiatlization %d of %d\n',initcount,ninit);
	
    if isempty(ww)
        w = rand(ninit,fx)*2-1;
        % random initialization [-1,1]
    else
		w = ww;
    end
        
	%w0 = w;
	n = 0;
	ppitable = zeros(ninit,nmax+1);
    
    
%	fprintf('N   PPIndex     weights');      fprintf('\n');
%   fprintf('%d  %6.4f    ',n,ppi); fprintf('  %6.4f',w); fprintf('\n');
%	fprintf(logfile,'N   PPIndex     weights\n');
% 	fprintf(logfile,'%d  %10.6f    ',n,ppi); fprintf(logfile,'  %6.4f',w); fprintf(logfile,'\n');
	
    avgppi = zeros(1,ninit);
    lastavgppi = zeros(1,ninit);
 %   lastn = ones(1,ninit).*killPeriod;
 
    
 %   lastppi = [];
    initindex = 1:ninit;
 %   leaveindex = [];
    % while ( n < nmax )
    % inicjacja
    for k=initindex
        w(k,:)=w(k,:)/norm(w(k,:));
    end

    [ppi ppid projection G] = ppifun(w(initindex,:));
    ppiinit = ppi;
    ppitable(initindex,1) = ppiinit;
    lastppi = ppiinit;
    lastn = ones(1,ninit).*killPeriod;
    if (leaven > ninit)
        leaven= ninit;
    end
    
%   figure(1);
%   cla;
    while ( 1 )
        w(initindex,:) = w(initindex,:) + lrate * ppid;
        
        % normalization
        for k=initindex
            w(k,:)=w(k,:)./norm(w(k,:));
        end
        
        n = n + 1;
        lastppi(initindex) = ppi;
        
        if ( n == killPeriod + 1 ) % it's killing time
            [ia ib]=sort(ppi - ppiinit,'descend');  % is growing fast ?
            [ic id]=sort(ppi,'descend');            % is large enought?
            
            initindex1 = initindex(ib(1:leaven));
            initindex2 = initindex(id(1:leaven));
            initindex = [initindex1 initindex2];
            %disp(w(initindex,:)*w(initindex,:)')
            [ie ig]=sort(ppi(initindex),'descend');
            leaven = length(ie); 
            select=ones(1,leaven);
            ik = 1;
            for k=2:leaven
                if ( abs((w(initindex(ig(ik)),:)* w(initindex(ig(k)),:)')) > 0.95) 
                    select(k) = 0;
                else
                    ik = k;
                end
            end
            initindex = initindex(ig(select == 1));
            bestindex = initindex;
            avgppi = avgppi(initindex);
            lastavgppi = lastavgppi(initindex);
%            fprintf('%d best initiations\n',length(bestindex));
        end

        if ( mod(n,avgtest) == 0 )
            lastavgppi = avgppi;
            avgppi = mean(ppitable(initindex,n-avgtest+1:n),2);
        end;
        
        [ppi ppid projection G] = ppifun(w(initindex,:));
        ppitable(initindex,n+1) = ppi;

        if nargout > 2
            [gmax igmax] = max(G,[],2);
        end

        if iplot == 1 && n > 0
%			plotppi(ppitable,n,avgtest);
			plotppi2(lastppi(initindex)',ppi',n,avgtest);
            drawnow;
        end
        
   		%fprintf('%3d ',n); 
%        fprintf('%%'); 
        %fprintf('%10.6f  ',ppi);  fprintf(logfile,'%d  %10.6f    ',n,ppi); 
%        fprintf('  %6.4f',w);     fprintf(logfile,'  %6.4f',w); 

%         if orto == 1
%             w1 = w*wort';
%             fprintf('  [ %6.4f ] ',w1);
%             fprintf(logfile,'  [ %6.4f ] ',w1);
%         end


%       fprintf('\n'); fprintf(logfile,'\n');
%         if plotall == 1
%             str=cell(1,3);
%             clf;
%             set(gcf,'Color','w');
%             if orto == 0
%                 
%                 str{1} = strcat('w = ',sprintf(' %.2f ',w));
%                 str{2} = strcat('I = ',sprintf(' %.5f ',ppitable(n+1)));
%                 str{3} = strcat('N = ',sprintf(' %d ',n));
%  
%                 ah = bgraph3(projection,y,'position',[0.05 0.4 0.9 0.50],'function',func);
% 
%                 axes(ah(1));
%                 text(0.01,0.95,str,'units','normalized');
% 
%             else
%                 str{2} = strcat('\alpha = ',sprintf(' %.4f ',w1));
%                 str{1} = strcat('I = ',sprintf(' %.4f ',ppitable(n+1)));
%                 str{3} = strcat('N = ',sprintf(' %d ',n));
%                 
%                 
%                 ah = scaterplot([projection ortoproj],y);
%                 
%                 set(gca,'position',[0.1 0.45 0.85 0.50]);
%                 xlabel(strcat('w2 = ',sprintf(' %.2f ',w)));
%                 ylabel(strcat('w1 = ',sprintf(' %.2f ',wort)));
%                 
%                 text(0.02,0.85,str,'units','normalized');
% 
%             end
%             ah(3) = axes('position',[0.1 0.1 0.85 0.25]);
%             plotppi(ppitable,n,50);
%             drawnow();
%             
%             if saveall == 1
%                 saveplot(strcat(prefix,sprintf('-%d.frame%04d',initcount,n)));
%             end
%        end;
%        if (stopcriterium == 1 && (abs((ppi - lastppi)./lastppi) < eps) && n > 1 ) 
%            break;
%        end;
        if (n > killPeriod  && stopcriterium == 2)
            select = abs((avgppi - lastavgppi)./lastavgppi) > eps;
            if nnz(~select) > 0
                if display > 0
                    fprintf('Finish N = %d',n); fprintf('   I = %f ',ppi(~select)); fprintf('[%d ]',initindex(~select)); fprintf('\n'); disp(w(initindex(~select),:));
                end
                lastn(initindex(~select)) = n;
                initindex=initindex(select);
                
                ppid = ppid(select,:);
                ppi = ppi(select);
                avgppi = avgppi(select);
                lastavgppi = lastavgppi(select);
            end
        end;
        if ( n >= nmax || isempty(initindex)) 
            break
        end
    end;

    if (~isempty(initindex))
        lastn (initindex) = n;
        if display > 0
            fprintf('%d still not coverge\n',length(initindex));
        end
    end
    
    % koniec obliczen, 
    
%    disp('Removing relevant projections');
    
    [ie ig]=sort(lastppi(bestindex),'descend');
    leaven = length(ie); 
    select=ones(1,leaven);
    ik = 1;
    for k=2:leaven
        if ( abs((w(bestindex(ig(ik)),:)* w(bestindex(ig(k)),:)')) >  alpha) 
            select(k) = 0;
        else
            ik = k;
        end
    end
    finalindex = bestindex(ig(select == 1));
%    disp((w(finalindex,:)* w(finalindex,:)'));
    bestw = w(finalindex,:);
    bestppi = lastppi(finalindex);

    if display > 0
        fprintf('Leving %d best solutions\n',length(finalindex));
    end
%     figure(2);
%     cla;
%     if length(finalindex) == 1
%         bgraph3(x*w(finalindex,:)',y,'sigma',0.1);
%     else
%         scaterplot(x*w',y,[],finalindex);
%     end
    
    
% 	if plotlast == 1
%             str=cell(1,3);
%             clf;
%             set(gcf,'Color','w');
%             if orto == 0
%                 
%                 str{1} = strcat('w = ',sprintf(' %.2f ',w));
%                 str{2} = strcat('I = ',sprintf(' %.5f ',ppitable(n+1)));
%                 str{3} = strcat('N = ',sprintf(' %d ',n));
%  
%                 ah = bgraph3(projection,y,'position',[0.05 0.4 0.9 0.50],'function',func);
% 
%                 axes(ah(1));
%                 text(0.01,0.95,str,'units','normalized');
% 
%             else
%                 str{2} = strcat('\alpha = ',sprintf(' %.4f ',w1));
%                 str{1} = strcat('I = ',sprintf(' %.4f ',ppitable(n+1)));
%                 str{3} = strcat('N = ',sprintf(' %d ',n));
%                 
%                 
%                 ah = scaterplot([projection ortoproj],y);
%                 
%                 set(gca,'position',[0.1 0.45 0.85 0.50]);
%                 xlabel(strcat('w2 = ',sprintf(' %.2f ',w)));
%                 ylabel(strcat('w1 = ',sprintf(' %.2f ',wort)));
%                 
%                 text(0.02,0.85,str,'units','normalized');
% 
%             end
%             ah(3) = axes('position',[0.1 0.1 0.85 0.25]);
%             plotppi(ppitable,n,50);
         %   ga = axes('Position',get(ah(1),'Position'));
         %   set(ga,'YAxisLocation','right','Color','none','XTickLabel',[]);
         %   set(ga,'XLim',get(ah(1),'XLim'),'Layer','top');
         %   plot(projection,G,'x');
%            drawnow();
            
            
%            if savelast == 1
%                saveplot(strcat(prefix,sprintf('-%d.last',initcount)));
%            end       
%	end;
	
	fprintf(logfile,'---------------------------------------------\n\n');
%	if (initcount == 1) || (ppi > bestppi) 
%		bestw = w;
%		bestppi = ppi;
%		bestinit = initcount;
%		bestn = n;
 %       if nargout > 2
 %          bestigmax = igmax;
 %       end
	%end
%end

%bestw = w(initindex,:);
%fprintf('Initialization  %d was the best\n',bestinit);
%fprintf(logfile,'Initialization  %d was the best\n',bestinit);
%fprintf('%d  %6.4f    ',bestn,bestppi); fprintf('  %6.4f',bestw); fprintf('\n');
%fprintf(logfile,'%d  %6.4f    ',bestn,bestppi); fprintf(logfile,'  %6.4f',bestw); fprintf(logfile,'\n');

fclose(logfile);

function saveplot(prefix)
 	name = strcat(prefix,'.png');
 	print('-dpng','-r96',name);
% % octave only
% %	print('-dpng','-S640,480',name);
 	name = strcat(prefix,'.eps');
 	print('-depsc',name);

% function plotppi(ppitable,n,scalefactor)
%     if nargin < 3
%         scalefactor = 10;
%     end
%     if nargin < 2
%         n = length(ppitable);
%     end
%     xl = scalefactor.*(fix(n./scalefactor)+1);
%     cla;
%     hold on;
%     xlim([0 xl]);
%     xlabel('Iterations');
%     ylabel('PPI value');
%     plot(0:n,ppitable(:,1:n+1),'-b','LineWidth',2);
%     box on;
%  	hold off;

    
function plotppi2(ppilast,ppi,n,scalefactor)
    if nargin < 3
        scalefactor = 10;
    end
    xl = scalefactor.*(fix(n./scalefactor)+1);
    hold on;
    xlim([0 xl]);
    xlabel('Iterations');
    ylabel('PPI value');
    plot([n-1 n],[ppilast; ppi],'-b','LineWidth',2);
    box on;
