EM Clustering of the Iris Dataset
=================================
The Iris dataset contains data regarding 150 iris plants. Each plant is described by four continuous features.
The plants are divided into three classes: Iris setosa, Iris virginica and Iris versicolor, 50 plants for each class.
The Iris dataset was introduced by Ronald Fisher in 1936 as an example of linear discriminant analysis and has become a standard dataset for data analysis.
Here we use it as an example of clustering to showcase the data analysis capabilities of SWISH and its integration with R. We aim at grouping the data in such a way that objects in the same group (called a cluster) are more similar to each other than to those in other groups. Since for the Iris dataset a natural grouping is provided, we aim at rediscovering that grouping using clustering.
A clustering algorithm that is particularly suitable for data described by continuous feature is EM clustering.
In applying such an algorithm, we assume that the data is generated by a mixture of components, where each component corresponds to a cluster and is associated with a probability distribution of the features. So we assume that the data is generated in this way: a component from a fixed and finite set is sampled and then the features are sampled from the distribution associated to the component.
Here we assume that the features given the component are independent random variables with a Gaussian distribution.
So the joint probability distribution of the features given the component is the product of four Gaussians.
The aim of EM clustering is to find the probability distribution over the components (a discrete distribution) and the parameters of the Gaussians over the features given the components.
For the Iris dataset, we assume that there are three components/clusters, so the distribution over the components has two free parameters. For the distributions over features we have two parameters (mean and variance) for each cluster and each feature, so overall 24 parameters.
EM clustering applies the EM algorithm: it first assigns random values to the parameters and then computes, for each individual and each cluster, the probability of the individual of belonging to the cluster (expectation step).
Given these probabilities, it is possible to estimate the value of the parameters that maximize the likelihood that the data is generated by that model (maximization step, using weighted relative frequency). Then the expectation and maximization steps are repeated. It is possible to prove that during this cycle the likelihood of the data always increases. The cycle is performed a fixed number of times or until the likelihood does not increase any more.
In this notebook, we describe a Prolog implementation of EM applied to the Iris dataset.
The predicate =em/1= performs the EM algorithm with the assumption
that there are 3 clusters and that the four features of the dataset are independent given the cluster.
To show the quality of the clustering, =em/1= performs a Principal Component Analysis (PCA)
of the dataset and draws a scatter plot using the first two
principal components (those that accounts for as much of the variability in
the data as possible). The color of the points indicates their class.
Two plots are drawn, one with the original classes of the dataset and
one with the clusters assigned by the EM algorithm, so that the two grouping
can be compared.
The two groupings should be similar.
=em/1= prints the log-likelihood of the dataset (LL) at each iteration. The values increase, showing that
the clustering improves at each iteration.
References
----------
For EM clustering, see
Witten, Ian H., and Eibe Frank.
Data Mining: Practical machine learning tools and techniques.
Morgan Kaufmann, 2005.
and
https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm#Gaussian_mixture
For the iris dataset, see
https://en.wikipedia.org/wiki/Iris_flower_data_set
For the principal component analysis,see
https://en.wikipedia.org/wiki/Principal_component_analysis
Author: Fabrizio Riguzzi
http://ds.ing.unife.it/~friguzzi/
Running the clustering
----------------------
EM clustering is started by calling =em/1= with the number of iterations, for example 10.
Running em(10) produces two scatter plots that should be similar, showing that EM clustering successfully recovered the natural classes of the examples.
em(10).
Exercise
--------
As an exercise for the brave, develop a general version of the EM algorithm that is applicable to any dataset, not limited to three clusters and four features.
Code Description
----------------
The =em/1= predicate first collect all the data, then it performs the PCA on it and draws a scatter plot using the first two components, with the color indicating the real class (Iris setosa, Iris virginica or Iris versicolor).
Then it computes the mean and variance of the four features over all the data and draws samples for the initial values of the parameters: the mean of the Guassians for each feature and cluster is sampled from a Gaussian with mean the mean of the feature over all data and variance 1.0. The distribution over components is taken initially to be `[c1:0.2,c2:0.5,c3:0.3]`.
The EM cycle is entered using predicate =em_it/7=.
Then the individual are classified into the most probable cluster (=find_most_likely/2=), the original labels are replaced by the new ones and a PCA is applied on the new dataset, drawing a scatter plot using the first two components and indicating the cluster with color.
:- use_module(library(clpfd), [transpose/2]).
:- use_module(library(apply)).
em(It):-
findall([A,B,C,D,Type],data(A,B,C,D,Type),LA),
data_stats(LA,DM1,DM2,DM3,DM4,V1,V2,V3,V4),
pca(LA),
findall(M,(between(1,3,_),gauss(DM1,1.0,M)),[M11,M21,M31]),
findall(M,(between(1,3,_),gauss(DM2,1.0,M)),[M12,M22,M32]),
findall(M,(between(1,3,_),gauss(DM3,1.0,M)),[M13,M23,M33]),
findall(M,(between(1,3,_),gauss(DM4,1.0,M)),[M14,M24,M34]),
em_it(0,It,LA,par([c1:0.2,c2:0.5,c3:0.3],
[M11,M12,M13,M14,V1,V2,V3,V4],[M21,M22,M23,M24,V1,V2,V3,V4],
[M31,M32,M33,M34,V1,V2,V3,V4]),_Par,_,LA1),
find_most_likely(LA1,LPostCat),
maplist(replace_cat,LA,LPostCat,LPost),
pca(LPost).
=em_it/7= takes as input the number of iterations, the data, the current weighted assignment of examples to cluster and the current value of the parameters and returns the updated parameters and the updated weighted assignment of examples to clusters.
=em_it/7= first performs the expectation step (=expect/4=) and then the maximization step (=maxim/2=).
=expect/4= takes as input the data and the current parameters and returns the weighted assignments and the log likelihood.
It first computes the weight with which each example belongs to each cluster, then it normalizes (=normal/6=) the weights so that they sum to 1 for each example and then computes the log likelihood (=log_lik/7=).
em_it(It,It,_LA,Par,Par,LAOut,LAOut):-!.
em_it(It0,It,LA,Par0,Par,_,LAOut):-
expect(LA,Par0,LA1,LL),
write('LL '),write(LL),nl,
maxim(LA1,Par1),
It1 is It0+1,
em_it(It1,It,LA,Par1,Par,LA1,LAOut).
expect(LA,par([c1:P1,c2:P2,c3:P3],
G1,G2,G3),[L1,L2,L3],LL):-
maplist(weight(G1,P1),LA,L01),
maplist(weight(G2,P2),LA,L02),
maplist(weight(G3,P3),LA,L03),
normal(L01,L02,L03,L1,L2,L3),
log_lik(L01,L02,L3,P1,P2,P3,LL).
maxim([LA1,LA2,LA3],par([c1:P1,c2:P2,c3:P3],C1,C2,C3)):-
stats(LA1,W1,C1),
stats(LA2,W2,C2),
stats(LA3,W3,C3),
SW is W1+W2+W3,
P1 is W1/SW,
P2 is W2/SW,
P3 is W3/SW.
=normal/6= normalizes the weights so that they sum up to one for each example.
=weight/4= takes as input the current parameters for the Gaussians of a cluster, the probability of the cluster and an example and returns a couple example-weight.
=log_lik/7= takes as input the weighted examples for each clusters and the probability of each cluster and returns the log likelihood of the data.
normal(L01,L02,L03,L1,L2,L3):-
maplist(px,L01,L02,L03,L1,L2,L3).
px(X-W01,X-W02,X-W03,X-W1,X-W2,X-W3):-
S is W01+W02+W03,
W1 is W01/S,
W2 is W02/S,
W3 is W03/S.
weight([M1,M2,M3,M4,V1,V2,V3,V4],P,
[A,B,C,D,_],[A,B,C,D]-W):-
gauss_density_0(M1,V1,A,W1),
gauss_density_0(M2,V2,B,W2),
gauss_density_0(M3,V3,C,W3),
gauss_density_0(M4,V4,D,W4),
W is W1*W2*W3*W4*P.
log_lik(L1,L2,L3,P1,P2,P3,LL):-
foldl(combine(P1,P2,P3),L1,L2,L3,0,LL).
combine(P1,P2,P3,_-W1,_-W2,_-W3,LL0,LL):-
LLs is log(P1*W1+P2*W2+P3*W3),
LL is LL0+LLs.
=find_most_likely/2= takes as input a list with three elements, each being the list of weighted examples for a cluster, and returns the list of clusters that are most likely for each example.
=replace_cat/3= replaces the category (class) of each example with the given category =Cat=.
find_most_likely([L1,L2,L3],LC):-
maplist(classify,L1,L2,L3,LC).
classify(_-W1,_-W2,_-W3,Cat):-
find_max([W1,W2,W3],Cat).
find_max(Counts,MaxC):-
max_list(Counts,MV),
nth1(Max,Counts,MV),!,
concat_atom(['c',Max],MaxC).
replace_cat([A,B,C,D,_],Cat,[A,B,C,D,Cat]).
=stats/3= takes as input the dataset and the set of weights with which the examples belong to a cluster and computes the parameters of the Gaussian distributions of the features for that cluster. It uses standard formulas for mean and variance modified to take into account the weight.
stats(LA,SW,[M1,M2,M3,M4,V1,V2,V3,V4]):-
maplist(component_weight,LA,CA,CB,CC,CD),
weighted_mean(CA,M1,SW),
weighted_mean(CB,M2,_),
weighted_mean(CC,M3,_),
weighted_mean(CD,M4,_),
weighted_var(CA,M1,V1),
weighted_var(CB,M2,V2),
weighted_var(CC,M3,V3),
weighted_var(CD,M4,V4).
weighted_var(L,M,Var):-
foldl(agg_val_var(M),L,(0,0),(S,SW0)),
SW is SW0,
(SW=:=0.0->
write(zero_var),nl,
Var=1.0
;
Var is S/SW
).
weighted_mean(L,M,SW):-
foldl(agg_val,L,(0,0),(S,SW0)),
SW is SW0,
(SW =:=0.0->
write(zero_mean),nl,
M is 0
;
M is S/SW
).
agg_val(V -N,(S,SW),(S+V*N,SW+N)).
agg_val_var(M,V -N,(S,SW),(S+(M-V)^2*N,SW+N)).
component_weight([A,B,C,D]-W,A-W,B-W,C-W,D-W).
=data_stats/9= computes the mean and the variance of each component in the overall dataset.
data_stats(LA,M1,M2,M3,M4,V1,V2,V3,V4):-
maplist(component,LA,CA,CB,CC,CD),
mean(CA,M1),
mean(CB,M2),
mean(CC,M3),
mean(CD,M4),
variance(CA,M1,V1),
variance(CB,M2,V2),
variance(CC,M3,V3),
variance(CD,M4,V4).
mean(L,M):-
length(L,N),
sum_list(L,S),
M is S/N.
variance(L,M,Var):-
length(L,N),
foldl(agg_var(M),L,0,S),
Var is S/N.
component([A,B,C,D,_],A,B,C,D).
agg_var(M,V,S,S+(M-V)^2).
=gauss_density_0/4= and =gauss_density/4= compute the value of the Gaussian density of a given mean and variance at a certain point. The first treats specially the case of 0 variance.
=gauss/3= samples a value from a Gaussian given the mean and variance.
gauss_density_0(M,V,X,W):-
(V=:=0.0->
write(zero_var_gauss),
W=0.0
;
gauss_density(M,V,X,W)
).
gauss_density(Mean,Variance,S,D):-
StdDev is sqrt(Variance),
D is 1/(StdDev*sqrt(2*pi))*exp(-(S-Mean)*(S-Mean)/(2*Variance)).
gauss(Mean,Variance,S):-
number(Mean),!,
random(U1),
random(U2),
R is sqrt(-2*log(U1)),
Theta is 2*pi*U2,
S0 is R*cos(Theta),
StdDev is sqrt(Variance),
S is StdDev*S0+Mean.
=pca/1= computes the PCA of the dataset passed as input and draws a scatter plot of the examples using the first two principal component and the color for the class.
=pca/1= uses R in two ways. First, it uses the =prcomp= function of R to compute the principal components of the dataset. Then it uses =qplot= to draw a scatter plot of the dataset.
:- <- library("ggplot2").
pca(LA):-
length(LA,NP),
maplist(add_cat,LA,LCat,L),
L=[H|_],
length(H,Comp),
append(L,LLin),
D =..[c|LLin],
data<- matrix(D,ncol=Comp,byrow='TRUE'),
pc<- prcomp(data),
Data0<-pc["x"],
Data0=[Data1],
foldl(getn(NP),Data2,Data1,[]),!,
transpose(Data2,Data),
maplist(getx,Data,X),
maplist(gety,Data,Y),
x<- X,
y<-Y,
class<-LCat,
<-qplot(x, y, colour=class),
r_download,
nl.
getn(N,LN,L,Rest):-
length(LN,N),
append(LN,Rest,L).
getx([X,_,_,_],X).
gety([_,Y,_,_],Y).
add_cat([X,Y,Z,W,C],C,[X,Y,Z,W]).
The Iris dataset is represented using =data/5=.
% Iris dataset
data(5.1,3.5,1.4,0.2,'Iris-setosa').
data(4.9,3.0,1.4,0.2,'Iris-setosa').
data(4.7,3.2,1.3,0.2,'Iris-setosa').
data(4.6,3.1,1.5,0.2,'Iris-setosa').
data(5.0,3.6,1.4,0.2,'Iris-setosa').
data(5.4,3.9,1.7,0.4,'Iris-setosa').
data(4.6,3.4,1.4,0.3,'Iris-setosa').
data(5.0,3.4,1.5,0.2,'Iris-setosa').
data(4.4,2.9,1.4,0.2,'Iris-setosa').
data(4.9,3.1,1.5,0.1,'Iris-setosa').
data(5.4,3.7,1.5,0.2,'Iris-setosa').
data(4.8,3.4,1.6,0.2,'Iris-setosa').
data(4.8,3.0,1.4,0.1,'Iris-setosa').
data(4.3,3.0,1.1,0.1,'Iris-setosa').
data(5.8,4.0,1.2,0.2,'Iris-setosa').
data(5.7,4.4,1.5,0.4,'Iris-setosa').
data(5.4,3.9,1.3,0.4,'Iris-setosa').
data(5.1,3.5,1.4,0.3,'Iris-setosa').
data(5.7,3.8,1.7,0.3,'Iris-setosa').
data(5.1,3.8,1.5,0.3,'Iris-setosa').
data(5.4,3.4,1.7,0.2,'Iris-setosa').
data(5.1,3.7,1.5,0.4,'Iris-setosa').
data(4.6,3.6,1.0,0.2,'Iris-setosa').
data(5.1,3.3,1.7,0.5,'Iris-setosa').
data(4.8,3.4,1.9,0.2,'Iris-setosa').
data(5.0,3.0,1.6,0.2,'Iris-setosa').
data(5.0,3.4,1.6,0.4,'Iris-setosa').
data(5.2,3.5,1.5,0.2,'Iris-setosa').
data(5.2,3.4,1.4,0.2,'Iris-setosa').
data(4.7,3.2,1.6,0.2,'Iris-setosa').
data(4.8,3.1,1.6,0.2,'Iris-setosa').
data(5.4,3.4,1.5,0.4,'Iris-setosa').
data(5.2,4.1,1.5,0.1,'Iris-setosa').
data(5.5,4.2,1.4,0.2,'Iris-setosa').
data(4.9,3.1,1.5,0.1,'Iris-setosa').
data(5.0,3.2,1.2,0.2,'Iris-setosa').
data(5.5,3.5,1.3,0.2,'Iris-setosa').
data(4.9,3.1,1.5,0.1,'Iris-setosa').
data(4.4,3.0,1.3,0.2,'Iris-setosa').
data(5.1,3.4,1.5,0.2,'Iris-setosa').
data(5.0,3.5,1.3,0.3,'Iris-setosa').
data(4.5,2.3,1.3,0.3,'Iris-setosa').
data(4.4,3.2,1.3,0.2,'Iris-setosa').
data(5.0,3.5,1.6,0.6,'Iris-setosa').
data(5.1,3.8,1.9,0.4,'Iris-setosa').
data(4.8,3.0,1.4,0.3,'Iris-setosa').
data(5.1,3.8,1.6,0.2,'Iris-setosa').
data(4.6,3.2,1.4,0.2,'Iris-setosa').
data(5.3,3.7,1.5,0.2,'Iris-setosa').
data(5.0,3.3,1.4,0.2,'Iris-setosa').
data(7.0,3.2,4.7,1.4,'Iris-versicolor').
data(6.4,3.2,4.5,1.5,'Iris-versicolor').
data(6.9,3.1,4.9,1.5,'Iris-versicolor').
data(5.5,2.3,4.0,1.3,'Iris-versicolor').
data(6.5,2.8,4.6,1.5,'Iris-versicolor').
data(5.7,2.8,4.5,1.3,'Iris-versicolor').
data(6.3,3.3,4.7,1.6,'Iris-versicolor').
data(4.9,2.4,3.3,1.0,'Iris-versicolor').
data(6.6,2.9,4.6,1.3,'Iris-versicolor').
data(5.2,2.7,3.9,1.4,'Iris-versicolor').
data(5.0,2.0,3.5,1.0,'Iris-versicolor').
data(5.9,3.0,4.2,1.5,'Iris-versicolor').
data(6.0,2.2,4.0,1.0,'Iris-versicolor').
data(6.1,2.9,4.7,1.4,'Iris-versicolor').
data(5.6,2.9,3.6,1.3,'Iris-versicolor').
data(6.7,3.1,4.4,1.4,'Iris-versicolor').
data(5.6,3.0,4.5,1.5,'Iris-versicolor').
data(5.8,2.7,4.1,1.0,'Iris-versicolor').
data(6.2,2.2,4.5,1.5,'Iris-versicolor').
data(5.6,2.5,3.9,1.1,'Iris-versicolor').
data(5.9,3.2,4.8,1.8,'Iris-versicolor').
data(6.1,2.8,4.0,1.3,'Iris-versicolor').
data(6.3,2.5,4.9,1.5,'Iris-versicolor').
data(6.1,2.8,4.7,1.2,'Iris-versicolor').
data(6.4,2.9,4.3,1.3,'Iris-versicolor').
data(6.6,3.0,4.4,1.4,'Iris-versicolor').
data(6.8,2.8,4.8,1.4,'Iris-versicolor').
data(6.7,3.0,5.0,1.7,'Iris-versicolor').
data(6.0,2.9,4.5,1.5,'Iris-versicolor').
data(5.7,2.6,3.5,1.0,'Iris-versicolor').
data(5.5,2.4,3.8,1.1,'Iris-versicolor').
data(5.5,2.4,3.7,1.0,'Iris-versicolor').
data(5.8,2.7,3.9,1.2,'Iris-versicolor').
data(6.0,2.7,5.1,1.6,'Iris-versicolor').
data(5.4,3.0,4.5,1.5,'Iris-versicolor').
data(6.0,3.4,4.5,1.6,'Iris-versicolor').
data(6.7,3.1,4.7,1.5,'Iris-versicolor').
data(6.3,2.3,4.4,1.3,'Iris-versicolor').
data(5.6,3.0,4.1,1.3,'Iris-versicolor').
data(5.5,2.5,4.0,1.3,'Iris-versicolor').
data(5.5,2.6,4.4,1.2,'Iris-versicolor').
data(6.1,3.0,4.6,1.4,'Iris-versicolor').
data(5.8,2.6,4.0,1.2,'Iris-versicolor').
data(5.0,2.3,3.3,1.0,'Iris-versicolor').
data(5.6,2.7,4.2,1.3,'Iris-versicolor').
data(5.7,3.0,4.2,1.2,'Iris-versicolor').
data(5.7,2.9,4.2,1.3,'Iris-versicolor').
data(6.2,2.9,4.3,1.3,'Iris-versicolor').
data(5.1,2.5,3.0,1.1,'Iris-versicolor').
data(5.7,2.8,4.1,1.3,'Iris-versicolor').
data(6.3,3.3,6.0,2.5,'Iris-virginica').
data(5.8,2.7,5.1,1.9,'Iris-virginica').
data(7.1,3.0,5.9,2.1,'Iris-virginica').
data(6.3,2.9,5.6,1.8,'Iris-virginica').
data(6.5,3.0,5.8,2.2,'Iris-virginica').
data(7.6,3.0,6.6,2.1,'Iris-virginica').
data(4.9,2.5,4.5,1.7,'Iris-virginica').
data(7.3,2.9,6.3,1.8,'Iris-virginica').
data(6.7,2.5,5.8,1.8,'Iris-virginica').
data(7.2,3.6,6.1,2.5,'Iris-virginica').
data(6.5,3.2,5.1,2.0,'Iris-virginica').
data(6.4,2.7,5.3,1.9,'Iris-virginica').
data(6.8,3.0,5.5,2.1,'Iris-virginica').
data(5.7,2.5,5.0,2.0,'Iris-virginica').
data(5.8,2.8,5.1,2.4,'Iris-virginica').
data(6.4,3.2,5.3,2.3,'Iris-virginica').
data(6.5,3.0,5.5,1.8,'Iris-virginica').
data(7.7,3.8,6.7,2.2,'Iris-virginica').
data(7.7,2.6,6.9,2.3,'Iris-virginica').
data(6.0,2.2,5.0,1.5,'Iris-virginica').
data(6.9,3.2,5.7,2.3,'Iris-virginica').
data(5.6,2.8,4.9,2.0,'Iris-virginica').
data(7.7,2.8,6.7,2.0,'Iris-virginica').
data(6.3,2.7,4.9,1.8,'Iris-virginica').
data(6.7,3.3,5.7,2.1,'Iris-virginica').
data(7.2,3.2,6.0,1.8,'Iris-virginica').
data(6.2,2.8,4.8,1.8,'Iris-virginica').
data(6.1,3.0,4.9,1.8,'Iris-virginica').
data(6.4,2.8,5.6,2.1,'Iris-virginica').
data(7.2,3.0,5.8,1.6,'Iris-virginica').
data(7.4,2.8,6.1,1.9,'Iris-virginica').
data(7.9,3.8,6.4,2.0,'Iris-virginica').
data(6.4,2.8,5.6,2.2,'Iris-virginica').
data(6.3,2.8,5.1,1.5,'Iris-virginica').
data(6.1,2.6,5.6,1.4,'Iris-virginica').
data(7.7,3.0,6.1,2.3,'Iris-virginica').
data(6.3,3.4,5.6,2.4,'Iris-virginica').
data(6.4,3.1,5.5,1.8,'Iris-virginica').
data(6.0,3.0,4.8,1.8,'Iris-virginica').
data(6.9,3.1,5.4,2.1,'Iris-virginica').
data(6.7,3.1,5.6,2.4,'Iris-virginica').
data(6.9,3.1,5.1,2.3,'Iris-virginica').
data(5.8,2.7,5.1,1.9,'Iris-virginica').
data(6.8,3.2,5.9,2.3,'Iris-virginica').
data(6.7,3.3,5.7,2.5,'Iris-virginica').
data(6.7,3.0,5.2,2.3,'Iris-virginica').
data(6.3,2.5,5.0,1.9,'Iris-virginica').
data(6.5,3.0,5.2,2.0,'Iris-virginica').
data(6.2,3.4,5.4,2.3,'Iris-virginica').
data(5.9,3.0,5.1,1.8,'Iris-virginica').