17 #ifndef __itkSampleClassifierWithMask_txx
18 #define __itkSampleClassifierWithMask_txx
23 namespace Statistics {
25 template<
class TSample,
class TMaskSample >
29 m_OtherClassLabel = 0;
33 template<
class TSample,
class TMaskSample >
38 Superclass::PrintSelf(os,indent);
40 os << indent <<
"Mask: ";
41 if ( m_Mask.IsNotNull() )
43 os << m_Mask << std::endl;
47 os <<
"not set." << std::endl;
50 os << indent <<
"SelectedClassLabels: ";
51 for (
unsigned int i = 0; i < m_SelectedClassLabels.size(); ++i )
53 os <<
" " << m_SelectedClassLabels[i];
56 os << indent <<
"OtherClassLabel: " << m_OtherClassLabel << std::endl;
59 template<
class TSample,
class TMaskSample >
70 template<
class TSample,
class TMaskSample >
76 typename TSample::ConstIterator iter = this->GetSample()->Begin();
77 typename TSample::ConstIterator end = this->GetSample()->End();
78 typename TSample::MeasurementVectorType measurements;
80 typename TMaskSample::Iterator m_iter = this->GetMask()->Begin();
84 std::vector< double > discriminantScores;
85 unsigned int numberOfClasses = this->GetNumberOfClasses();
86 discriminantScores.resize(numberOfClasses);
88 unsigned int classLabel;
90 this->GetDecisionRule();
91 typename Superclass::ClassLabelVectorType classLabels =
92 this->GetMembershipFunctionClassLabels();
94 if ( this->GetMask()->
Size() != this->GetSample()->
Size() )
96 itkExceptionMacro(
"The sizes of the mask sample and the input sample do not match.");
99 if ( classLabels.size() != this->GetNumberOfMembershipFunctions() )
103 measurements = iter.GetMeasurementVector();
104 if ( std::find(m_SelectedClassLabels.begin(),
105 m_SelectedClassLabels.end(),
106 m_iter.GetMeasurementVector()[0]) !=
107 m_SelectedClassLabels.end() )
109 for (i = 0; i < numberOfClasses; i++)
111 discriminantScores[i] =
112 (this->GetMembershipFunction(i))->Evaluate(measurements);
114 classLabel = rule->Evaluate(discriminantScores);
118 classLabel = m_OtherClassLabel;
120 output->
AddInstance(classLabel, iter.GetInstanceIdentifier());
129 measurements = iter.GetMeasurementVector();
130 if ( std::find(m_SelectedClassLabels.begin(),
131 m_SelectedClassLabels.end(),
132 m_iter.GetMeasurementVector()[0]) !=
133 m_SelectedClassLabels.end() )
135 for (i = 0; i < numberOfClasses; i++)
137 discriminantScores[i] =
138 (this->GetMembershipFunction(i))->Evaluate(measurements);
140 classLabel = rule->Evaluate(discriminantScores);
142 iter.GetInstanceIdentifier());
147 iter.GetInstanceIdentifier());