1e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org/* Copyright (c) 2008-2011 Octasic Inc.
2e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   Written by Jean-Marc Valin */
3e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org/*
4e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   Redistribution and use in source and binary forms, with or without
5e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   modification, are permitted provided that the following conditions
6e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   are met:
7e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
8e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   - Redistributions of source code must retain the above copyright
9e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   notice, this list of conditions and the following disclaimer.
10e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
11e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   - Redistributions in binary form must reproduce the above copyright
12e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   notice, this list of conditions and the following disclaimer in the
13e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   documentation and/or other materials provided with the distribution.
14e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
15e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org*/
27e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
28e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
29e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include "mlp_train.h"
30e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <stdlib.h>
31e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <stdio.h>
32e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <string.h>
33e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <semaphore.h>
34e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <pthread.h>
35e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <time.h>
36e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#include <signal.h>
37e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
38e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgint stopped = 0;
39e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
40e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgvoid handler(int sig)
41e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{
42e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	stopped = 1;
43e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	signal(sig, handler);
44e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org}
45e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
46e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgMLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{
48e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int i, j, k;
49e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	MLPTrain *net;
50e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int inDim, outDim;
51e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	net = malloc(sizeof(*net));
52e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<nbLayers;i++)
54e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		net->topo[i] = topo[i];
55e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	inDim = topo[0];
56e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	outDim = topo[nbLayers-1];
57e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<nbLayers-1;i++)
61e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
62e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
65e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double inMean[inDim];
66e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (j=0;j<inDim;j++)
67e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
68e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		double std=0;
69e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		inMean[j] = 0;
70e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<nbSamples;i++)
71e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
72e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			inMean[j] += inputs[i*inDim+j];
73e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			std += inputs[i*inDim+j]*inputs[i*inDim+j];
74e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
75e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		inMean[j] /= nbSamples;
76e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		std /= nbSamples;
77e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		net->in_rate[1+j] = .5/(.0001+std);
78e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		std = std-inMean[j]*inMean[j];
79e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (std<.001)
80e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			std = .001;
81e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		std = 1/sqrt(inDim*std);
82e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (k=0;k<topo[1];k++)
83e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
85e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	net->in_rate[0] = 1;
86e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (j=0;j<topo[1];j++)
87e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
88e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		double sum = 0;
89e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (k=0;k<inDim;k++)
90e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		net->weights[0][j*(topo[0]+1)] = -sum;
92e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
93e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (j=0;j<outDim;j++)
94e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
95e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		double mean = 0;
96e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		double std;
97e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<nbSamples;i++)
98e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			mean += outputs[i*outDim+j];
99e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		mean /= nbSamples;
100e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		std = 1/sqrt(topo[nbLayers-2]);
101e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (k=0;k<topo[nbLayers-2];k++)
103e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
105e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	return net;
106e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org}
107e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
108e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#define MAX_NEURONS 100
109e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#define MAX_OUT 10
110e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
111e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgdouble compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
112e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{
113e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int i,j;
114e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int s;
115e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int inDim, outDim, hiddenDim;
116e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int *topo;
117e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0, *W1;
118e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double rms=0;
119e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int W0_size, W1_size;
120e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double hidden[MAX_NEURONS];
121e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double netOut[MAX_NEURONS];
122e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double error[MAX_NEURONS];
123e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
124e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<outDim;i++)
125e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	   error_rate[i] = 0;
126e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	topo = net->topo;
127e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	inDim = net->topo[0];
128e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	hiddenDim = net->topo[1];
129e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	outDim = net->topo[2];
130e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_size = (topo[0]+1)*topo[1];
131e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_size = (topo[1]+1)*topo[2];
132e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0 = net->weights[0];
133e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1 = net->weights[1];
134e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memset(W0_grad, 0, W0_size*sizeof(double));
135e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memset(W1_grad, 0, W1_size*sizeof(double));
136e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<outDim;i++)
137e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		netOut[i] = outputs[i];
138e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (s=0;s<nbSamples;s++)
139e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
140e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		float *in, *out;
141e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		in = inputs+s*inDim;
142e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		out = outputs + s*outDim;
143e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<hiddenDim;i++)
144e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
145e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			double sum = W0[i*(inDim+1)];
146e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (j=0;j<inDim;j++)
147e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				sum += W0[i*(inDim+1)+j+1]*in[j];
148e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			hidden[i] = tansig_approx(sum);
149e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
150e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<outDim;i++)
151e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
152e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			double sum = W1[i*(hiddenDim+1)];
153e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (j=0;j<hiddenDim;j++)
154e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
155e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			netOut[i] = tansig_approx(sum);
156e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			error[i] = out[i] - netOut[i];
157e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			rms += error[i]*error[i];
158e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			error_rate[i] += fabs(error[i])>1;
159e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			/*error[i] = error[i]/(1+fabs(error[i]));*/
160e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
161e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		/* Back-propagate error */
162e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<outDim;i++)
163e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
164e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org                        float grad = 1-netOut[i]*netOut[i];
165e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W1_grad[i*(hiddenDim+1)] += error[i]*grad;
166e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (j=0;j<hiddenDim;j++)
167e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
168e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
169e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<hiddenDim;i++)
170e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
171e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			double grad;
172e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			grad = 0;
173e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (j=0;j<outDim;j++)
174e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				grad += error[j]*W1[j*(hiddenDim+1)+i+1];
175e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			grad *= 1-hidden[i]*hidden[i];
176e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W0_grad[i*(inDim+1)] += grad;
177e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (j=0;j<inDim;j++)
178e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W0_grad[i*(inDim+1)+j+1] += grad*in[j];
179e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
180e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
181e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	return rms;
182e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org}
183e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
184e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org#define NB_THREADS 8
185e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
186e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgsem_t sem_begin[NB_THREADS];
187e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgsem_t sem_end[NB_THREADS];
188e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
189e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgstruct GradientArg {
190e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int id;
191e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int done;
192e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	MLPTrain *net;
193e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	float *inputs;
194e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	float *outputs;
195e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int nbSamples;
196e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0_grad;
197e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W1_grad;
198e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double rms;
199e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double error_rate[MAX_OUT];
200e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org};
201e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
202e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgvoid *gradient_thread_process(void *_arg)
203e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{
204e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int W0_size, W1_size;
205e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	struct GradientArg *arg = _arg;
206e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int *topo = arg->net->topo;
207e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_size = (topo[0]+1)*topo[1];
208e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_size = (topo[1]+1)*topo[2];
209e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double W0_grad[W0_size];
210e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double W1_grad[W1_size];
211e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	arg->W0_grad = W0_grad;
212e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	arg->W1_grad = W1_grad;
213e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	while (1)
214e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
215e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		sem_wait(&sem_begin[arg->id]);
216e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (arg->done)
217e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			break;
218e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
219e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		sem_post(&sem_end[arg->id]);
220e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
221e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	fprintf(stderr, "done\n");
222e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	return NULL;
223e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org}
224e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
225e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgfloat mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
226e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{
227e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int i, j;
228e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int e;
229e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	float best_rms = 1e10;
230e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int inDim, outDim, hiddenDim;
231e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int *topo;
232e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0, *W1, *best_W0, *best_W1;
233e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0_old, *W1_old;
234e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0_old2, *W1_old2;
235e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0_grad, *W1_grad;
236e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0_oldgrad, *W1_oldgrad;
237e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *W0_rate, *W1_rate;
238e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	double *best_W0_rate, *best_W1_rate;
239e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int W0_size, W1_size;
240e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	topo = net->topo;
241e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_size = (topo[0]+1)*topo[1];
242e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_size = (topo[1]+1)*topo[2];
243e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	struct GradientArg args[NB_THREADS];
244e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	pthread_t thread[NB_THREADS];
245e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int samplePerPart = nbSamples/NB_THREADS;
246e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int count_worse=0;
247e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int count_retries=0;
248e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
249e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	topo = net->topo;
250e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	inDim = net->topo[0];
251e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	hiddenDim = net->topo[1];
252e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	outDim = net->topo[2];
253e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0 = net->weights[0];
254e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1 = net->weights[1];
255e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	best_W0 = net->best_weights[0];
256e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	best_W1 = net->best_weights[1];
257e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_old = malloc(W0_size*sizeof(double));
258e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_old = malloc(W1_size*sizeof(double));
259e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_old2 = malloc(W0_size*sizeof(double));
260e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_old2 = malloc(W1_size*sizeof(double));
261e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_grad = malloc(W0_size*sizeof(double));
262e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_grad = malloc(W1_size*sizeof(double));
263e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_oldgrad = malloc(W0_size*sizeof(double));
264e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_oldgrad = malloc(W1_size*sizeof(double));
265e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W0_rate = malloc(W0_size*sizeof(double));
266e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	W1_rate = malloc(W1_size*sizeof(double));
267e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	best_W0_rate = malloc(W0_size*sizeof(double));
268e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	best_W1_rate = malloc(W1_size*sizeof(double));
269e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memcpy(W0_old, W0, W0_size*sizeof(double));
270e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memcpy(W0_old2, W0, W0_size*sizeof(double));
271e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memset(W0_grad, 0, W0_size*sizeof(double));
272e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memset(W0_oldgrad, 0, W0_size*sizeof(double));
273e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memcpy(W1_old, W1, W1_size*sizeof(double));
274e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memcpy(W1_old2, W1, W1_size*sizeof(double));
275e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memset(W1_grad, 0, W1_size*sizeof(double));
276e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	memset(W1_oldgrad, 0, W1_size*sizeof(double));
277e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
278e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	rate /= nbSamples;
279e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<hiddenDim;i++)
280e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (j=0;j<inDim+1;j++)
281e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
282e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<W1_size;i++)
283e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		W1_rate[i] = rate;
284e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
285e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<NB_THREADS;i++)
286e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
287e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		args[i].net = net;
288e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		args[i].inputs = inputs+i*samplePerPart*inDim;
289e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		args[i].outputs = outputs+i*samplePerPart*outDim;
290e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		args[i].nbSamples = samplePerPart;
291e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		args[i].id = i;
292e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		args[i].done = 0;
293e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		sem_init(&sem_begin[i], 0, 0);
294e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		sem_init(&sem_end[i], 0, 0);
295e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
296e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
297e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (e=0;e<nbEpoch;e++)
298e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
299e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		double rms=0;
300e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		double error_rate[2] = {0,0};
301e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<NB_THREADS;i++)
302e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
303e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			sem_post(&sem_begin[i]);
304e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
305e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		memset(W0_grad, 0, W0_size*sizeof(double));
306e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		memset(W1_grad, 0, W1_size*sizeof(double));
307e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<NB_THREADS;i++)
308e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
309e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			sem_wait(&sem_end[i]);
310e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			rms += args[i].rms;
311e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			error_rate[0] += args[i].error_rate[0];
312e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org            error_rate[1] += args[i].error_rate[1];
313e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (j=0;j<W0_size;j++)
314e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W0_grad[j] += args[i].W0_grad[j];
315e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (j=0;j<W1_size;j++)
316e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W1_grad[j] += args[i].W1_grad[j];
317e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
318e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
319e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		float mean_rate = 0, min_rate = 1e10;
320e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		rms = (rms/(outDim*nbSamples));
321e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		error_rate[0] = (error_rate[0]/(nbSamples));
322e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org        error_rate[1] = (error_rate[1]/(nbSamples));
323e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
324e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (rms < best_rms)
325e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
326e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			best_rms = rms;
327e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (i=0;i<W0_size;i++)
328e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			{
329e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				best_W0[i] = W0[i];
330e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				best_W0_rate[i] = W0_rate[i];
331e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			}
332e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			for (i=0;i<W1_size;i++)
333e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			{
334e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				best_W1[i] = W1[i];
335e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				best_W1_rate[i] = W1_rate[i];
336e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			}
337e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			count_worse=0;
338e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			count_retries=0;
339e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		} else {
340e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			count_worse++;
341e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			if (count_worse>30)
342e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			{
343e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			    count_retries++;
344e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				count_worse=0;
345e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				for (i=0;i<W0_size;i++)
346e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				{
347e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					W0[i] = best_W0[i];
348e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					best_W0_rate[i] *= .7;
349e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
350e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					W0_rate[i] = best_W0_rate[i];
351e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					W0_grad[i] = 0;
352e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				}
353e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				for (i=0;i<W1_size;i++)
354e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				{
355e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					W1[i] = best_W1[i];
356e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					best_W1_rate[i] *= .8;
357e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
358e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					W1_rate[i] = best_W1_rate[i];
359e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org					W1_grad[i] = 0;
360e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				}
361e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			}
362e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
363e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (count_retries>10)
364e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		    break;
365e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<W0_size;i++)
366e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
367e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			if (W0_oldgrad[i]*W0_grad[i] > 0)
368e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W0_rate[i] *= 1.01;
369e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			else if (W0_oldgrad[i]*W0_grad[i] < 0)
370e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W0_rate[i] *= .9;
371e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			mean_rate += W0_rate[i];
372e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			if (W0_rate[i] < min_rate)
373e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				min_rate = W0_rate[i];
374e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			if (W0_rate[i] < 1e-15)
375e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W0_rate[i] = 1e-15;
376e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			/*if (W0_rate[i] > .01)
377e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W0_rate[i] = .01;*/
378e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W0_oldgrad[i] = W0_grad[i];
379e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W0_old2[i] = W0_old[i];
380e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W0_old[i] = W0[i];
381e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W0[i] += W0_grad[i]*W0_rate[i];
382e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
383e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (i=0;i<W1_size;i++)
384e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
385e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			if (W1_oldgrad[i]*W1_grad[i] > 0)
386e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W1_rate[i] *= 1.01;
387e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			else if (W1_oldgrad[i]*W1_grad[i] < 0)
388e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W1_rate[i] *= .9;
389e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			mean_rate += W1_rate[i];
390e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			if (W1_rate[i] < min_rate)
391e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				min_rate = W1_rate[i];
392e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			if (W1_rate[i] < 1e-15)
393e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org				W1_rate[i] = 1e-15;
394e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W1_oldgrad[i] = W1_grad[i];
395e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W1_old2[i] = W1_old[i];
396e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W1_old[i] = W1[i];
397e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			W1[i] += W1_grad[i]*W1_rate[i];
398e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
399e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
400e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		fprintf (stderr, "%g %d", mean_rate, e);
401e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (count_retries)
402e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		    fprintf(stderr, " %d", count_retries);
403e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		fprintf(stderr, "\n");
404e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (stopped)
405e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			break;
406e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
407e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<NB_THREADS;i++)
408e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
409e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		args[i].done = 1;
410e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		sem_post(&sem_begin[i]);
411e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		pthread_join(thread[i], NULL);
412e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		fprintf (stderr, "joined %d\n", i);
413e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
414e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	free(W0_old);
415e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	free(W1_old);
416e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	free(W0_grad);
417e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	free(W1_grad);
418e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	free(W0_rate);
419e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	free(W1_rate);
420e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	return best_rms;
421e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org}
422e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
423e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.orgint main(int argc, char **argv)
424e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org{
425e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int i, j;
426e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int nbInputs;
427e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int nbOutputs;
428e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int nbHidden;
429e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int nbSamples;
430e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int nbEpoch;
431e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int nbRealInputs;
432e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	unsigned int seed;
433e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int ret;
434e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	float rms;
435e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	float *inputs;
436e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	float *outputs;
437e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	if (argc!=6)
438e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
439e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
440e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		return 1;
441e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
442e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	nbInputs = atoi(argv[1]);
443e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	nbHidden = atoi(argv[2]);
444e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	nbOutputs = atoi(argv[3]);
445e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	nbSamples = atoi(argv[4]);
446e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	nbEpoch = atoi(argv[5]);
447e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	nbRealInputs = nbInputs;
448e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
449e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
450e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
451e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	seed = time(NULL);
452e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org    /*seed = 1361480659;*/
453e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	fprintf (stderr, "Seed is %u\n", seed);
454e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	srand(seed);
455e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	build_tansig_table();
456e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	signal(SIGTERM, handler);
457e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	signal(SIGINT, handler);
458e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	signal(SIGHUP, handler);
459e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<nbSamples;i++)
460e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
461e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (j=0;j<nbRealInputs;j++)
462e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			ret = scanf(" %f", &inputs[i*nbInputs+j]);
463e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		for (j=0;j<nbOutputs;j++)
464e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			ret = scanf(" %f", &outputs[i*nbOutputs+j]);
465e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (feof(stdin))
466e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		{
467e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			nbSamples = i;
468e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			break;
469e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		}
470e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
471e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	int topo[3] = {nbInputs, nbHidden, nbOutputs};
472e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	MLPTrain *net;
473e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org
474e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	fprintf (stderr, "Got %d samples\n", nbSamples);
475e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	net = mlp_init(topo, 3, inputs, outputs, nbSamples);
476e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
477e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("#include \"mlp.h\"\n\n");
478e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
479e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
480e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("\n/* hidden layer */\n");
481e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<(topo[0]+1)*topo[1];i++)
482e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
483e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		printf ("%gf, ", net->weights[0][i]);
484e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (i%5==4)
485e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			printf("\n");
486e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
487e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("\n/* output layer */\n");
488e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	for (i=0;i<(topo[1]+1)*topo[2];i++)
489e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	{
490e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		printf ("%g, ", net->weights[1][i]);
491e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org		if (i%5==4)
492e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org			printf("\n");
493e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	}
494e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("};\n\n");
495e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
496e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("const MLP net = {\n");
497e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("\t3,\n");
498e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("\ttopo,\n");
499e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	printf ("\tweights\n};\n");
500e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org	return 0;
501e3ea049fcaee2247e45f0ce793d4313babb4ef69tlegrand@chromium.org}
502