function [Us,z,y,output] = diffrac(x,kkk,varargin);
addpath utilities

% DIFFRAC
% Run diffrac on some data
%
%
% REQUIRED PARAMETERS
% x:        n x d data matrix - but potentially kernel matrix or matrix A
% kkk:      number of desired clusters
%
% OPTIONAL PARAMETERS
% 'kernel_type'         kernel type
%                       'linear'        = linear kernel (default)
%                       'rbf'           = Gaussian RBF
%                       'kernel_matrix' = kernel matrix directly
% 'kernel_parameters'   kernel parameters = alpha
%                       gaussian kernel -> exp(-alpha ||x-y||^2)
% 'lambda'              regularization paramter (default=1e-6)
% 'df'                  degrees of freedom: alternative way of setting
%                       regularization parameter
% 'labelled_option'     1 -> 'positive constraints' will be used
%                       2 -> 'labelled' will be used (SSL)
% 'labelled'                    vector of labels in {1,...,kkk,NaN}
% 'positive_constraints'        set of pairs of positive constraints
% 'lambda0'             minimum value of each cluster
%                       default = round(n/2/kkk)
% 'display'             1 -> verbose, 0 -> no display

x = x';
[ d n ] = size(x);


kernel_parameters = [];         % kernel parameters
positive_constraints = [];      % set of pairs of positive constraints
negative_constraints = [];      % set of pairs of negative constraints
if kkk>2
    subsample  = min(max(10,round(sqrt(n)*2)),n);       % number of subsample for pointwise positivity constraint
else
    subsample = 0;
end
subsample_seed = 1;             % random seed generator for subsample
restart = 100;                  % number of restarst for Kmeans rounding
lambda0     = round(n/2/kkk);   % minimum number of elements per clusters
kernel_type = 'linear';         % type of kernel: 'linear'
lambda      = 1e-6;             % regularization parameter
df = [];                        % degrees of freedom: overrules lambda

labelled = [];                  % vector of labels in {1,...,kkk,NaN}
% all labels must be present!!
% SSL: usual semi-supervised classification
% will make sure that classes matches
labelled_option = 1;
% 1: only positive constraints
% 2: preclustering, no negative constraints

% OPTIONAL OPTIMIZATION PARAMETERS (DO NOT TOUCH...)
kmax = 200;                     % maximal number of parameters
warm_restart = [];              % starting matrix U or DUAL matrix
eps_dual = 1e-3;                % smoothing parameter for dual optimization
missing_constraints = [0 0 0];  % missing constraints (i.e., N,b,mu) -> for debugging
display = 1;                    % verbose

% READ OPTIONAL PARAMETERS
args = varargin;
nargs = length(args);
for i=1:2:nargs
    switch args{i},
        case 'display',        display = args{i+1};
        case 'lambda0',     lambda0 = args{i+1};
        case 'lambda',     lambda = args{i+1};
        case 'kernel_type',  kernel_type = args{i+1};
        case 'kernel_parameters',  kernel_parameters = args{i+1};
        case 'positive_constraints',  positive_constraints = args{i+1};
        case 'subsample',        subsample = args{i+1};
        case 'subsample_seed',        subsample_seed = args{i+1};
        case 'restart',             restart = args{i+1};
        case 'df',        df = args{i+1};
        case 'warm_restart',        warm_restart = args{i+1};
        case 'eps_dual',        eps_dual = args{i+1};
        case 'missing_constraints',        missing_constraints = args{i+1};
        case 'labelled',        labelled = args{i+1};
        case 'labelled_option',        labelled_option = args{i+1};
        case 'kmax',        kmax = args{i+1};
    end
end
D = ones(n,1);





