首页 > 安全资讯 >

Softmax classifier (一个隐含层)

16-09-24

Softmax classifier (一个隐含层)

程序实现 softmax classifier, 含有一个隐含层的情况。activation function 是 ReLU : f(x)=max(0,x)
f1=w1x+b1

h1=max(0,f1)

f2=w2h1+b2

y=ef2i∑jef2j

function Out=Softmax_Classifier_1(train_x,  train_y, opts)

% setting learning parameters
step_size=opts.step_size;
reg=opts.reg;
batchsize = opts.batchsize;
numepochs = opts.numepochs;
K=opts.class;
h=opts.hidden;

D=size(train_x, 2);
W1=0.01*randn(D,h);
b1=zeros(1,h);
W2=0.01*randn(h, K);
b2=zeros(1,K);


loss(1 : numepochs)=0;

num_examples=size(train_x, 1);
numbatches = num_examples / batchsize;

for epoch=1:numepochs

    kk = randperm(num_examples);
    loss(epoch)=0;

    % %      tic;
    % %
    % %       sprintf('epoch %d:  \n' , epoch)


    for bat=1:numbatches

        batch_x = train_x(kk((bat - 1) * batchsize + 1 : bat * batchsize), :);
        batch_y = train_y(kk((bat - 1) * batchsize + 1 : bat * batchsize), :);

        %% forward
        f1=batch_x*W1+repmat(b1, batchsize, 1);
        hiddenval_1=max(0, f1);
        scores=hiddenval_1*W2+repmat(b2, batchsize, 1);

        %% the loss
        exp_scores=exp(scores);
        dd=repmat(sum(exp_scores, 2), 1, K);
        probs=exp_scores./dd;
        correct_logprobs=-log(sum(probs.*batch_y, 2));
        data_loss=sum(correct_logprobs)/batchsize;
        reg_loss=0.5*reg*sum(sum(W1.*W1))+0.5*reg*sum(sum(W2.*W2));
        loss(epoch) =loss(epoch)+ data_loss + reg_loss;

        %% back propagation
        dscores = probs-batch_y;
        dscores=dscores/batchsize;
        dW2=hiddenval_1'*dscores;
        db2=sum(dscores);

        dhiddenval_1=dscores*W2';
        mask=max(sign(hiddenval_1), 0);
        df_1=dhiddenval_1.*mask;
        dW1=batch_x'*df_1;
        db1=sum(df_1);

        %% update
        dW2=dW2+reg*W2;
        dW1=dW1+reg*W1;

        W1=W1-step_size*dW1;
        b1=b1-step_size*db1;

        W2=W2-step_size*dW2;
        b2=b2-step_size*db2;

    end

    loss(epoch)=loss(epoch)/numbatches;

    if (mod(epoch, 10)==0)
        sprintf('epoch: %d, training loss is  %f:  \n', epoch, loss(epoch))
    end

    toc;

end

Out.W1=W1;
Out.b1=b1;
Out.b2=b2;
Out.W2=W2;
Out.loss=loss;

end

相关文章
最新文章
热点推荐