10a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim/* Copyright (c) 2008-2011 Octasic Inc.
20a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   Written by Jean-Marc Valin */
30a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim/*
40a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   Redistribution and use in source and binary forms, with or without
50a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   modification, are permitted provided that the following conditions
60a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   are met:
70a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
80a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   - Redistributions of source code must retain the above copyright
90a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   notice, this list of conditions and the following disclaimer.
100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   - Redistributions in binary form must reproduce the above copyright
120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   notice, this list of conditions and the following disclaimer in the
130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   documentation and/or other materials provided with the distribution.
140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim*/
270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include "mlp_train.h"
300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <stdlib.h>
310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <stdio.h>
320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <string.h>
330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <semaphore.h>
340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <pthread.h>
350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <time.h>
360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <signal.h>
370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limint stopped = 0;
390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limvoid handler(int sig)
410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{
420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    stopped = 1;
430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    signal(sig, handler);
440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim}
450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia LimMLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{
480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int i, j, k;
490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    MLPTrain *net;
500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int inDim, outDim;
510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    net = malloc(sizeof(*net));
520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    net->topo = malloc(nbLayers*sizeof(net->topo[0]));
530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<nbLayers;i++)
540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        net->topo[i] = topo[i];
550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    inDim = topo[0];
560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    outDim = topo[nbLayers-1];
570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    net->weights = malloc((nbLayers-1)*sizeof(net->weights));
590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<nbLayers-1;i++)
610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double inMean[inDim];
660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (j=0;j<inDim;j++)
670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        double std=0;
690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        inMean[j] = 0;
700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<nbSamples;i++)
710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            inMean[j] += inputs[i*inDim+j];
730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            std += inputs[i*inDim+j]*inputs[i*inDim+j];
740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        inMean[j] /= nbSamples;
760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        std /= nbSamples;
770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        net->in_rate[1+j] = .5/(.0001+std);
780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        std = std-inMean[j]*inMean[j];
790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (std<.001)
800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            std = .001;
810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        std = 1/sqrt(inDim*std);
820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (k=0;k<topo[1];k++)
830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    net->in_rate[0] = 1;
860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (j=0;j<topo[1];j++)
870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        double sum = 0;
890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (k=0;k<inDim;k++)
900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        net->weights[0][j*(topo[0]+1)] = -sum;
920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (j=0;j<outDim;j++)
940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        double mean = 0;
960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        double std;
970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<nbSamples;i++)
980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            mean += outputs[i*outDim+j];
990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        mean /= nbSamples;
1000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        std = 1/sqrt(topo[nbLayers-2]);
1010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
1020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (k=0;k<topo[nbLayers-2];k++)
1030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
1040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
1050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    return net;
1060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim}
1070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
1080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#define MAX_NEURONS 100
1090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#define MAX_OUT 10
1100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
1110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limdouble compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
1120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{
1130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int i,j;
1140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int s;
1150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int inDim, outDim, hiddenDim;
1160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int *topo;
1170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0, *W1;
1180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double rms=0;
1190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int W0_size, W1_size;
1200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double hidden[MAX_NEURONS];
1210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double netOut[MAX_NEURONS];
1220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double error[MAX_NEURONS];
1230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
1240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    topo = net->topo;
1250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    inDim = net->topo[0];
1260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    hiddenDim = net->topo[1];
1270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    outDim = net->topo[2];
1280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_size = (topo[0]+1)*topo[1];
1290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_size = (topo[1]+1)*topo[2];
1300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0 = net->weights[0];
1310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1 = net->weights[1];
1320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memset(W0_grad, 0, W0_size*sizeof(double));
1330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memset(W1_grad, 0, W1_size*sizeof(double));
1340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<outDim;i++)
1350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        netOut[i] = outputs[i];
1360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<outDim;i++)
1370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        error_rate[i] = 0;
1380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (s=0;s<nbSamples;s++)
1390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
1400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        float *in, *out;
1410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        in = inputs+s*inDim;
1420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        out = outputs + s*outDim;
1430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<hiddenDim;i++)
1440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
1450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            double sum = W0[i*(inDim+1)];
1460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (j=0;j<inDim;j++)
1470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                sum += W0[i*(inDim+1)+j+1]*in[j];
1480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            hidden[i] = tansig_approx(sum);
1490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
1500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<outDim;i++)
1510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
1520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            double sum = W1[i*(hiddenDim+1)];
1530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (j=0;j<hiddenDim;j++)
1540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
1550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            netOut[i] = tansig_approx(sum);
1560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            error[i] = out[i] - netOut[i];
1570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            rms += error[i]*error[i];
1580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            error_rate[i] += fabs(error[i])>1;
1590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            /*error[i] = error[i]/(1+fabs(error[i]));*/
1600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
1610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        /* Back-propagate error */
1620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<outDim;i++)
1630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
1640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            float grad = 1-netOut[i]*netOut[i];
1650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W1_grad[i*(hiddenDim+1)] += error[i]*grad;
1660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (j=0;j<hiddenDim;j++)
1670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
1680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
1690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<hiddenDim;i++)
1700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
1710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            double grad;
1720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            grad = 0;
1730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (j=0;j<outDim;j++)
1740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                grad += error[j]*W1[j*(hiddenDim+1)+i+1];
1750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            grad *= 1-hidden[i]*hidden[i];
1760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W0_grad[i*(inDim+1)] += grad;
1770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (j=0;j<inDim;j++)
1780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W0_grad[i*(inDim+1)+j+1] += grad*in[j];
1790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
1800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
1810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    return rms;
1820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim}
1830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
1840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#define NB_THREADS 8
1850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
1860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limsem_t sem_begin[NB_THREADS];
1870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limsem_t sem_end[NB_THREADS];
1880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
1890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limstruct GradientArg {
1900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int id;
1910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int done;
1920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    MLPTrain *net;
1930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    float *inputs;
1940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    float *outputs;
1950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int nbSamples;
1960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0_grad;
1970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W1_grad;
1980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double rms;
1990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double error_rate[MAX_OUT];
2000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim};
2010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
2020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limvoid *gradient_thread_process(void *_arg)
2030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{
2040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int W0_size, W1_size;
2050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    struct GradientArg *arg = _arg;
2060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int *topo = arg->net->topo;
2070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_size = (topo[0]+1)*topo[1];
2080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_size = (topo[1]+1)*topo[2];
2090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double W0_grad[W0_size];
2100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double W1_grad[W1_size];
2110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    arg->W0_grad = W0_grad;
2120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    arg->W1_grad = W1_grad;
2130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    while (1)
2140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
2150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        sem_wait(&sem_begin[arg->id]);
2160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (arg->done)
2170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            break;
2180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
2190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        sem_post(&sem_end[arg->id]);
2200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
2210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    fprintf(stderr, "done\n");
2220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    return NULL;
2230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim}
2240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
2250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limfloat mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
2260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{
2270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int i, j;
2280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int e;
2290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    float best_rms = 1e10;
2300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int inDim, outDim, hiddenDim;
2310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int *topo;
2320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0, *W1, *best_W0, *best_W1;
2330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0_old, *W1_old;
2340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0_old2, *W1_old2;
2350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0_grad, *W1_grad;
2360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0_oldgrad, *W1_oldgrad;
2370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *W0_rate, *W1_rate;
2380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    double *best_W0_rate, *best_W1_rate;
2390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int W0_size, W1_size;
2400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    topo = net->topo;
2410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_size = (topo[0]+1)*topo[1];
2420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_size = (topo[1]+1)*topo[2];
2430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    struct GradientArg args[NB_THREADS];
2440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    pthread_t thread[NB_THREADS];
2450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int samplePerPart = nbSamples/NB_THREADS;
2460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int count_worse=0;
2470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int count_retries=0;
2480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
2490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    topo = net->topo;
2500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    inDim = net->topo[0];
2510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    hiddenDim = net->topo[1];
2520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    outDim = net->topo[2];
2530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0 = net->weights[0];
2540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1 = net->weights[1];
2550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    best_W0 = net->best_weights[0];
2560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    best_W1 = net->best_weights[1];
2570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_old = malloc(W0_size*sizeof(double));
2580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_old = malloc(W1_size*sizeof(double));
2590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_old2 = malloc(W0_size*sizeof(double));
2600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_old2 = malloc(W1_size*sizeof(double));
2610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_grad = malloc(W0_size*sizeof(double));
2620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_grad = malloc(W1_size*sizeof(double));
2630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_oldgrad = malloc(W0_size*sizeof(double));
2640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_oldgrad = malloc(W1_size*sizeof(double));
2650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W0_rate = malloc(W0_size*sizeof(double));
2660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    W1_rate = malloc(W1_size*sizeof(double));
2670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    best_W0_rate = malloc(W0_size*sizeof(double));
2680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    best_W1_rate = malloc(W1_size*sizeof(double));
2690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memcpy(W0_old, W0, W0_size*sizeof(double));
2700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memcpy(W0_old2, W0, W0_size*sizeof(double));
2710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memset(W0_grad, 0, W0_size*sizeof(double));
2720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memset(W0_oldgrad, 0, W0_size*sizeof(double));
2730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memcpy(W1_old, W1, W1_size*sizeof(double));
2740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memcpy(W1_old2, W1, W1_size*sizeof(double));
2750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memset(W1_grad, 0, W1_size*sizeof(double));
2760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    memset(W1_oldgrad, 0, W1_size*sizeof(double));
2770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
2780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    rate /= nbSamples;
2790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<hiddenDim;i++)
2800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (j=0;j<inDim+1;j++)
2810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
2820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<W1_size;i++)
2830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        W1_rate[i] = rate;
2840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
2850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<NB_THREADS;i++)
2860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
2870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        args[i].net = net;
2880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        args[i].inputs = inputs+i*samplePerPart*inDim;
2890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        args[i].outputs = outputs+i*samplePerPart*outDim;
2900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        args[i].nbSamples = samplePerPart;
2910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        args[i].id = i;
2920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        args[i].done = 0;
2930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        sem_init(&sem_begin[i], 0, 0);
2940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        sem_init(&sem_end[i], 0, 0);
2950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
2960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
2970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (e=0;e<nbEpoch;e++)
2980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
2990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        double rms=0;
3000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        double error_rate[2] = {0,0};
3010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<NB_THREADS;i++)
3020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
3030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            sem_post(&sem_begin[i]);
3040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
3050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        memset(W0_grad, 0, W0_size*sizeof(double));
3060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        memset(W1_grad, 0, W1_size*sizeof(double));
3070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<NB_THREADS;i++)
3080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
3090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            sem_wait(&sem_end[i]);
3100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            rms += args[i].rms;
3110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            error_rate[0] += args[i].error_rate[0];
3120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            error_rate[1] += args[i].error_rate[1];
3130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (j=0;j<W0_size;j++)
3140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W0_grad[j] += args[i].W0_grad[j];
3150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (j=0;j<W1_size;j++)
3160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W1_grad[j] += args[i].W1_grad[j];
3170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
3180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
3190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        float mean_rate = 0, min_rate = 1e10;
3200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        rms = (rms/(outDim*nbSamples));
3210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        error_rate[0] = (error_rate[0]/(nbSamples));
3220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        error_rate[1] = (error_rate[1]/(nbSamples));
3230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
3240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (rms < best_rms)
3250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
3260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            best_rms = rms;
3270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (i=0;i<W0_size;i++)
3280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            {
3290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                best_W0[i] = W0[i];
3300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                best_W0_rate[i] = W0_rate[i];
3310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            }
3320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            for (i=0;i<W1_size;i++)
3330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            {
3340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                best_W1[i] = W1[i];
3350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                best_W1_rate[i] = W1_rate[i];
3360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            }
3370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            count_worse=0;
3380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            count_retries=0;
3390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        } else {
3400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            count_worse++;
3410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            if (count_worse>30)
3420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            {
3430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                count_retries++;
3440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                count_worse=0;
3450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                for (i=0;i<W0_size;i++)
3460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                {
3470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    W0[i] = best_W0[i];
3480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    best_W0_rate[i] *= .7;
3490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
3500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    W0_rate[i] = best_W0_rate[i];
3510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    W0_grad[i] = 0;
3520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                }
3530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                for (i=0;i<W1_size;i++)
3540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                {
3550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    W1[i] = best_W1[i];
3560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    best_W1_rate[i] *= .8;
3570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
3580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    W1_rate[i] = best_W1_rate[i];
3590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                    W1_grad[i] = 0;
3600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                }
3610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            }
3620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
3630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (count_retries>10)
3640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            break;
3650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<W0_size;i++)
3660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
3670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            if (W0_oldgrad[i]*W0_grad[i] > 0)
3680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W0_rate[i] *= 1.01;
3690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            else if (W0_oldgrad[i]*W0_grad[i] < 0)
3700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W0_rate[i] *= .9;
3710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            mean_rate += W0_rate[i];
3720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            if (W0_rate[i] < min_rate)
3730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                min_rate = W0_rate[i];
3740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            if (W0_rate[i] < 1e-15)
3750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W0_rate[i] = 1e-15;
3760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            /*if (W0_rate[i] > .01)
3770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W0_rate[i] = .01;*/
3780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W0_oldgrad[i] = W0_grad[i];
3790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W0_old2[i] = W0_old[i];
3800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W0_old[i] = W0[i];
3810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W0[i] += W0_grad[i]*W0_rate[i];
3820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
3830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (i=0;i<W1_size;i++)
3840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
3850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            if (W1_oldgrad[i]*W1_grad[i] > 0)
3860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W1_rate[i] *= 1.01;
3870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            else if (W1_oldgrad[i]*W1_grad[i] < 0)
3880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W1_rate[i] *= .9;
3890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            mean_rate += W1_rate[i];
3900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            if (W1_rate[i] < min_rate)
3910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                min_rate = W1_rate[i];
3920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            if (W1_rate[i] < 1e-15)
3930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim                W1_rate[i] = 1e-15;
3940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W1_oldgrad[i] = W1_grad[i];
3950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W1_old2[i] = W1_old[i];
3960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W1_old[i] = W1[i];
3970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            W1[i] += W1_grad[i]*W1_rate[i];
3980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
3990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
4000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        fprintf (stderr, "%g %d", mean_rate, e);
4010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (count_retries)
4020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            fprintf(stderr, " %d", count_retries);
4030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        fprintf(stderr, "\n");
4040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (stopped)
4050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            break;
4060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
4070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<NB_THREADS;i++)
4080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
4090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        args[i].done = 1;
4100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        sem_post(&sem_begin[i]);
4110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        pthread_join(thread[i], NULL);
4120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        fprintf (stderr, "joined %d\n", i);
4130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
4140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    free(W0_old);
4150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    free(W1_old);
4160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    free(W0_grad);
4170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    free(W1_grad);
4180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    free(W0_rate);
4190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    free(W1_rate);
4200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    return best_rms;
4210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim}
4220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
4230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limint main(int argc, char **argv)
4240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{
4250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int i, j;
4260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int nbInputs;
4270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int nbOutputs;
4280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int nbHidden;
4290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int nbSamples;
4300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int nbEpoch;
4310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int nbRealInputs;
4320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    unsigned int seed;
4330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int ret;
4340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    float rms;
4350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    float *inputs;
4360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    float *outputs;
4370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    if (argc!=6)
4380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
4390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
4400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        return 1;
4410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
4420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    nbInputs = atoi(argv[1]);
4430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    nbHidden = atoi(argv[2]);
4440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    nbOutputs = atoi(argv[3]);
4450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    nbSamples = atoi(argv[4]);
4460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    nbEpoch = atoi(argv[5]);
4470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    nbRealInputs = nbInputs;
4480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
4490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
4500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
4510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    seed = time(NULL);
4520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    /*seed = 1361480659;*/
4530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    fprintf (stderr, "Seed is %u\n", seed);
4540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    srand(seed);
4550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    build_tansig_table();
4560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    signal(SIGTERM, handler);
4570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    signal(SIGINT, handler);
4580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    signal(SIGHUP, handler);
4590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<nbSamples;i++)
4600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
4610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (j=0;j<nbRealInputs;j++)
4620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            ret = scanf(" %f", &inputs[i*nbInputs+j]);
4630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        for (j=0;j<nbOutputs;j++)
4640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            ret = scanf(" %f", &outputs[i*nbOutputs+j]);
4650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (feof(stdin))
4660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        {
4670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            nbSamples = i;
4680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            break;
4690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        }
4700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
4710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    int topo[3] = {nbInputs, nbHidden, nbOutputs};
4720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    MLPTrain *net;
4730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim
4740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    fprintf (stderr, "Got %d samples\n", nbSamples);
4750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    net = mlp_init(topo, 3, inputs, outputs, nbSamples);
4760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
4770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("#include \"mlp.h\"\n\n");
4780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
4790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
4800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("\n/* hidden layer */\n");
4810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<(topo[0]+1)*topo[1];i++)
4820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
4830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        printf ("%gf, ", net->weights[0][i]);
4840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (i%5==4)
4850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            printf("\n");
4860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
4870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("\n/* output layer */\n");
4880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    for (i=0;i<(topo[1]+1)*topo[2];i++)
4890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    {
4900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        printf ("%g, ", net->weights[1][i]);
4910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim        if (i%5==4)
4920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim            printf("\n");
4930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    }
4940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("};\n\n");
4950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
4960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("const MLP net = {\n");
4970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("\t3,\n");
4980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("\ttopo,\n");
4990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    printf ("\tweights\n};\n");
5000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim    return 0;
5010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim}
502