% BUILD KERNEL MATRIX
switch kernel_type
    case 'linear'
        xtilde = x - repmat(mean(x,2),1,size(x,2));
        %[u,e] = eig(xtilde * xtilde');
        [u,e,v] = svd(xtilde,'econ');
        e=e.^2;
        if ~isempty(df);
            lambda = df_to_lambda_eig(df,real(diag(e))/n + 1e-10);
            lambda = 1/lambda;
        end
        ind = find( diag(e) > n*lambda*1e-2);
        q = u(:,ind);
        xtilde = q' * xtilde;
        C = xtilde * xtilde' + n * lambda * eye(size(xtilde,1));
        R = chol(C);
        R = inv(R)';
        xtilde = R * xtilde;
        trA = n-1 - sum( xtilde(:).^2 );
        xtilde0 = xtilde;
        df = sum( xtilde(:).^2 );


    case 'rbf'
        if ~isempty(df);
            % df is given
            mmax = round(min(1000,4*df));   % maximum size for the incomplete Cholesky
            [G,P,m] = icd_gauss(x,kernel_parameters,n*1e-8,min(n,mmax)); % perform largest ICD to get lambda
            I = P(1:m);
            [temp,Pi] = sort(P);
            G = G(Pi,1:m);

            xtilde = G';
            xtilde = xtilde - repmat(mean(xtilde,2),1,size(xtilde,2));
            % [u,e] = eig(xtilde * xtilde');
            [u,e,v] = svd(xtilde,'econ');
            e=e.^2;
            ind =find( real(diag(e))/n > 1e-10 );
            if df >= length(ind),
                % df is too big
                lambda = 1e-10;
            else
                lambda = df_to_lambda_eig(df,real(diag(e(ind,ind)))/n);
                lambda = 1/lambda;
            end
            ind = find( real(diag(e)) > n*lambda*1e-2);
            q = u(:,ind);
            xtilde = q' * xtilde;
            C = xtilde * xtilde' + n * lambda * eye(size(xtilde,1));
            R = chol(C);
            R = inv(R)';
            xtilde = R * xtilde;
            trA = n-1 - sum( xtilde(:).^2 );

        else
            % lambda is given
            [G,P,m] = icd_gauss(x,kernel_parameters,n*lambda*1e-2,n);
            I = P(1:m);
            [temp,Pi] = sort(P);
            G = G(Pi,1:m);
            G = G';
            G = G - repmat(mean(G,2),1,size(G,2));
            C = G * G' + n * lambda * eye(size(G,1));
            try
                R = chol(C);
            catch
                % augment lambda if not positive enough and error is produced
                R = chol(C + norm(G,'fro')^2 * 1e-10 * eye(size(G,1)));
                warning('lambda was augmented');
            end
            R = inv(R)';
            xtilde = R * G;
            trA = n-1 - sum( xtilde(:).^2 );
        end
        xtilde0 = xtilde;
        df = sum( xtilde(:).^2 );



    case 'kernel_matrix',
        K = x;

        if ~isempty(df);
            % df is given
            mmax = round(min(1000,4*df));   % maximum size for the incomplete Cholesky
            [G,P,m] = icd_full(K,n*1e-8,min(n,mmax)); % perform largest ICD to get lambda
            I = P(1:m);
            [temp,Pi] = sort(P);
            G = G(Pi,1:m);

            xtilde = G';
            xtilde = xtilde - repmat(mean(xtilde,2),1,size(xtilde,2));
            % [u,e] = eig(xtilde * xtilde');
            [u,e,v] = svd(xtilde,'econ');
            e=e.^2;
            ind =find( real(diag(e))/n > 1e-10 );
            if df >= length(ind),
                % df is too big
                lambda = 1e-10;
            else
                lambda = df_to_lambda_eig(df,real(diag(e(ind,ind)))/n);
                lambda = 1/lambda;
            end
            ind = find( real(diag(e)) > n*lambda*1e-2);
            q = u(:,ind);
            xtilde = q' * xtilde;
            C = xtilde * xtilde' + n * lambda * eye(size(xtilde,1));
            R = chol(C);
            R = inv(R)';
            xtilde = R * xtilde;
            trA = n-1 - sum( xtilde(:).^2 );

        else
            % lambda is given
            [G,P,m] = icd_full(K,n*lambda*1e-2,n);
            I = P(1:m);
            [temp,Pi] = sort(P);
            G = G(Pi,1:m);
            G = G';
            G = G - repmat(mean(G,2),1,size(G,2));
            C = G * G' + n * lambda * eye(size(G,1));
            try
                R = chol(C);
            catch
                % augment lambda if not positive enough and error is produced
                R = chol(C + norm(G,'fro')^2 * 1e-10 * eye(size(G,1)));
                warning('lambda was augmented');
            end
            R = inv(R)';
            xtilde = R * G;
            trA = n-1 - sum( xtilde(:).^2 );
        end
        xtilde0 = xtilde;
        df = sum( xtilde(:).^2 );


end

if (display)
    fprintf('DEGREES OF FREEDOM = %f\n', df);
    fprintf('OPTIMIZATION DIMENSION = %d\n',size(xtilde,1));
end
output.df = df;
output.lambda = lambda;

% CONSTRUCT SUBSAMPLE FOR POSITIVE CONSTRAINTS
rand('state',subsample_seed);
randn('state',subsample_seed);
temp = subsample;
subsample = randperm(n);
subsample = subsample(1:temp);


if labelled_option==1
    if ~isempty(labelled)
        z0max = kkk;
        labelled_examples = cell(1,z0max);
        positive_constraints = [];
        for k=1:z0max
            labelled_examples{k} = find(labelled==k)';
            if isempty(labelled_examples{k}), error('must have at least one point labelled for each class'); end
            positive_constraints = [ positive_constraints ; ...
                [ labelled_examples{k}(1) * ones( length(labelled_examples{k})-1,1 ), labelled_examples{k}(2:end)' ] ];

        end
        negative_constraints =[];
        if labelled_option==1
            for k1=1:kkk-1
                for k2=k1+1:kkk
                    negative_constraints = [ negative_constraints; [ labelled_examples{k1}(1), labelled_examples{k2}(1)] ];
                end
            end

        end

    end



    % POSITIVE CONSTRAINTS
    % transform pairs of matching points to chunks
    pairs = positive_constraints;
    D0=D;
    if ~isempty(positive_constraints)
        Mc = sparse( pairs(:,1),pairs(:,2),ones(size(pairs,1),1),n,n);
        Mc = Mc + Mc';
        Mc = Mc - diag(diag(Mc));
        Mc = Mc + speye(n);             % Mc is the affinity matrix of the chunks
        if ~isempty(labelled)
            % explicit transitive closure in pure SSL settings
            Mcr = Mc;
            for k=1:z0max
                Mcr(labelled_examples{k},labelled_examples{k})=1;
            end
        else
            Mcr = reachability_graph(Mc);   % transitive closure
        end
        [chunks,chunksize,reverse_index] = affinity_to_chunks(Mcr);
        PC = sparse(n,length(chunks));
        D = zeros(length(chunks),1);
        for ic=1:length(chunks)
            PC(chunks{ic},ic)=1;
            D(ic) = sum(D0(chunks{ic}));
        end
        switch  kernel_type,
            case {'rbf_full','kernel_matrix'},
                A = PC' * ( A * PC );
                trA = trace(A);
            otherwise
                xtilde = xtilde * PC;
                trA = sum(D) - 1/sum(D) * sum(D.^2)  - sum( xtilde(:).^2 );
        end

        % updates the negative constraints
        if ~isempty(negative_constraints),
            negative_constraints = reverse_index(negative_constraints);

            if size(negative_constraints,2)==1, negative_constraints=negative_constraints'; end
            if any( negative_constraints(:,1)==negative_constraints(:,2) ), error('INCONSISTENT CONSTRAINTS'); end
            negative_constraints= unique(negative_constraints,'rows');
        end

        % update the pointwise positivity constraints
        subsample = unique(reverse_index(subsample));

    end
    if ~isempty(negative_constraints),
        if any( negative_constraints(:,1)==negative_constraints(:,2) ), error('INCONSISTENT CONSTRAINTS'); end
    end
else
    % known labels for all
    z0max = max(labelled);
    for k=1:z0max
        chunks{k} = find(labelled==k)';
    end
    PC = class2indmat(labelled-1,z0max);
    D = sum(PC,1)';
    switch  kernel_type,
        case {'rbf_full','kernel_matrix'},
            A = PC' * ( A * PC );
            trA = trace(A);
        otherwise
            xtilde = xtilde * PC;
            trA = sum(D) - 1/sum(D) * sum(D.^2)  - sum( xtilde(:).^2 );
    end
    negative_constraints = [];
    positive_constraints = NaN;
    subsample = 1:z0max;
end
% reduce the size of xtilde if necessary
if (size(xtilde,1) > size(xtilde,2))
    [q,r]=qr(xtilde);
    xtilde = r(1:size(r,2),1:size(r,2));


end


p = length(D);

% launch optimization with dual formulation - ICML paper constraint + constraints on sum(M)

optparam.kmax = kmax;
optparam.display = 1;
nsub = length(subsample);
nneg = size(negative_constraints,1);

if ~isempty(warm_restart) &  length(warm_restart)~= (nsub * (nsub-1)/2 + p + p + p + 0 + 1 + 1)
    warning('wrong size of warm restart');
end

% uses global variables to store eigenvectors
global u_diffrac;
global j_diffrac;
u_diffrac = [];     % eigenvectors
j_diffrac = 0;      % eigenvalues
if ~isempty(warm_restart)
    DUAL=minimize_projected_gradient_dual_full_fast(warm_restart,optparam,repmat(D.^-.5,1,size(xtilde',2))' .* xtilde,trA,kkk,subsample,lambda0,eps_dual,D,negative_constraints,missing_constraints);
else
    if nneg>0
        optparam.kmax = optparam.kmax / 2;
        DUAL = zeros( nsub * (nsub-1)/2 + p + p + p + 0 + 1 + 1, 1);
        DUAL(end) = 1;      % c parameter
        DUAL=minimize_projected_gradient_dual_full_fast(DUAL,optparam,repmat(D.^-.5,1,size(xtilde',2))' .* xtilde,trA,kkk,subsample,lambda0,eps_dual,D,[],missing_constraints);
        DUALNEW = zeros( nsub * (nsub-1)/2 + p + p + p + nneg + 1 + 1, 1);
        DUALNEW(1: nsub * (nsub-1)/2 + p+ p + p )=DUAL(1: nsub * (nsub-1)/2 + p+ p + p );
        DUALNEW(end)=DUAL(end);
        DUALNEW(end-1)=DUAL(end-1);
        DUAL=DUALNEW;
        DUAL=minimize_projected_gradient_dual_full_fast(DUAL,optparam,repmat(D.^-.5,1,size(xtilde',2))' .* xtilde,trA,kkk,subsample,lambda0,eps_dual,D,negative_constraints,missing_constraints);
    else
        DUAL = zeros( nsub * (nsub-1)/2 + p + p + p + nneg + 1 + 1, 1);
        DUAL(end) = 1;      % c parameter

        DUAL=minimize_projected_gradient_dual_full_fast(DUAL,optparam,repmat(D.^-.5,1,size(xtilde',2))' .* xtilde,trA,kkk,subsample,lambda0,eps_dual,D,negative_constraints,missing_constraints);
    end
end
[fx,subgradient,M1,M2] = dual_cost_full_fast(DUAL,repmat(D.^-.5,1,size(xtilde',2))' .*xtilde,trA,kkk,subsample,lambda0,eps_dual,D,negative_constraints,missing_constraints);

output.DUAL=DUAL;

if any(M2 < -1e-6)
    save temp_negative_eigenvalue, error('one negative eigenvalue :  call francis bach!!');
end
M2 = max(M2,0);
Us = M1 * diag(M2.^.5);




% ROUNDING
[v,e,u]=svd(Us,0);
e = e.^2;
% [u,e]=eig(Us'*Us);
[a,b]=sort(-diag(e));
un =u(:,b(1:kkk));
en =e(b(1:kkk),b(1:kkk));

un = Us * un * diag(1./sqrt(diag(en)));

normalized = sqrt(sum(un.^2,2));
un = un ./ repmat( normalized, 1, kkk);
z = kmeans_restarts_1_0(un',kkk, 'weights', D,'negative_constraints',negative_constraints,'restarts',restart);


output.zu = z;
y = class2indmat(z-1,kkk);
yu = y;
output.zb = z;
output.Usb = Us;

if ~isempty(positive_constraints)
    % go back to actual points
    Us = PC * Us;
    un = PC * un;
    zc = zeros(n,1);
    for ic=1:length(chunks)
        zc(chunks{ic})=z(ic);
    end
    z= zc;
    zc = zeros(n,1);
    for ic=1:length(chunks)
        zc(chunks{ic})=output.zu(ic);
    end
    output.zu= zc;
end
y = class2indmat(z-1,kkk);


if ~isempty(labelled)
    if labelled_option == 1
        % changes the labels to take into account supervision
        change_matrix = zeros(kkk,1);
        for i=1:kkk
            change_matrix(i) = z(labelled_examples{i}(1));
        end
        change_matrix_reverse = 1:kkk;
        change_matrix_reverse(change_matrix)=1:kkk;
        z = change_matrix_reverse(z);
        y = class2indmat(z-1,kkk);
        output.zu = change_matrix_reverse(output.zu);
    end
end

if any(~isreal(Us(:))), save temp_imaginary_diffrac, error('imaginary_diffrac call francis bach!!'); end
