

clear all
seed= 1 ;
randn('state',seed);
rand('state',seed);

n = 200;
d = 400;
temp = randn(d,d); temp = temp + temp';
[u,e] = eig(temp);
Sigma = u * u';
Sigma = Sigma / trace(Sigma);
Sigmasqrt = sqrtm(Sigma);
theta_ast =   randn(d,1)  ;
theta_ast = theta_ast / sqrt( theta_ast'*Sigma*theta_ast);



nrep = 40;
ms = 0:2:2*d;
ms( find(abs(ms-n)<9) ) = []; % remove the m's which are too close to n

empRvar = Inf * ones(length(ms),nrep);
Rvar_bound = Inf * ones(length(ms),1);
empRbias = Inf * ones(length(ms),nrep);
Rbias_bound_old = Inf * ones(length(ms),1);
Rbias_bound = Inf * ones(length(ms),1);


for im = 1:length(ms);
    m = ms(im);
    
    if m < n
        
        % compute implicit regularization parameters
        for irep=1:nrep
            S = sign(randn(d,m));
            mmloc(irep) = trace( inv(S'*Sigma*S) );
        end
        mm = mean(mmloc);
        kappas(im) = 1/mm;
        kappa = 1/mm;
        
        Rvar_bound(im) = m / (n - m) ;
        invSigmakappa = inv( Sigma + kappa * eye(d));
        invSigmakappa2 = invSigmakappa * invSigmakappa;
        Rbias_bound(im) =  n / (n - m ) * kappa * theta_ast' * ( Sigma * invSigmakappa ) * theta_ast;
        
    elseif m> n
        
        % compute implicit regularization parameters
        for irep=1:nrep
            XX = sign(randn(n,d)) * Sigmasqrt;
            mmloc(irep) = trace( inv(XX*XX') );
        end
        mm = mean(mmloc);
        kappas(im) = 1/mm;
        kappa = 1/mm;
        
        invSigmakappa = inv( Sigma + kappa * eye(d));
        invSigmakappa2 = invSigmakappa * invSigmakappa;
        
        df2 = trace( Sigma * Sigma * invSigmakappa2 );
        Rvar_bound(im) = df2/n / ( 1-df2/n) + n / ( m - n);
        Rbias_bound(im) =  kappa^2 * theta_ast' * ( Sigma * invSigmakappa2 ) * theta_ast  / ( 1-df2/n);
        Rbias_bound(im) = Rbias_bound(im) + n / (m - n) * kappa * theta_ast' * ( Sigma * invSigmakappa ) * theta_ast;
        
        
    end
end




for irep=1:nrep
    
    
    % X / Sfull averaged out
    
    X = sign(randn(n,d)) * Sigmasqrt;
    invK = inv(X*X');
    Pi = X' * invK * X;
    Sfull = sign(randn(d,max(ms)));
    
    
    % epsilon are averaged out
    
    epsilon = sign(randn(n,1));
    
    
    for im = 1:length(ms);
        m = ms(im);
        S = Sfull(:,1:m);
        
        if m < n
            
            invSXXS = inv( S'*X'*X*S );
            thetabias =  S * invSXXS * S' * X' * X * theta_ast;
            thetavar =  S * invSXXS * S' * X' * epsilon;
            
            empRvar(im,irep) = (thetavar)'*Sigma*(thetavar);
            empRbias(im,irep) = (thetabias-theta_ast)'*Sigma*(thetabias-theta_ast);
            
        elseif  m > n
            
            
            
            invXSSX = inv(X*S*S'*X');
            thetabias = S*S'*X'* invXSSX * X * theta_ast;
            thetavar =  S*S'*X'* invXSSX * epsilon;
            
            empRvar(im,irep) = (thetavar)'*Sigma*(thetavar);
            empRbias(im,irep) = (thetabias-theta_ast)'*Sigma*(thetabias-theta_ast);
            
        else
            invXSSX = inv(X*S*S'*X');
            thetabias = S*S'*X'* invXSSX * X * theta_ast;
            thetavar =  S*S'*X'* invXSSX * epsilon;
            
            empRvar(im,irep) = (thetavar)'*Sigma*(thetavar);
            empRbias(im,irep) = (thetabias-theta_ast)'*Sigma*(thetabias-theta_ast);
            
        end
        
    end
end

subplot(1,2,1);
plot(ms, Rvar_bound,'r','linewidth',2); hold on;
plot(ms,mean(empRvar,2),'b','linewidth',2);
plot(ms,mean(empRvar,2) + std(empRvar, [],2),'b:','linewidth',2);
temp =  mean(empRvar,2) - std(empRvar,[],2);
plot(ms,temp,'b:','linewidth',2);
plot(ms, Rvar_bound,'r','linewidth',2); hold off

axis([ 0 max(ms) 0 8])
title('variance','fontweight','normal')
set(gca,'fontsize',18);
legend('theoretical bound','empirical estimate');
xlabel('m');
ylabel('excess risks');

subplot(1,2,2);
plot(ms,Rbias_bound,'r','linewidth',2); hold on
plot(ms,mean(empRbias,2),'b','linewidth',2);
plot(ms,mean(empRbias,2) + std(empRbias, [],2),'b:','linewidth',2);
temp =  mean(empRbias,2) - std(empRbias,[],2);
plot(ms,temp,'b:','linewidth',2);
plot(ms,Rbias_bound,'r','linewidth',2); hold off
hold off
axis([ 0 max(ms) 0 2])
title('bias','fontweight','normal')
set(gca,'fontsize',18);
legend('theoretical bound','empirical estimate');
xlabel('m');
ylabel('excess risks');


