18 #ifndef __itkGaussianMixtureModelComponent_txx
19 #define __itkGaussianMixtureModelComponent_txx
26 namespace Statistics {
28 template<
class TSample >
32 m_MeanEstimator = MeanEstimatorType::New();
33 m_CovarianceEstimator = CovarianceEstimatorType::New();
34 m_GaussianDensityFunction = NativeMembershipFunctionType::New();
36 m_GaussianDensityFunction.GetPointer());
38 m_Covariance.SetIdentity();
41 template<
class TSample >
46 Superclass::PrintSelf(os, indent);
48 os << indent <<
"Mean: " << m_Mean << std::endl;
49 os << indent <<
"Covariance: " << m_Covariance << std::endl;
50 os << indent <<
"Mean Estimator: " << m_MeanEstimator << std::endl;
51 os << indent <<
"Covariance Estimator: " << m_CovarianceEstimator << std::endl;
52 os << indent <<
"GaussianDensityFunction: " << m_GaussianDensityFunction << std::endl;
55 template<
class TSample >
60 Superclass::SetSample(sample);
62 m_MeanEstimator->SetInputSample(sample);
63 m_CovarianceEstimator->SetInputSample(sample);
66 m_MeanEstimator->SetWeights(weights);
67 m_CovarianceEstimator->SetWeights(weights);
69 sample->GetMeasurementVectorSize();
70 m_GaussianDensityFunction->SetMeasurementVectorSize(
71 measurementVectorLength );
73 m_Covariance.SetSize( measurementVectorLength, measurementVectorLength );
74 m_Mean.Fill(NumericTraits< double >::NonpositiveMin());
75 m_Covariance.Fill(NumericTraits< double >::NonpositiveMin());
76 m_CovarianceEstimator->SetMean(&m_Mean);
77 m_GaussianDensityFunction->SetMean(&m_Mean);
80 template<
class TSample >
85 Superclass::SetParameters(parameters);
87 unsigned int paramIndex = 0;
93 this->GetSample()->GetMeasurementVectorSize();
95 for ( i = 0; i < measurementVectorSize; i++)
97 if ( m_Mean[i] != parameters[paramIndex] )
99 m_Mean[i] = parameters[paramIndex];
105 for ( i = 0; i < measurementVectorSize; i++ )
107 for ( j = 0; j < measurementVectorSize; j++ )
109 if ( m_Covariance.GetVnlMatrix().get(i, j) !=
110 parameters[paramIndex] )
112 m_Covariance.GetVnlMatrix().put(i, j, parameters[paramIndex]);
118 m_GaussianDensityFunction->SetCovariance(&m_Covariance);
119 this->AreParametersModified(changed);
123 template<
class TSample >
130 MeanType meanEstimate = *(m_MeanEstimator->GetOutput());
131 CovarianceType covEstimate = *(m_CovarianceEstimator->GetOutput());
134 double changes = 0.0;
136 this->GetSample()->GetMeasurementVectorSize();
138 for ( i = 0; i < measurementVectorSize; i++)
140 temp = m_Mean[i] - meanEstimate[i];
141 changes += temp * temp;
144 for ( i = 0; i < measurementVectorSize; i++ )
146 for ( j = 0; j < measurementVectorSize; j++ )
148 temp = m_Covariance.GetVnlMatrix().get(i, j) -
150 changes += temp * temp;
154 changes = vcl_sqrt(changes);
158 template<
class TSample >
164 this->GetSample()->GetMeasurementVectorSize();
166 this->AreParametersModified(
false);
168 m_MeanEstimator->Update();
173 bool changed =
false;
177 MeanType meanEstimate = *(m_MeanEstimator->GetOutput());
178 for ( i = 0; i < measurementVectorSize; i++)
180 temp = m_Mean[i] - meanEstimate[i];
181 changes = temp * temp;
182 changes = vcl_sqrt(changes);
183 if ( changes > this->GetMinimalParametersChange() )
192 m_Mean = *(m_MeanEstimator->GetOutput());
193 for ( i = 0; i < measurementVectorSize; i++)
195 parameters[paramIndex] = meanEstimate[i];
198 this->AreParametersModified(
true);
202 paramIndex = measurementVectorSize;
205 m_CovarianceEstimator->Update();
206 CovarianceType covEstimate = *(m_CovarianceEstimator->GetOutput());
208 for ( i = 0; i < measurementVectorSize; i++ )
210 for ( j = 0; j < measurementVectorSize; j++ )
214 changes = temp * temp;
215 changes = vcl_sqrt(changes);
216 if ( changes > this->GetMinimalParametersChange() )
225 m_Covariance = *(m_CovarianceEstimator->GetOutput());
226 for ( i = 0; i < measurementVectorSize; i++ )
228 for ( j = 0; j < measurementVectorSize; j++ )
230 parameters[paramIndex] = covEstimate.
GetVnlMatrix().get(i, j);
234 this->AreParametersModified(
true);
237 Superclass::SetParameters(parameters);
239 m_GaussianDensityFunction->SetCovariance( &m_Covariance );