You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@mahout.apache.org by "Jeff Eastman (JIRA)" <ji...@apache.org> on 2008/11/13 05:51:44 UTC
[jira] Issue Comment Edited: (MAHOUT-30) dirichlet process
implementation
[ https://issues.apache.org/jira/browse/MAHOUT-30?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=12647186#action_12647186 ]
jeastman edited comment on MAHOUT-30 at 11/12/08 8:50 PM:
--------------------------------------------------------------
I refactored again and was able eliminate materializing of the posterior {{data}} sets by adding {{observe()}} and {{computeParameters()}} operations to {{Model}}. The idea is that all models begin in their prior state and are asked to observe each sample that is assigned to them. Then, before {{pdf()}} is called on them in the next iteration a call to {{computeParameters()}} finalizes the parameters once and turns the model into a posterior model. I also compute {{counts}} on the fly to eliminate materializing {{z}} altogether. I hope I didn't throw the baby out with the bath water.
Finally, I introduced a {{DirichletState}} bean to hold the models, dirichlet distribution and the mixture, simplifying the arguments and, I think, fixing a bug in the earlier refactoring. The algorithm runs over 10,000 points and produces the following outputs (prior() indicates a model with no observations, n is the number of observations, m the mean and sd the std):
Generating 4000 samples m=[1.0, 1.0] sd=3.0
Generating 3000 samples m=[1.0, 0.0] sd=0.1
Generating 3000 samples m=[0.0, 1.0] sd=0.1
* sample[0]= [prior(), normal(n=6604 m=[0.67, 0.63] sd=1.11), normal(n=86 m=[0.77, 2.81] sd=2.15), prior(), normal(n=242 m=[2.89, 1.67] sd=2.14), normal(n=2532 m=[0.53, 0.55] sd=0.69), normal(n=339 m=[0.99, 1.70] sd=2.18), normal(n=77 m=[0.53, 0.47] sd=0.51), normal(n=119 m=[0.36, 0.47] sd=2.85), normal(n=1 m=[0.00, 0.00] sd=0.33)]
* sample[1]= [prior(), normal(n=6626 m=[0.62, 0.54] sd=0.91), normal(n=137 m=[0.51, 2.99] sd=1.56), normal(n=2 m=[0.57, 0.25] sd=0.70), normal(n=506 m=[2.55, 0.93] sd=1.73), normal(n=1573 m=[0.38, 0.60] sd=0.50), normal(n=848 m=[0.81, 1.59] sd=2.11), normal(n=67 m=[0.76, 0.31] sd=0.45), normal(n=240 m=[0.73, 0.31] sd=2.24), normal(n=1 m=[0.00, 0.00] sd=0.98)]
* sample[2]= [prior(), normal(n=5842 m=[0.67, 0.39] sd=0.73), normal(n=157 m=[0.73, 3.12] sd=1.14), prior(), normal(n=655 m=[2.32, 0.64] sd=1.60), normal(n=1439 m=[0.00, 1.00] sd=0.33), normal(n=1439 m=[0.78, 1.53] sd=1.89), normal(n=66 m=[0.96, -0.04] sd=0.24), normal(n=399 m=[0.63, -0.03] sd=1.99), normal(n=3 m=[-0.07, 0.76] sd=0.41)]
{code:title=Model}
/**
* A model is a probability distribution over observed data points and allows
* the probability of any data point to be computed.
*/
public interface Model<Observation> {
/**
* Observe the given observation, retaining information about it
*
* @param x an Observation from the posterior
*/
public abstract void observe(Observation x);
/**
* Compute a new set of posterior parameters based upon the Observations
* that have been observed since my creation
*/
public abstract void computeParameters();
/**
* Return the probability that the observation is described by this model
*
* @param x an Observation from the posterior
* @return the probability that x is in z
*/
public abstract double pdf(Observation x);
}
{code}
{code:title=DirichletCluster}
/**
* Initialize the variables and run the iterations to assign the sample data
* points to a computed number of clusters
*
* @return a List<List<Model<Observation>>> of the observed models
*/
public List<List<Model<Observation>>> dirichletCluster() {
DirichletState<Observation> state = initializeState();
// create a posterior sample list to collect results
List<List<Model<Observation>>> clusterSamples = new ArrayList<List<Model<Observation>>>();
// now iterate
for (int iteration = 0; iteration < maxIterations; iteration++)
iterate(state, iteration, clusterSamples);
return clusterSamples;
}
/**
* Initialize the state of the computation
*
* @return the DirichletState
*/
private DirichletState<Observation> initializeState() {
// get initial prior models
List<Model<Observation>> models = createPriorModels();
// create the initial distribution.
DirichletDistribution distribution = new DirichletDistribution(maxClusters,
alpha_0, dist);
// mixture parameters are sampled from the Dirichlet distribution.
Vector mixture = distribution.sample();
return new DirichletState<Observation>(models, distribution, mixture);
}
/**
* Create a list of prior models
* @return the Observation
*/
private List<Model<Observation>> createPriorModels() {
List<Model<Observation>> models = new ArrayList<Model<Observation>>();
for (int k = 0; k < maxClusters; k++) {
models.add(modelFactory.sampleFromPrior());
}
return models;
}
/**
* Perform one iteration of the clustering process, updating the state for the next iteration
* @param state the DirichletState<Observation> of this iteration
* @param iteration the int iteration number
* @param clusterSamples a List<List<Model<Observation>>> that will be modified in each iteration
*/
private void iterate(DirichletState<Observation> state, int iteration,
List<List<Model<Observation>>> clusterSamples) {
// create new prior models
List<Model<Observation>> newModels = createPriorModels();
// initialize vector of membership counts
Vector counts = new DenseVector(maxClusters);
counts.assign(alpha_0 / maxClusters);
// iterate over the samples
for (Observation x : sampleData) {
// compute vector of probabilities x is described by each model
Vector pi = computeProbabilities(state, x);
// then pick one cluster by sampling a Multinomial distribution based upon them
// see: http://en.wikipedia.org/wiki/Multinomial_distribution
int model = dist.rmultinom(pi);
// ask the selected model to observe the datum
newModels.get(model).observe(x);
// record counts for the model
counts.set(model, counts.get(model) + 1);
}
// compute new model parameters based upon observations
for (Model<Observation> m : newModels)
m.computeParameters();
// update the state from the new models and counts
state.distribution.setAlpha(counts);
state.mixture = state.distribution.sample();
state.models = newModels;
// periodically add models to cluster samples after getting started
if ((iteration > burnin) && (iteration % thin == 0))
clusterSamples.add(state.models);
}
/**
* Compute a normalized vector of probabilities that x is described
* by each model using the mixture and the model pdfs
*
* @param state the DirichletState<Observation> of this iteration
* @param x an Observation
* @return the Vector of probabilities
*/
private Vector computeProbabilities(DirichletState<Observation> state,
Observation x) {
Vector pi = new DenseVector(maxClusters);
double max = 0;
for (int k = 0; k < maxClusters; k++) {
double p = state.mixture.get(k) * state.models.get(k).pdf(x);
pi.set(k, p);
if (max < p)
max = p;
}
// normalize the probabilities by largest observed value
pi.assign(new TimesFunction(), 1.0 / max);
return pi;
}
{code}
was (Author: jeastman):
I refactored again and was able eliminate materializing of the posterior {{data}} sets by adding {{observe()}} and {{computeParameters()}} operations to {{Model}}. The idea is that all models begin in their prior state and are asked to observe each sample that is assigned to them. Then, before {{pdf()}} is called on them in the next iteration a call to {{computeParameters()}} finalizes the parameters once and turns the model into a posterior model. I also compute {{counts}} on the fly to eliminate materializing {{z}} altogether. I hope I didn't throw the baby out with the bath water.
Finally, I introduced a {{DirichletState}} bean to hold the models, dirichlet distribution and the mixture, simplifying the arguments and, I think, fixing a bug in the earlier refactoring. The algorithm runs over 10,000 points and produces the following outputs (prior() indicates a model with no observations, n is the number of observations, m the mean and sd the std):
Generating 4000 samples m=[1.0, 1.0] sd=3.0
Generating 3000 samples m=[1.0, 0.0] sd=0.1
Generating 3000 samples m=[0.0, 1.0] sd=0.1
* sample[0]= [prior(), normal(n=6604 m=[0.67, 0.63] sd=1.11), normal(n=86 m=[0.77, 2.81] sd=2.15), prior(), normal(n=242 m=[2.89, 1.67] sd=2.14), normal(n=2532 m=[0.53, 0.55] sd=0.69), normal(n=339 m=[0.99, 1.70] sd=2.18), normal(n=77 m=[0.53, 0.47] sd=0.51), normal(n=119 m=[0.36, 0.47] sd=2.85), normal(n=1 m=[0.00, 0.00] sd=0.33)]
* sample[1]= [prior(), normal(n=6626 m=[0.62, 0.54] sd=0.91), normal(n=137 m=[0.51, 2.99] sd=1.56), normal(n=2 m=[0.57, 0.25] sd=0.70), normal(n=506 m=[2.55, 0.93] sd=1.73), normal(n=1573 m=[0.38, 0.60] sd=0.50), normal(n=848 m=[0.81, 1.59] sd=2.11), normal(n=67 m=[0.76, 0.31] sd=0.45), normal(n=240 m=[0.73, 0.31] sd=2.24), normal(n=1 m=[0.00, 0.00] sd=0.98)]
* sample[2]= [prior(), normal(n=5842 m=[0.67, 0.39] sd=0.73), normal(n=157 m=[0.73, 3.12] sd=1.14), prior(), normal(n=655 m=[2.32, 0.64] sd=1.60), normal(n=1439 m=[0.00, 1.00] sd=0.33), normal(n=1439 m=[0.78, 1.53] sd=1.89), normal(n=66 m=[0.96, -0.04] sd=0.24), normal(n=399 m=[0.63, -0.03] sd=1.99), normal(n=3 m=[-0.07, 0.76] sd=0.41)]
{code:title=Model}
/**
* A model is a probability distribution over observed data points and allows
* the probability of any data point to be computed.
*/
public interface Model<Observation> {
/**
* Observe the given observation, retaining information about it
*
* @param x an Observation from the posterior
*/
public abstract void observe(Observation x);
/**
* Compute a new set of posterior parameters based upon the Observations
* that have been observed since my creation
*/
public abstract void computeParameters();
/**
* Return the probability that the observation is described by this model
*
* @param x an Observation from the posterior
* @return the probability that x is in z
*/
public abstract double pdf(Observation x);
}
{code}
{code:title=DirichletCluster}
/**
* Initialize the variables and run the iterations to assign the sample data
* points to a computed number of clusters
*
* @return a List<List<Model<Observation>>> of the observed models
*/
public List<List<Model<Observation>>> dirichletCluster() {
DirichletState<Observation> state = initializeState();
// create a posterior sample list to collect results
List<List<Model<Observation>>> clusterSamples = new ArrayList<List<Model<Observation>>>();
// now iterate
for (int iteration = 0; iteration < maxIterations; iteration++)
iterate(state, iteration, clusterSamples);
return clusterSamples;
}
/**
* Initialize the state of the computation
*
* @return the DirichletState
*/
private DirichletState<Observation> initializeState() {
// get initial prior models
List<Model<Observation>> models = createPriorModels();
// create the initial distribution.
DirichletDistribution distribution = new DirichletDistribution(maxClusters,
alpha_0, dist);
// mixture parameters are sampled from the Dirichlet distribution.
Vector mixture = distribution.sample();
return new DirichletState<Observation>(models, distribution, mixture);
}
/**
* Create a list of prior models
* @return the Observation
*/
private List<Model<Observation>> createPriorModels() {
List<Model<Observation>> models = new ArrayList<Model<Observation>>();
for (int k = 0; k < maxClusters; k++) {
models.add(modelFactory.sampleFromPrior());
}
return models;
}
/**
* Perform one iteration of the clustering process, updating the state for the next iteration
* @param state the DirichletState<Observation> of this iteration
* @param iteration the int iteration number
* @param clusterSamples a List<List<Model<Observation>>> that will be modified in each iteration
*/
private void iterate(DirichletState<Observation> state, int iteration,
List<List<Model<Observation>>> clusterSamples) {
// create new prior models
List<Model<Observation>> newModels = createPriorModels();
// initialize vector of membership counts
Vector counts = new DenseVector(maxClusters);
counts.assign(alpha_0 / maxClusters);
// iterate over the samples
for (int i = 0; i < sampleData.size(); i++) {
Observation x = sampleData.get(i);
// compute vector of probabilities x is described by each model
Vector pi = computeProbabilities(state, x);
// then pick one cluster by sampling a Multinomial distribution based upon them
// see: http://en.wikipedia.org/wiki/Multinomial_distribution
int model = dist.rmultinom(pi);
// ask the selected model to observe the datum
newModels.get(model).observe(x);
// record counts for the model
counts.set(model, counts.get(model) + 1);
}
// compute new model parameters based upon observations
for (int k = 0; k < maxClusters; k++)
newModels.get(k).computeParameters();
// update the state from the new models and counts
state.distribution.setAlpha(counts);
state.mixture = state.distribution.sample();
state.models = newModels;
// periodically add models to cluster samples after getting started
if ((iteration > burnin) && (iteration % thin == 0))
clusterSamples.add(state.models);
}
/**
* Compute a normalized vector of probabilities that x is described
* by each model using the mixture and the model pdfs
*
* @param state the DirichletState<Observation> of this iteration
* @param x an Observation
* @return the Vector of probabilities
*/
private Vector computeProbabilities(DirichletState<Observation> state,
Observation x) {
Vector pi = new DenseVector(maxClusters);
double max = 0;
for (int k = 0; k < maxClusters; k++) {
double p = state.mixture.get(k) * state.models.get(k).pdf(x);
pi.set(k, p);
if (max < p)
max = p;
}
// normalize the probabilities by largest observed value
pi.assign(new TimesFunction(), 1.0 / max);
return pi;
}
{code}
> dirichlet process implementation
> --------------------------------
>
> Key: MAHOUT-30
> URL: https://issues.apache.org/jira/browse/MAHOUT-30
> Project: Mahout
> Issue Type: New Feature
> Components: Clustering
> Reporter: Isabel Drost
> Attachments: MAHOUT-30.patch
>
>
> Copied over from original issue:
> > Further extension can also be made by assuming an infinite mixture model. The implementation is only slightly more difficult and the result is a (nearly)
> > non-parametric clustering algorithm.
--
This message is automatically generated by JIRA.
-
You can reply to this email to add a comment to the issue online.