1e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org/* Copyright (c) 2008-2011 Octasic Inc. 2e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org Written by Jean-Marc Valin */ 3e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org/* 4e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org Redistribution and use in source and binary forms, with or without 5e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org modification, are permitted provided that the following conditions 6e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org are met: 7e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 8e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org - Redistributions of source code must retain the above copyright 9e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org notice, this list of conditions and the following disclaimer. 10e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 11e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org - Redistributions in binary form must reproduce the above copyright 12e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org notice, this list of conditions and the following disclaimer in the 13e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org documentation and/or other materials provided with the distribution. 14e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 15e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org*/ 27e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 28e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 29e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include "mlp_train.h" 30e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <stdlib.h> 31e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <stdio.h> 32e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <string.h> 33e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <semaphore.h> 34e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <pthread.h> 35e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <time.h> 36e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <signal.h> 37e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 38e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgint stopped = 0; 39e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 40e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgvoid handler(int sig) 41e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{ 42e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org stopped = 1; 43e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org signal(sig, handler); 44e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org} 45e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 46e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgMLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples) 47e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{ 48e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int i, j, k; 49e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org MLPTrain *net; 50e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int inDim, outDim; 51e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net = malloc(sizeof(*net)); 52e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->topo = malloc(nbLayers*sizeof(net->topo[0])); 53e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<nbLayers;i++) 54e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->topo[i] = topo[i]; 55e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org inDim = topo[0]; 56e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org outDim = topo[nbLayers-1]; 57e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0])); 58e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->weights = malloc((nbLayers-1)*sizeof(net->weights)); 59e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->best_weights = malloc((nbLayers-1)*sizeof(net->weights)); 60e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<nbLayers-1;i++) 61e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 62e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0])); 63e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0])); 64e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 65e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double inMean[inDim]; 66e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<inDim;j++) 67e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 68e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double std=0; 69e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org inMean[j] = 0; 70e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<nbSamples;i++) 71e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 72e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org inMean[j] += inputs[i*inDim+j]; 73e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org std += inputs[i*inDim+j]*inputs[i*inDim+j]; 74e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 75e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org inMean[j] /= nbSamples; 76e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org std /= nbSamples; 77e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->in_rate[1+j] = .5/(.0001+std); 78e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org std = std-inMean[j]*inMean[j]; 79e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (std<.001) 80e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org std = .001; 81e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org std = 1/sqrt(inDim*std); 82e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (k=0;k<topo[1];k++) 83e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->weights[0][k*(topo[0]+1)+j+1] = randn(std); 84e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 85e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->in_rate[0] = 1; 86e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<topo[1];j++) 87e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 88e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double sum = 0; 89e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (k=0;k<inDim;k++) 90e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1]; 91e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->weights[0][j*(topo[0]+1)] = -sum; 92e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 93e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<outDim;j++) 94e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 95e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double mean = 0; 96e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double std; 97e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<nbSamples;i++) 98e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org mean += outputs[i*outDim+j]; 99e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org mean /= nbSamples; 100e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org std = 1/sqrt(topo[nbLayers-2]); 101e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean; 102e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (k=0;k<topo[nbLayers-2];k++) 103e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std); 104e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 105e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org return net; 106e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org} 107e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 108e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#define MAX_NEURONS 100 109e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#define MAX_OUT 10 110e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 111e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgdouble compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate) 112e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{ 113e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int i,j; 114e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int s; 115e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int inDim, outDim, hiddenDim; 116e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int *topo; 117e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0, *W1; 118e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double rms=0; 119e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int W0_size, W1_size; 120e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double hidden[MAX_NEURONS]; 121e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double netOut[MAX_NEURONS]; 122e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double error[MAX_NEURONS]; 123e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 124e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<outDim;i++) 125e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org error_rate[i] = 0; 126e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org topo = net->topo; 127e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org inDim = net->topo[0]; 128e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org hiddenDim = net->topo[1]; 129e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org outDim = net->topo[2]; 130e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_size = (topo[0]+1)*topo[1]; 131e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_size = (topo[1]+1)*topo[2]; 132e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0 = net->weights[0]; 133e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1 = net->weights[1]; 134e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W0_grad, 0, W0_size*sizeof(double)); 135e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W1_grad, 0, W1_size*sizeof(double)); 136e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<outDim;i++) 137e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org netOut[i] = outputs[i]; 138e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (s=0;s<nbSamples;s++) 139e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 140e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float *in, *out; 141e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org in = inputs+s*inDim; 142e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org out = outputs + s*outDim; 143e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<hiddenDim;i++) 144e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 145e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double sum = W0[i*(inDim+1)]; 146e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<inDim;j++) 147e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sum += W0[i*(inDim+1)+j+1]*in[j]; 148e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org hidden[i] = tansig_approx(sum); 149e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 150e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<outDim;i++) 151e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 152e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double sum = W1[i*(hiddenDim+1)]; 153e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<hiddenDim;j++) 154e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sum += W1[i*(hiddenDim+1)+j+1]*hidden[j]; 155e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org netOut[i] = tansig_approx(sum); 156e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org error[i] = out[i] - netOut[i]; 157e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org rms += error[i]*error[i]; 158e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org error_rate[i] += fabs(error[i])>1; 159e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org /*error[i] = error[i]/(1+fabs(error[i]));*/ 160e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 161e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org /* Back-propagate error */ 162e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<outDim;i++) 163e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 164e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float grad = 1-netOut[i]*netOut[i]; 165e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_grad[i*(hiddenDim+1)] += error[i]*grad; 166e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<hiddenDim;j++) 167e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j]; 168e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 169e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<hiddenDim;i++) 170e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 171e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double grad; 172e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org grad = 0; 173e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<outDim;j++) 174e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org grad += error[j]*W1[j*(hiddenDim+1)+i+1]; 175e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org grad *= 1-hidden[i]*hidden[i]; 176e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_grad[i*(inDim+1)] += grad; 177e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<inDim;j++) 178e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_grad[i*(inDim+1)+j+1] += grad*in[j]; 179e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 180e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 181e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org return rms; 182e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org} 183e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 184e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#define NB_THREADS 8 185e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 186e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgsem_t sem_begin[NB_THREADS]; 187e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgsem_t sem_end[NB_THREADS]; 188e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 189e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgstruct GradientArg { 190e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int id; 191e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int done; 192e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org MLPTrain *net; 193e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float *inputs; 194e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float *outputs; 195e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int nbSamples; 196e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0_grad; 197e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W1_grad; 198e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double rms; 199e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double error_rate[MAX_OUT]; 200e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org}; 201e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 202e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgvoid *gradient_thread_process(void *_arg) 203e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{ 204e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int W0_size, W1_size; 205e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org struct GradientArg *arg = _arg; 206e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int *topo = arg->net->topo; 207e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_size = (topo[0]+1)*topo[1]; 208e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_size = (topo[1]+1)*topo[2]; 209e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double W0_grad[W0_size]; 210e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double W1_grad[W1_size]; 211e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org arg->W0_grad = W0_grad; 212e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org arg->W1_grad = W1_grad; 213e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org while (1) 214e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 215e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sem_wait(&sem_begin[arg->id]); 216e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (arg->done) 217e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org break; 218e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate); 219e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sem_post(&sem_end[arg->id]); 220e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 221e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf(stderr, "done\n"); 222e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org return NULL; 223e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org} 224e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 225e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgfloat mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate) 226e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{ 227e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int i, j; 228e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int e; 229e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float best_rms = 1e10; 230e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int inDim, outDim, hiddenDim; 231e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int *topo; 232e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0, *W1, *best_W0, *best_W1; 233e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0_old, *W1_old; 234e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0_old2, *W1_old2; 235e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0_grad, *W1_grad; 236e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0_oldgrad, *W1_oldgrad; 237e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *W0_rate, *W1_rate; 238e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double *best_W0_rate, *best_W1_rate; 239e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int W0_size, W1_size; 240e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org topo = net->topo; 241e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_size = (topo[0]+1)*topo[1]; 242e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_size = (topo[1]+1)*topo[2]; 243e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org struct GradientArg args[NB_THREADS]; 244e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org pthread_t thread[NB_THREADS]; 245e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int samplePerPart = nbSamples/NB_THREADS; 246e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int count_worse=0; 247e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int count_retries=0; 248e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 249e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org topo = net->topo; 250e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org inDim = net->topo[0]; 251e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org hiddenDim = net->topo[1]; 252e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org outDim = net->topo[2]; 253e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0 = net->weights[0]; 254e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1 = net->weights[1]; 255e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W0 = net->best_weights[0]; 256e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W1 = net->best_weights[1]; 257e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_old = malloc(W0_size*sizeof(double)); 258e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_old = malloc(W1_size*sizeof(double)); 259e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_old2 = malloc(W0_size*sizeof(double)); 260e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_old2 = malloc(W1_size*sizeof(double)); 261e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_grad = malloc(W0_size*sizeof(double)); 262e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_grad = malloc(W1_size*sizeof(double)); 263e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_oldgrad = malloc(W0_size*sizeof(double)); 264e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_oldgrad = malloc(W1_size*sizeof(double)); 265e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_rate = malloc(W0_size*sizeof(double)); 266e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_rate = malloc(W1_size*sizeof(double)); 267e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W0_rate = malloc(W0_size*sizeof(double)); 268e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W1_rate = malloc(W1_size*sizeof(double)); 269e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memcpy(W0_old, W0, W0_size*sizeof(double)); 270e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memcpy(W0_old2, W0, W0_size*sizeof(double)); 271e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W0_grad, 0, W0_size*sizeof(double)); 272e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W0_oldgrad, 0, W0_size*sizeof(double)); 273e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memcpy(W1_old, W1, W1_size*sizeof(double)); 274e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memcpy(W1_old2, W1, W1_size*sizeof(double)); 275e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W1_grad, 0, W1_size*sizeof(double)); 276e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W1_oldgrad, 0, W1_size*sizeof(double)); 277e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 278e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org rate /= nbSamples; 279e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<hiddenDim;i++) 280e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<inDim+1;j++) 281e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j]; 282e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<W1_size;i++) 283e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_rate[i] = rate; 284e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 285e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<NB_THREADS;i++) 286e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 287e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org args[i].net = net; 288e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org args[i].inputs = inputs+i*samplePerPart*inDim; 289e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org args[i].outputs = outputs+i*samplePerPart*outDim; 290e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org args[i].nbSamples = samplePerPart; 291e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org args[i].id = i; 292e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org args[i].done = 0; 293e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sem_init(&sem_begin[i], 0, 0); 294e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sem_init(&sem_end[i], 0, 0); 295e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]); 296e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 297e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (e=0;e<nbEpoch;e++) 298e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 299e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double rms=0; 300e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org double error_rate[2] = {0,0}; 301e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<NB_THREADS;i++) 302e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 303e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sem_post(&sem_begin[i]); 304e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 305e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W0_grad, 0, W0_size*sizeof(double)); 306e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org memset(W1_grad, 0, W1_size*sizeof(double)); 307e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<NB_THREADS;i++) 308e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 309e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sem_wait(&sem_end[i]); 310e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org rms += args[i].rms; 311e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org error_rate[0] += args[i].error_rate[0]; 312e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org error_rate[1] += args[i].error_rate[1]; 313e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<W0_size;j++) 314e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_grad[j] += args[i].W0_grad[j]; 315e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<W1_size;j++) 316e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_grad[j] += args[i].W1_grad[j]; 317e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 318e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 319e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float mean_rate = 0, min_rate = 1e10; 320e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org rms = (rms/(outDim*nbSamples)); 321e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org error_rate[0] = (error_rate[0]/(nbSamples)); 322e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org error_rate[1] = (error_rate[1]/(nbSamples)); 323e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms); 324e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (rms < best_rms) 325e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 326e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_rms = rms; 327e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<W0_size;i++) 328e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 329e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W0[i] = W0[i]; 330e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W0_rate[i] = W0_rate[i]; 331e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 332e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<W1_size;i++) 333e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 334e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W1[i] = W1[i]; 335e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W1_rate[i] = W1_rate[i]; 336e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 337e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org count_worse=0; 338e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org count_retries=0; 339e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } else { 340e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org count_worse++; 341e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (count_worse>30) 342e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 343e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org count_retries++; 344e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org count_worse=0; 345e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<W0_size;i++) 346e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 347e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0[i] = best_W0[i]; 348e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W0_rate[i] *= .7; 349e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15; 350e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_rate[i] = best_W0_rate[i]; 351e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_grad[i] = 0; 352e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 353e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<W1_size;i++) 354e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 355e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1[i] = best_W1[i]; 356e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org best_W1_rate[i] *= .8; 357e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15; 358e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_rate[i] = best_W1_rate[i]; 359e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_grad[i] = 0; 360e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 361e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 362e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 363e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (count_retries>10) 364e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org break; 365e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<W0_size;i++) 366e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 367e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (W0_oldgrad[i]*W0_grad[i] > 0) 368e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_rate[i] *= 1.01; 369e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org else if (W0_oldgrad[i]*W0_grad[i] < 0) 370e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_rate[i] *= .9; 371e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org mean_rate += W0_rate[i]; 372e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (W0_rate[i] < min_rate) 373e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org min_rate = W0_rate[i]; 374e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (W0_rate[i] < 1e-15) 375e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_rate[i] = 1e-15; 376e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org /*if (W0_rate[i] > .01) 377e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_rate[i] = .01;*/ 378e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_oldgrad[i] = W0_grad[i]; 379e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_old2[i] = W0_old[i]; 380e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0_old[i] = W0[i]; 381e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W0[i] += W0_grad[i]*W0_rate[i]; 382e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 383e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<W1_size;i++) 384e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 385e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (W1_oldgrad[i]*W1_grad[i] > 0) 386e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_rate[i] *= 1.01; 387e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org else if (W1_oldgrad[i]*W1_grad[i] < 0) 388e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_rate[i] *= .9; 389e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org mean_rate += W1_rate[i]; 390e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (W1_rate[i] < min_rate) 391e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org min_rate = W1_rate[i]; 392e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (W1_rate[i] < 1e-15) 393e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_rate[i] = 1e-15; 394e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_oldgrad[i] = W1_grad[i]; 395e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_old2[i] = W1_old[i]; 396e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1_old[i] = W1[i]; 397e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org W1[i] += W1_grad[i]*W1_rate[i]; 398e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 399e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]; 400e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf (stderr, "%g %d", mean_rate, e); 401e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (count_retries) 402e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf(stderr, " %d", count_retries); 403e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf(stderr, "\n"); 404e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (stopped) 405e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org break; 406e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 407e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<NB_THREADS;i++) 408e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 409e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org args[i].done = 1; 410e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org sem_post(&sem_begin[i]); 411e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org pthread_join(thread[i], NULL); 412e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf (stderr, "joined %d\n", i); 413e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 414e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org free(W0_old); 415e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org free(W1_old); 416e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org free(W0_grad); 417e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org free(W1_grad); 418e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org free(W0_rate); 419e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org free(W1_rate); 420e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org return best_rms; 421e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org} 422e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 423e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgint main(int argc, char **argv) 424e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{ 425e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int i, j; 426e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int nbInputs; 427e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int nbOutputs; 428e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int nbHidden; 429e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int nbSamples; 430e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int nbEpoch; 431e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int nbRealInputs; 432e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org unsigned int seed; 433e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int ret; 434e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float rms; 435e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float *inputs; 436e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org float *outputs; 437e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (argc!=6) 438e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 439e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n"); 440e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org return 1; 441e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 442e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org nbInputs = atoi(argv[1]); 443e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org nbHidden = atoi(argv[2]); 444e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org nbOutputs = atoi(argv[3]); 445e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org nbSamples = atoi(argv[4]); 446e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org nbEpoch = atoi(argv[5]); 447e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org nbRealInputs = nbInputs; 448e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org inputs = malloc(nbInputs*nbSamples*sizeof(*inputs)); 449e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs)); 450e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 451e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org seed = time(NULL); 452e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org /*seed = 1361480659;*/ 453e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf (stderr, "Seed is %u\n", seed); 454e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org srand(seed); 455e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org build_tansig_table(); 456e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org signal(SIGTERM, handler); 457e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org signal(SIGINT, handler); 458e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org signal(SIGHUP, handler); 459e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<nbSamples;i++) 460e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 461e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<nbRealInputs;j++) 462e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org ret = scanf(" %f", &inputs[i*nbInputs+j]); 463e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (j=0;j<nbOutputs;j++) 464e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org ret = scanf(" %f", &outputs[i*nbOutputs+j]); 465e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (feof(stdin)) 466e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 467e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org nbSamples = i; 468e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org break; 469e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 470e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 471e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org int topo[3] = {nbInputs, nbHidden, nbOutputs}; 472e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org MLPTrain *net; 473e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org 474e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org fprintf (stderr, "Got %d samples\n", nbSamples); 475e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org net = mlp_init(topo, 3, inputs, outputs, nbSamples); 476e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1); 477e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("#include \"mlp.h\"\n\n"); 478e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed); 479e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]); 480e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("\n/* hidden layer */\n"); 481e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<(topo[0]+1)*topo[1];i++) 482e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 483e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("%gf, ", net->weights[0][i]); 484e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (i%5==4) 485e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf("\n"); 486e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 487e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("\n/* output layer */\n"); 488e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org for (i=0;i<(topo[1]+1)*topo[2];i++) 489e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org { 490e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("%g, ", net->weights[1][i]); 491e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org if (i%5==4) 492e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf("\n"); 493e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org } 494e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("};\n\n"); 495e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]); 496e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("const MLP net = {\n"); 497e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("\t3,\n"); 498e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("\ttopo,\n"); 499e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org printf ("\tweights\n};\n"); 500e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org return 0; 501e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org} 502