L'optimisation d'optimisation (EM) est une sorte de méthode probabiliste pour classer les données. S'il vous plaît, corrigez-moi si je me trompe s'il ne s'agit pas d'un classificateur.
Qu'est-ce qu'une explication intuitive de cette technique EM? Qu'est-ce que expectation
ici et qu'est-ce que maximized
?
Remarque: le code derrière cette réponse peut être trouvé ici .
Supposons que nous ayons des données échantillonnées à partir de deux groupes différents, rouge et bleu:
Ici, nous pouvons voir quel point de données appartient au groupe rouge ou bleu. Cela facilite la recherche des paramètres qui caractérisent chaque groupe. Par exemple, la moyenne du groupe rouge est d'environ 3, la moyenne du groupe bleu est d'environ 7 (et nous pourrions trouver la moyenne exacte si nous le souhaitions).
Ceci est généralement appelé estimation du maximum de vraisemblance. À partir de certaines données, nous calculons la valeur d’un paramètre (ou de plusieurs paramètres) qui explique le mieux ces données.
Imaginons maintenant que ne peut pas voir quelle valeur a été échantillonnée dans quel groupe. Tout nous semble violet:
Nous savons ici qu'il existe deux groupes de valeurs, mais nous ne savons pas à quel groupe appartient une valeur particulière.
Pouvons-nous toujours estimer les moyennes pour le groupe rouge et le groupe bleu qui correspondent le mieux à ces données?
Oui, souvent nous pouvons! La maximisation de l'attente nous donne un moyen de le faire. L'idée très générale derrière l'algorithme est la suivante:
Ces étapes nécessitent des explications supplémentaires, je vais donc passer en revue le problème décrit ci-dessus.
Je vais utiliser Python dans cet exemple, mais le code devrait être assez facile à comprendre si vous n'êtes pas familier avec ce langage.
Supposons que nous ayons deux groupes, rouge et bleu, avec les valeurs distribuées comme dans l'image ci-dessus. Plus précisément, chaque groupe contient une valeur tirée de distribution normale avec les paramètres suivants:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue))) # for later use...
Voici à nouveau une image de ces groupes rouge et bleu (pour vous éviter de devoir faire défiler vers le haut):
Lorsque nous pouvons voir la couleur de chaque point (c’est-à-dire le groupe auquel il appartient), il est très facile d’estimer la moyenne et l’écart type de chaque groupe. Nous passons simplement les valeurs rouge et bleue aux fonctions intégrées de NumPy. Par exemple:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Mais que se passe-t-il si ne peut pas voir les couleurs des points? Autrement dit, au lieu de rouge ou de bleu, chaque point a été coloré en violet.
Pour essayer de récupérer les paramètres de moyenne et d'écart type des groupes rouge et bleu, nous pouvons utiliser l'optimisation de l'espérance.
Notre première étape ( étape 1 ci-dessus) consiste à deviner les valeurs des paramètres pour la moyenne et l'écart type de chaque groupe. Nous n'avons pas à deviner intelligemment; nous pouvons choisir tous les nombres que nous aimons:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Ces estimations de paramètres produisent des courbes en cloche ressemblant à ceci:
Ce sont de mauvaises estimations. Les deux moyens (les lignes pointillées verticales) semblent éloignés de tout type de "milieu" pour les groupes de points sensibles, par exemple. Nous voulons améliorer ces estimations.
L'étape suivante ( étape 2 ) consiste à calculer la probabilité que chaque point de données apparaisse sous le paramètre actuel devine:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Ici, nous avons simplement mis chaque point de données dans la fonction de densité de probabilité pour une distribution normale utilisant nos estimations actuelles à la moyenne et à l'écart type pour le rouge et le bleu. Cela nous indique, par exemple, qu’avec nos suppositions actuelles, le point de données à 1,761 est beaucoup plus de chances d’être rouge (0,189) que bleu (0,00003).
Pour chaque point de données, nous pouvons transformer ces deux valeurs de vraisemblance en pondérations ( étape 3 ) afin qu'elles totalisent 1, comme suit:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Avec nos estimations actuelles et nos poids nouvellement calculés, nous pouvons maintenant calculer nouveau estimations de la moyenne et de l'écart type des groupes rouge et bleu ( étape 4 ).
Nous calculons deux fois la moyenne et l'écart type en utilisant tous points de données, mais avec des pondérations différentes: une fois pour les poids rouges et une fois pour les poids bleus.
Le bit clé de l'intuition est que plus le poids d'une couleur sur un point de données est important, plus le point de données influence les estimations suivantes pour les paramètres de cette couleur. Cela a pour effet de "tirer" les paramètres dans la bonne direction.
def estimate_mean(data, weight):
"""
For each data point, multiply the point by the probability it
was drawn from the colour's distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among our data points.
"""
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
"""
For each data point, multiply the point's squared difference
from a mean value by the probability it was drawn from
that distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among the values for the difference of
each data point from the mean.
This is the estimate of the variance, take the positive square
root to find the standard deviation.
"""
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
Nous avons de nouvelles estimations pour les paramètres. Pour les améliorer à nouveau, nous pouvons revenir à l’étape 2 et répéter le processus. Nous le faisons jusqu'à ce que les estimations convergent ou après qu'un certain nombre d'itérations ait été effectué ( étape 5 ).
Pour nos données, les cinq premières itérations de ce processus ressemblent à ceci (les itérations récentes ont une apparence plus forte):
Nous voyons que les moyennes convergent déjà sur certaines valeurs et que les formes des courbes (régies par l’écart type) deviennent également plus stables.
Si nous continuons pendant 20 itérations, nous nous retrouvons avec ce qui suit:
Le processus EM a convergé vers les valeurs suivantes, qui s'avèrent très proches des valeurs réelles (où on peut voir les couleurs - pas de variables cachées):
| EM guess | Actual | Delta
----------+----------+--------+-------
Red mean | 2.910 | 2.802 | 0.108
Red std | 0.854 | 0.871 | -0.017
Blue mean | 6.838 | 6.932 | -0.094
Blue std | 2.227 | 2.195 | 0.032
Dans le code ci-dessus, vous avez peut-être remarqué que la nouvelle estimation de l'écart type a été calculée à l'aide de l'estimation de la moyenne de l'itération précédente. En fin de compte, peu importe si nous calculons d'abord une nouvelle valeur pour la moyenne, car nous ne faisons que rechercher la variance (pondérée) des valeurs autour d'un point central. Nous verrons encore les estimations pour les paramètres converger.
EM est un algorithme permettant de maximiser une fonction de vraisemblance lorsque certaines des variables de votre modèle ne sont pas observées (c'est-à-dire lorsque vous avez des variables latentes).
Vous pourriez vous poser la question suivante: si nous essayons simplement de maximiser une fonction, pourquoi ne pas simplement utiliser les mécanismes existants pour maximiser une fonction. Eh bien, si vous essayez de maximiser cela en prenant des dérivés et en les mettant à zéro, vous constaterez que dans de nombreux cas, les conditions de premier ordre n'ont pas de solution. Pour résoudre les paramètres de votre modèle, vous devez connaître la distribution de vos données non observées. mais la distribution de vos données non observées est fonction des paramètres de votre modèle.
E-M essaie de contourner ce problème en devinant de manière itérative une distribution pour les données non observées, puis en estimant les paramètres du modèle en maximisant quelque chose qui constitue une limite inférieure de la fonction de vraisemblance réelle, et en répétant jusqu'à la convergence:
L'algorithme EM
Commencez par deviner les valeurs des paramètres de votre modèle
E-step: pour chaque point de données ayant des valeurs manquantes, utilisez l’équation de votre modèle pour résoudre la distribution des données manquantes en fonction de votre estimation actuelle des paramètres du modèle et des données observées (notez que vous résolvez une distribution pour chaque élément manquant.) valeur, pas pour la valeur attendue). Maintenant que nous avons une distribution pour chaque valeur manquante, nous pouvons calculer la expectation de la fonction de vraisemblance par rapport aux variables non observées. Si notre estimation du paramètre de modèle était correcte, cette probabilité probable sera la probabilité réelle de nos données observées; si les paramètres n'étaient pas corrects, ce sera simplement une limite inférieure.
Étape M: Maintenant que nous avons une fonction de vraisemblance attendue ne contenant aucune variable non observée, maximisez la fonction comme vous le feriez dans le cas parfaitement observé pour obtenir une nouvelle estimation des paramètres de votre modèle.
Répétez jusqu'à la convergence.
Voici une recette simple pour comprendre l'algorithme d'optimisation des attentes:
1 - Lisez ceci document de travail sur l'EM par Do et Batzoglou.
2 - Vous avez peut-être des points d'interrogation dans la tête, jetez un coup d'œil aux explications fournies dans cet échange de piles mathématiques page .
3 - Regardez ce code que j'ai écrit dans Python et qui explique l'exemple présenté dans l'article du didacticiel sur la ME du point 1:
Avertissement: Le code peut être compliqué/sous-optimal, car je ne suis pas un développeur Python. Mais ça fait le travail.
import numpy as np
import math
#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* ####
def get_mn_log_likelihood(obs,probs):
""" Return the (log)likelihood of obs, given the probs"""
# Multinomial Distribution Log PMF
# ln (pdf) = multinomial coeff * product of probabilities
# ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]
multinomial_coeff_denom= 0
prod_probs = 0
for x in range(0,len(obs)): # loop through state counts in each observation
multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
prod_probs = prod_probs + obs[x]*math.log(probs[x])
multinomial_coeff = math.log(math.factorial(sum(obs))) - multinomial_coeff_denom
likelihood = multinomial_coeff + prod_probs
return likelihood
# 1st: Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd: Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd: Coin A, {HTHHHHHTHH}, 8H,2T
# 4th: Coin B, {HTHTTTHHTT}, 4H,6T
# 5th: Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45
# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = Zip(head_counts,tail_counts)
# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50
# E-M begins!
delta = 0.001
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
expectation_A = np.zeros((5,2), dtype=float)
expectation_B = np.zeros((5,2), dtype=float)
for i in range(0,len(experiments)):
e = experiments[i] # i'th experiment
ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B
weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A
weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B
expectation_A[i] = np.dot(weightA, e)
expectation_B[i] = np.dot(weightB, e)
pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A));
pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B));
improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
j = j+1
Techniquement, le terme "EM" est un peu sous-spécifié, mais je suppose que vous faites référence à la technique d’analyse par grappes gaussienne pour la modélisation de mélanges, c’est-à-dire une instance du général Principe EM.
En fait, l'analyse de cluster EM n'est pas un classificateur. Je sais que certaines personnes considèrent le clustering comme une "classification non supervisée", mais en réalité, l'analyse par cluster est quelque chose de très différent.
La principale différence, et le gros malentendu que l'on a toujours sur la classification avec l'analyse en grappes, est la suivante: dans l'analyse en grappes, il n'y a pas de "solution correcte". C'est une méthode de découverte de connaissances , elle est en fait destinée à trouver quelque chose new ! Cela rend l'évaluation très délicate. Elle est souvent évaluée en utilisant une classification connue en tant que référence, mais cela n’est pas toujours approprié: la classification que vous avez peut refléter ou non le contenu des données.
Permettez-moi de vous donner un exemple: vous avez un grand ensemble de données de clients, y compris des données de genre. Une méthode qui divise cet ensemble de données en "hommes" et "femmes" est optimale lorsque vous le comparez aux classes existantes. En termes de "prédiction", c'est bien, car pour les nouveaux utilisateurs, vous pouvez maintenant prédire leur sexe. En termes de "découverte de connaissances", c'est en fait une mauvaise idée, car vous vouliez découvrir une nouvelle structure dans les données. Une méthode qui, par exemple, scinder les données en personnes âgées et enfants entraînerait toutefois un score aussi mauvais que possible par rapport à la classe hommes/femmes. Cependant, ce serait un excellent résultat de regroupement (si l'âge n'était pas indiqué).
Revenons maintenant à EM. Cela suppose essentiellement que vos données sont composées de plusieurs distributions normales multivariées (notez qu'il s'agit d'une hypothèse forte très , en particulier lorsque vous fixez le nombre de clusters !) Il essaie ensuite de trouver un modèle local optimal pour cela en améliorant alternativement le modèle et l'affectation d'objets au modèle.
Pour de meilleurs résultats dans un contexte de classification, choisissez un nombre de clusters supérieur au nombre de classes, ou appliquez même le clustering à single classes uniquement (pour savoir s’il existe une structure dans la classe!).
Supposons que vous vouliez former un classificateur pour distinguer les "voitures", les "vélos" et les "camions". Il est peu utile de supposer que les données consistent en exactement 3 distributions normales. Cependant, vous pouvez supposer qu'il existe plusieurs types de voitures (ainsi que des camions et des vélos). Ainsi, au lieu de former un classificateur pour ces trois classes, vous regroupez les voitures, les camions et les vélos en 10 groupes (ou peut-être 10 voitures, 3 camions et 3 vélos, peu importe), puis vous entraînez un classificateur pour distinguer ces 30 classes, puis fusionner le résultat de la classe avec les classes d'origine. Vous pouvez également découvrir qu’il existe un groupe particulièrement difficile à classer, par exemple Trikes. Ce sont des voitures et des motos. Ou les camions de livraison, qui ressemblent davantage à des voitures surdimensionnées qu'à des camions.
La réponse acceptée fait référence au Chuong EM Paper , qui fait un travail décent en expliquant les EM. Il existe également un vidéo sur youtube qui explique le document de manière plus détaillée.
Pour récapituler, voici le scénario:
1st: {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd: {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd: {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th: {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th: {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails
Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.
We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.
Dans le cas de la question du premier procès, intuitivement, nous penserions que B l’a générée puisque la proportion de têtes correspond très bien au biais de B ... mais cette valeur n’est qu’une estimation, nous ne pouvons donc pas en être sûrs.
Dans cet esprit, j'aime bien penser à la solution de SE:
C'est peut-être une simplification excessive (ou même fondamentalement fausse à certains niveaux), mais j'espère que cela vous aidera au niveau intuitif!
Les autres réponses étant bonnes, je vais essayer de donner un autre point de vue et d’aborder la partie intuitive de la question.
algorithme EM (Expectation-Maximization) est une variante d'une classe d'algorithmes itératifs utilisant dualité
Extrait (c'est moi qui souligne):
En mathématiques, une dualité, en général, traduit des concepts, des théorèmes ou des structures mathématiques en d’autres concepts, théorèmes ou structures, de manière personnalisée, souvent (mais pas toujours) au moyen d’une opération d’involution: si le dual de A est B, alors le dual de B est A. De telles involutions ont parfois des points fixes , de sorte que le dual de A est A lui-même
Habituellement, un dual B d'un objet A est lié à A d'une manière qui préserve quelque symétrie ou compatibilité. Par exemple, AB = const
Voici des exemples d’algorithmes itératifs utilisant la dualité (au sens précédent):
De manière similaire, l'algorithme EM peut également être vu comme deux étapes de maximisation doubles :
.. [EM] est vu comme maximisant une fonction conjointe des paramètres et de la distribution sur les variables non observées. Le pas E maximise cette fonction par rapport à la distribution sur les variables non observées; le M-step par rapport aux paramètres ..
Dans un algorithme itératif utilisant la dualité, il y a l'hypothèse explicite (ou implicite) d'un point de convergence d'équilibre (ou fixe) (pour EM cela est prouvé à l'aide de l'inégalité de Jensen).
Le contour de tels algorithmes est donc:
Notez que lorsqu'un tel algorithme converge vers un optimum (global), il a trouvé une configuration qui est meilleure dans les deux sens (c'est-à-dire dans le domaine/paramètres x et dans le y domaine/paramètres). Cependant, l'algorithme peut simplement trouver un local optimum et non le global optimum.
je dirais que ceci est la description intuitive du contour de l'algorithme
Pour les arguments statistiques et les applications, d’autres réponses ont donné de bonnes explications (cochez aussi les références dans cette réponse)
EM est utilisé pour maximiser la probabilité d'un modèle Q avec des variables latentes Z.
C'est une optimisation itérative.
theta <- initial guess for hidden parameters
while not converged:
#e-step
Q(theta'|theta) = E[log L(theta|Z)]
#m-step
theta <- argmax_theta' Q(theta'|theta)
e-step: à partir de l’estimation actuelle de Z, calculez la fonction de logique attendue
m-step: trouver le thêta qui maximise ce Q
Exemple GMM:
e-step: estimation des assignations d'étiquettes pour chaque point de données en fonction de l'estimation actuelle des paramètres gmm
m-step: maximiser une nouvelle thêta compte tenu des nouvelles affectations d'étiquettes
K-means est également un algorithme EM et il existe de nombreuses animations explicatives sur K-means.
En utilisant le même article de Do et Batzoglou cité dans la réponse de Zhubarb, j'ai implémenté EM pour ce problème dans Java. Les commentaires de sa réponse montrent que l'algorithme reste bloqué à un optimum local, ce qui se produit également avec mon implémentation si les paramètres thetaA et thetaB sont les mêmes.
Ci-dessous se trouve la sortie standard de mon code, montrant la convergence des paramètres.
thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960
Voici mon Java implémentation de EM pour résoudre le problème dans (Do et Batzoglou, 2008). La partie essentielle de l’implémentation est la boucle permettant d’exécuter EM jusqu’à ce que les paramètres convergent.
private Parameters _parameters;
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
Vous trouverez ci-dessous le code complet.
import Java.util.*;
/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
double _thetaA = 0.0; // Probability of heads for coin A.
double _thetaB = 0.0; // Probability of heads for coin B.
double _delta = 0.00001;
public Parameters(double thetaA, double thetaB)
{
_thetaA = thetaA;
_thetaB = thetaB;
}
/*************************************************************************
Returns true if this parameter is close enough to another parameter
(typically the estimated parameter coming from the maximization step).
*************************************************************************/
public boolean converged(Parameters other)
{
if (Math.abs(_thetaA - other._thetaA) < _delta &&
Math.abs(_thetaB - other._thetaB) < _delta)
{
return true;
}
return false;
}
public double getThetaA()
{
return _thetaA;
}
public double getThetaB()
{
return _thetaB;
}
public String toString()
{
return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
}
}
/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
double _numHeads = 0;
double _numTails = 0;
public Observation(String s)
{
for (int i = 0; i < s.length(); i++)
{
char c = s.charAt(i);
if (c == 'H')
{
_numHeads++;
}
else if (c == 'T')
{
_numTails++;
}
else
{
throw new RuntimeException("Unknown character: " + c);
}
}
}
public Observation(double numHeads, double numTails)
{
_numHeads = numHeads;
_numTails = numTails;
}
public double getNumHeads()
{
return _numHeads;
}
public double getNumTails()
{
return _numTails;
}
public String toString()
{
return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
}
}
/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
// Current estimated parameters.
private Parameters _parameters;
// Observations from the trials. These observations are set once.
private final List<Observation> _observations;
// Estimated observations per coin. These observations are the output
// of the expectation step.
private List<Observation> _expectedObservationsForCoinA;
private List<Observation> _expectedObservationsForCoinB;
private static Java.io.PrintStream o = System.out;
/*************************************************************************
Principal constructor.
@param observations The observations from the trial.
@param parameters The initial guessed parameters.
*************************************************************************/
public EM(List<Observation> observations, Parameters parameters)
{
_observations = observations;
_parameters = parameters;
}
/*************************************************************************
Run EM until parameters converge.
*************************************************************************/
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
o.printf("%s\n", estimatedParameters);
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
/*************************************************************************
Given the observations and current estimated parameters, compute new
estimated completions (distribution over the classes) and observations.
*************************************************************************/
private void expectation()
{
_expectedObservationsForCoinA = new ArrayList<Observation>();
_expectedObservationsForCoinB = new ArrayList<Observation>();
for (Observation observation : _observations)
{
int numHeads = (int)observation.getNumHeads();
int numTails = (int)observation.getNumTails();
double probabilityOfObservationForCoinA=
binomialProbability(10, numHeads, _parameters.getThetaA());
double probabilityOfObservationForCoinB=
binomialProbability(10, numHeads, _parameters.getThetaB());
double normalizer = probabilityOfObservationForCoinA +
probabilityOfObservationForCoinB;
// Compute the completions for coin A and B (i.e. the probability
// distribution of the two classes, summed to 1.0).
double completionCoinA = probabilityOfObservationForCoinA /
normalizer;
double completionCoinB = probabilityOfObservationForCoinB /
normalizer;
// Compute new expected observations for the two coins.
Observation expectedObservationForCoinA =
new Observation(numHeads * completionCoinA,
numTails * completionCoinA);
Observation expectedObservationForCoinB =
new Observation(numHeads * completionCoinB,
numTails * completionCoinB);
_expectedObservationsForCoinA.add(expectedObservationForCoinA);
_expectedObservationsForCoinB.add(expectedObservationForCoinB);
}
}
/*************************************************************************
Given new estimated observations, compute new estimated parameters.
*************************************************************************/
private Parameters maximization()
{
double sumCoinAHeads = 0.0;
double sumCoinATails = 0.0;
double sumCoinBHeads = 0.0;
double sumCoinBTails = 0.0;
for (Observation observation : _expectedObservationsForCoinA)
{
sumCoinAHeads += observation.getNumHeads();
sumCoinATails += observation.getNumTails();
}
for (Observation observation : _expectedObservationsForCoinB)
{
sumCoinBHeads += observation.getNumHeads();
sumCoinBTails += observation.getNumTails();
}
return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));
//o.printf("parameters: %s\n", _parameters);
}
/*************************************************************************
Since the coin-toss experiment posed in this article is a Bernoulli trial,
use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
*************************************************************************/
private static double binomialProbability(int n, int k, double p)
{
double q = 1.0 - p;
return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
}
private static long nChooseK(int n, int k)
{
long numerator = 1;
for (int i = 0; i < k; i++)
{
numerator = numerator * n;
n--;
}
long denominator = factorial(k);
return (long)(numerator / denominator);
}
private static long factorial(int n)
{
long result = 1;
for (; n >0; n--)
{
result = result * n;
}
return result;
}
/*************************************************************************
Entry point into the program.
*************************************************************************/
public static void main(String argv[])
{
// Create the observations and initial parameter guess
// from the (Do and Batzoglou, 2008) article.
List<Observation> observations = new ArrayList<Observation>();
observations.add(new Observation("HTTTHHTHTH"));
observations.add(new Observation("HHHHTHHHHH"));
observations.add(new Observation("HTHHHHHTHH"));
observations.add(new Observation("HTHTTTHHTT"));
observations.add(new Observation("THHHTHHHTH"));
Parameters initialParameters = new Parameters(0.6, 0.5);
EM em = new EM(observations, initialParameters);
Parameters finalParameters = em.run();
o.printf("Final result:\n%s\n", finalParameters);
}
}