10a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim/* Copyright (c) 2008-2011 Octasic Inc. 20a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim Written by Jean-Marc Valin */ 30a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim/* 40a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim Redistribution and use in source and binary forms, with or without 50a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim modification, are permitted provided that the following conditions 60a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim are met: 70a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 80a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim - Redistributions of source code must retain the above copyright 90a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim notice, this list of conditions and the following disclaimer. 100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim - Redistributions in binary form must reproduce the above copyright 120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim notice, this list of conditions and the following disclaimer in the 130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim documentation and/or other materials provided with the distribution. 140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim*/ 270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include "mlp_train.h" 300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <stdlib.h> 310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <stdio.h> 320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <string.h> 330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <semaphore.h> 340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <pthread.h> 350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <time.h> 360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#include <signal.h> 370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limint stopped = 0; 390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limvoid handler(int sig) 410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{ 420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim stopped = 1; 430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim signal(sig, handler); 440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim} 450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia LimMLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples) 470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{ 480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int i, j, k; 490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim MLPTrain *net; 500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int inDim, outDim; 510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net = malloc(sizeof(*net)); 520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->topo = malloc(nbLayers*sizeof(net->topo[0])); 530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<nbLayers;i++) 540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->topo[i] = topo[i]; 550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim inDim = topo[0]; 560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim outDim = topo[nbLayers-1]; 570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0])); 580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->weights = malloc((nbLayers-1)*sizeof(net->weights)); 590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->best_weights = malloc((nbLayers-1)*sizeof(net->weights)); 600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<nbLayers-1;i++) 610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0])); 630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0])); 640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double inMean[inDim]; 660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<inDim;j++) 670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double std=0; 690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim inMean[j] = 0; 700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<nbSamples;i++) 710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim inMean[j] += inputs[i*inDim+j]; 730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim std += inputs[i*inDim+j]*inputs[i*inDim+j]; 740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim inMean[j] /= nbSamples; 760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim std /= nbSamples; 770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->in_rate[1+j] = .5/(.0001+std); 780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim std = std-inMean[j]*inMean[j]; 790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (std<.001) 800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim std = .001; 810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim std = 1/sqrt(inDim*std); 820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (k=0;k<topo[1];k++) 830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->weights[0][k*(topo[0]+1)+j+1] = randn(std); 840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->in_rate[0] = 1; 860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<topo[1];j++) 870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double sum = 0; 890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (k=0;k<inDim;k++) 900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1]; 910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->weights[0][j*(topo[0]+1)] = -sum; 920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<outDim;j++) 940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double mean = 0; 960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double std; 970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<nbSamples;i++) 980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim mean += outputs[i*outDim+j]; 990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim mean /= nbSamples; 1000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim std = 1/sqrt(topo[nbLayers-2]); 1010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean; 1020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (k=0;k<topo[nbLayers-2];k++) 1030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std); 1040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 1050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim return net; 1060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim} 1070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 1080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#define MAX_NEURONS 100 1090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#define MAX_OUT 10 1100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 1110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limdouble compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate) 1120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{ 1130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int i,j; 1140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int s; 1150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int inDim, outDim, hiddenDim; 1160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int *topo; 1170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0, *W1; 1180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double rms=0; 1190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int W0_size, W1_size; 1200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double hidden[MAX_NEURONS]; 1210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double netOut[MAX_NEURONS]; 1220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double error[MAX_NEURONS]; 1230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 1240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim topo = net->topo; 1250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim inDim = net->topo[0]; 1260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim hiddenDim = net->topo[1]; 1270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim outDim = net->topo[2]; 1280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_size = (topo[0]+1)*topo[1]; 1290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_size = (topo[1]+1)*topo[2]; 1300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0 = net->weights[0]; 1310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1 = net->weights[1]; 1320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W0_grad, 0, W0_size*sizeof(double)); 1330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W1_grad, 0, W1_size*sizeof(double)); 1340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<outDim;i++) 1350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim netOut[i] = outputs[i]; 1360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<outDim;i++) 1370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim error_rate[i] = 0; 1380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (s=0;s<nbSamples;s++) 1390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 1400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float *in, *out; 1410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim in = inputs+s*inDim; 1420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim out = outputs + s*outDim; 1430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<hiddenDim;i++) 1440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 1450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double sum = W0[i*(inDim+1)]; 1460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<inDim;j++) 1470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sum += W0[i*(inDim+1)+j+1]*in[j]; 1480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim hidden[i] = tansig_approx(sum); 1490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 1500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<outDim;i++) 1510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 1520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double sum = W1[i*(hiddenDim+1)]; 1530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<hiddenDim;j++) 1540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sum += W1[i*(hiddenDim+1)+j+1]*hidden[j]; 1550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim netOut[i] = tansig_approx(sum); 1560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim error[i] = out[i] - netOut[i]; 1570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim rms += error[i]*error[i]; 1580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim error_rate[i] += fabs(error[i])>1; 1590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim /*error[i] = error[i]/(1+fabs(error[i]));*/ 1600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 1610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim /* Back-propagate error */ 1620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<outDim;i++) 1630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 1640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float grad = 1-netOut[i]*netOut[i]; 1650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_grad[i*(hiddenDim+1)] += error[i]*grad; 1660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<hiddenDim;j++) 1670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j]; 1680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 1690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<hiddenDim;i++) 1700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 1710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double grad; 1720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim grad = 0; 1730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<outDim;j++) 1740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim grad += error[j]*W1[j*(hiddenDim+1)+i+1]; 1750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim grad *= 1-hidden[i]*hidden[i]; 1760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_grad[i*(inDim+1)] += grad; 1770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<inDim;j++) 1780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_grad[i*(inDim+1)+j+1] += grad*in[j]; 1790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 1800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 1810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim return rms; 1820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim} 1830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 1840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim#define NB_THREADS 8 1850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 1860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limsem_t sem_begin[NB_THREADS]; 1870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limsem_t sem_end[NB_THREADS]; 1880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 1890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limstruct GradientArg { 1900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int id; 1910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int done; 1920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim MLPTrain *net; 1930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float *inputs; 1940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float *outputs; 1950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int nbSamples; 1960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0_grad; 1970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W1_grad; 1980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double rms; 1990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double error_rate[MAX_OUT]; 2000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim}; 2010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 2020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limvoid *gradient_thread_process(void *_arg) 2030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{ 2040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int W0_size, W1_size; 2050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim struct GradientArg *arg = _arg; 2060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int *topo = arg->net->topo; 2070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_size = (topo[0]+1)*topo[1]; 2080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_size = (topo[1]+1)*topo[2]; 2090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double W0_grad[W0_size]; 2100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double W1_grad[W1_size]; 2110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim arg->W0_grad = W0_grad; 2120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim arg->W1_grad = W1_grad; 2130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim while (1) 2140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 2150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sem_wait(&sem_begin[arg->id]); 2160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (arg->done) 2170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim break; 2180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate); 2190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sem_post(&sem_end[arg->id]); 2200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 2210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf(stderr, "done\n"); 2220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim return NULL; 2230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim} 2240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 2250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limfloat mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate) 2260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{ 2270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int i, j; 2280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int e; 2290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float best_rms = 1e10; 2300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int inDim, outDim, hiddenDim; 2310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int *topo; 2320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0, *W1, *best_W0, *best_W1; 2330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0_old, *W1_old; 2340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0_old2, *W1_old2; 2350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0_grad, *W1_grad; 2360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0_oldgrad, *W1_oldgrad; 2370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *W0_rate, *W1_rate; 2380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double *best_W0_rate, *best_W1_rate; 2390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int W0_size, W1_size; 2400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim topo = net->topo; 2410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_size = (topo[0]+1)*topo[1]; 2420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_size = (topo[1]+1)*topo[2]; 2430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim struct GradientArg args[NB_THREADS]; 2440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim pthread_t thread[NB_THREADS]; 2450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int samplePerPart = nbSamples/NB_THREADS; 2460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int count_worse=0; 2470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int count_retries=0; 2480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 2490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim topo = net->topo; 2500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim inDim = net->topo[0]; 2510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim hiddenDim = net->topo[1]; 2520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim outDim = net->topo[2]; 2530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0 = net->weights[0]; 2540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1 = net->weights[1]; 2550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W0 = net->best_weights[0]; 2560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W1 = net->best_weights[1]; 2570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_old = malloc(W0_size*sizeof(double)); 2580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_old = malloc(W1_size*sizeof(double)); 2590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_old2 = malloc(W0_size*sizeof(double)); 2600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_old2 = malloc(W1_size*sizeof(double)); 2610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_grad = malloc(W0_size*sizeof(double)); 2620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_grad = malloc(W1_size*sizeof(double)); 2630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_oldgrad = malloc(W0_size*sizeof(double)); 2640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_oldgrad = malloc(W1_size*sizeof(double)); 2650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_rate = malloc(W0_size*sizeof(double)); 2660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_rate = malloc(W1_size*sizeof(double)); 2670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W0_rate = malloc(W0_size*sizeof(double)); 2680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W1_rate = malloc(W1_size*sizeof(double)); 2690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memcpy(W0_old, W0, W0_size*sizeof(double)); 2700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memcpy(W0_old2, W0, W0_size*sizeof(double)); 2710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W0_grad, 0, W0_size*sizeof(double)); 2720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W0_oldgrad, 0, W0_size*sizeof(double)); 2730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memcpy(W1_old, W1, W1_size*sizeof(double)); 2740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memcpy(W1_old2, W1, W1_size*sizeof(double)); 2750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W1_grad, 0, W1_size*sizeof(double)); 2760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W1_oldgrad, 0, W1_size*sizeof(double)); 2770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 2780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim rate /= nbSamples; 2790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<hiddenDim;i++) 2800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<inDim+1;j++) 2810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j]; 2820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<W1_size;i++) 2830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_rate[i] = rate; 2840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 2850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<NB_THREADS;i++) 2860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 2870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim args[i].net = net; 2880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim args[i].inputs = inputs+i*samplePerPart*inDim; 2890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim args[i].outputs = outputs+i*samplePerPart*outDim; 2900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim args[i].nbSamples = samplePerPart; 2910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim args[i].id = i; 2920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim args[i].done = 0; 2930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sem_init(&sem_begin[i], 0, 0); 2940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sem_init(&sem_end[i], 0, 0); 2950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]); 2960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 2970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (e=0;e<nbEpoch;e++) 2980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 2990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double rms=0; 3000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim double error_rate[2] = {0,0}; 3010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<NB_THREADS;i++) 3020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sem_post(&sem_begin[i]); 3040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W0_grad, 0, W0_size*sizeof(double)); 3060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim memset(W1_grad, 0, W1_size*sizeof(double)); 3070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<NB_THREADS;i++) 3080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sem_wait(&sem_end[i]); 3100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim rms += args[i].rms; 3110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim error_rate[0] += args[i].error_rate[0]; 3120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim error_rate[1] += args[i].error_rate[1]; 3130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<W0_size;j++) 3140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_grad[j] += args[i].W0_grad[j]; 3150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<W1_size;j++) 3160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_grad[j] += args[i].W1_grad[j]; 3170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 3190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float mean_rate = 0, min_rate = 1e10; 3200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim rms = (rms/(outDim*nbSamples)); 3210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim error_rate[0] = (error_rate[0]/(nbSamples)); 3220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim error_rate[1] = (error_rate[1]/(nbSamples)); 3230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms); 3240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (rms < best_rms) 3250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_rms = rms; 3270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<W0_size;i++) 3280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W0[i] = W0[i]; 3300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W0_rate[i] = W0_rate[i]; 3310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<W1_size;i++) 3330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W1[i] = W1[i]; 3350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W1_rate[i] = W1_rate[i]; 3360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim count_worse=0; 3380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim count_retries=0; 3390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } else { 3400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim count_worse++; 3410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (count_worse>30) 3420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim count_retries++; 3440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim count_worse=0; 3450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<W0_size;i++) 3460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0[i] = best_W0[i]; 3480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W0_rate[i] *= .7; 3490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15; 3500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_rate[i] = best_W0_rate[i]; 3510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_grad[i] = 0; 3520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<W1_size;i++) 3540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1[i] = best_W1[i]; 3560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim best_W1_rate[i] *= .8; 3570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15; 3580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_rate[i] = best_W1_rate[i]; 3590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_grad[i] = 0; 3600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (count_retries>10) 3640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim break; 3650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<W0_size;i++) 3660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (W0_oldgrad[i]*W0_grad[i] > 0) 3680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_rate[i] *= 1.01; 3690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim else if (W0_oldgrad[i]*W0_grad[i] < 0) 3700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_rate[i] *= .9; 3710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim mean_rate += W0_rate[i]; 3720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (W0_rate[i] < min_rate) 3730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim min_rate = W0_rate[i]; 3740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (W0_rate[i] < 1e-15) 3750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_rate[i] = 1e-15; 3760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim /*if (W0_rate[i] > .01) 3770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_rate[i] = .01;*/ 3780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_oldgrad[i] = W0_grad[i]; 3790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_old2[i] = W0_old[i]; 3800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0_old[i] = W0[i]; 3810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W0[i] += W0_grad[i]*W0_rate[i]; 3820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<W1_size;i++) 3840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 3850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (W1_oldgrad[i]*W1_grad[i] > 0) 3860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_rate[i] *= 1.01; 3870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim else if (W1_oldgrad[i]*W1_grad[i] < 0) 3880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_rate[i] *= .9; 3890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim mean_rate += W1_rate[i]; 3900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (W1_rate[i] < min_rate) 3910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim min_rate = W1_rate[i]; 3920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (W1_rate[i] < 1e-15) 3930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_rate[i] = 1e-15; 3940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_oldgrad[i] = W1_grad[i]; 3950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_old2[i] = W1_old[i]; 3960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1_old[i] = W1[i]; 3970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim W1[i] += W1_grad[i]*W1_rate[i]; 3980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 3990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]; 4000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf (stderr, "%g %d", mean_rate, e); 4010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (count_retries) 4020a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf(stderr, " %d", count_retries); 4030a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf(stderr, "\n"); 4040a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (stopped) 4050a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim break; 4060a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 4070a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<NB_THREADS;i++) 4080a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 4090a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim args[i].done = 1; 4100a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim sem_post(&sem_begin[i]); 4110a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim pthread_join(thread[i], NULL); 4120a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf (stderr, "joined %d\n", i); 4130a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 4140a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim free(W0_old); 4150a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim free(W1_old); 4160a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim free(W0_grad); 4170a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim free(W1_grad); 4180a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim free(W0_rate); 4190a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim free(W1_rate); 4200a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim return best_rms; 4210a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim} 4220a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 4230a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Limint main(int argc, char **argv) 4240a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim{ 4250a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int i, j; 4260a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int nbInputs; 4270a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int nbOutputs; 4280a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int nbHidden; 4290a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int nbSamples; 4300a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int nbEpoch; 4310a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int nbRealInputs; 4320a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim unsigned int seed; 4330a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int ret; 4340a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float rms; 4350a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float *inputs; 4360a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim float *outputs; 4370a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (argc!=6) 4380a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 4390a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n"); 4400a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim return 1; 4410a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 4420a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim nbInputs = atoi(argv[1]); 4430a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim nbHidden = atoi(argv[2]); 4440a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim nbOutputs = atoi(argv[3]); 4450a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim nbSamples = atoi(argv[4]); 4460a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim nbEpoch = atoi(argv[5]); 4470a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim nbRealInputs = nbInputs; 4480a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim inputs = malloc(nbInputs*nbSamples*sizeof(*inputs)); 4490a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs)); 4500a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 4510a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim seed = time(NULL); 4520a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim /*seed = 1361480659;*/ 4530a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf (stderr, "Seed is %u\n", seed); 4540a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim srand(seed); 4550a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim build_tansig_table(); 4560a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim signal(SIGTERM, handler); 4570a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim signal(SIGINT, handler); 4580a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim signal(SIGHUP, handler); 4590a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<nbSamples;i++) 4600a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 4610a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<nbRealInputs;j++) 4620a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim ret = scanf(" %f", &inputs[i*nbInputs+j]); 4630a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (j=0;j<nbOutputs;j++) 4640a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim ret = scanf(" %f", &outputs[i*nbOutputs+j]); 4650a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (feof(stdin)) 4660a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 4670a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim nbSamples = i; 4680a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim break; 4690a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 4700a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 4710a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim int topo[3] = {nbInputs, nbHidden, nbOutputs}; 4720a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim MLPTrain *net; 4730a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim 4740a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim fprintf (stderr, "Got %d samples\n", nbSamples); 4750a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim net = mlp_init(topo, 3, inputs, outputs, nbSamples); 4760a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1); 4770a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("#include \"mlp.h\"\n\n"); 4780a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed); 4790a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]); 4800a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("\n/* hidden layer */\n"); 4810a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<(topo[0]+1)*topo[1];i++) 4820a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 4830a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("%gf, ", net->weights[0][i]); 4840a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (i%5==4) 4850a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf("\n"); 4860a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 4870a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("\n/* output layer */\n"); 4880a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim for (i=0;i<(topo[1]+1)*topo[2];i++) 4890a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim { 4900a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("%g, ", net->weights[1][i]); 4910a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim if (i%5==4) 4920a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf("\n"); 4930a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim } 4940a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("};\n\n"); 4950a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]); 4960a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("const MLP net = {\n"); 4970a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("\t3,\n"); 4980a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("\ttopo,\n"); 4990a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim printf ("\tweights\n};\n"); 5000a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim return 0; 5010a1406acbe87c63044e9da7e0ab41bcbfa704f3dFelicia Lim} 502