Une implémentation Java de l'algorithme Expectation-maximization (EM).


exemple de convergence avec 3 distributions normales (voir post suivant)

L'interface qui définit une loi de probabilité
Code java : Sélectionner tout - Visualiser dans une fenêtre à part
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
interface Law {
	/**
         * @param x some value
         * @return the probability of the value x
         */
	double proba(double x);
 
	/**
         * improve law parameters
         * 
         * @param N number of samples
         * @param x samples
         * @param tk probability of each sample
         */
	void improveParameters(int N, double[] x, double[] tk);
}

Le code de l'algorithme EM
Code java : Sélectionner tout - Visualiser dans une fenêtre à part
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/**
 * Compute the mixture coefficients using EM algorithm
 * 
 * @param x sample values
 * @param laws instances of the laws
 * @return mixture coefficients 
 */
public double[] algorithmEM(double[] x, Law[] laws) {
	int N=x.length;
	int G=laws.length;
 
	double[] pi = new double[G];
	double[][] t = new double[G][N];
 
	// initial guess for the mixture coefficients (uniform)
	for(int k=0;k<G;k++) pi[k]=1.0/G;
 
	// iterative loop (until convergence or 5000 iterations)
	double convergence;
	for(int loop=0;loop<5000;loop++) {
		convergence=0;
 
		// ---- E Step ----
 
		//(Bayes inversion formula)
		for(int i=0;i<N;i++) {
			double denominator = 0;
			for(int l=0;l<G;l++) denominator+=pi[l]*laws[l].proba(x[i]);
			for(int k=0;k<G;k++) {
				double numerator = pi[k]*laws[k].proba(x[i]);
				t[k][i]=numerator/denominator;
			}
		}
 
		// ---- M Step ----
 
		// mixture coefficients (maximum likelihood estimate of binomial distribution)
		for(int k=0;k<G;k++) {
			double savedpi=pi[k];
			pi[k]=0;
			for(int i=0;i<N;i++) pi[k]+=t[k][i];
			pi[k]/=N;
			convergence+=(savedpi-pi[k])*(savedpi-pi[k]);
		}
 
		// update the parameters of the laws
		for(int k=0;k<G;k++)
			laws[k].improveParameters(N, x, t[k]);
 
		if( convergence < 1E-10 ) break;
	}
 
	return pi;
}