18 #ifndef __itkErrorBackPropagationLearningWithMomentum_txx
19 #define __itkErrorBackPropagationLearningWithMomentum_txx
30 template<
class LayerType,
class TTargetVector>
37 template<
class LayerType,
class TTargetVector>
42 typedef typename LayerInterfaceType::WeightSetType::Pointer WeightSetPointer;
43 WeightSetPointer outputweightset;
44 WeightSetPointer inputweightset;
45 outputweightset = layer->GetOutputWeightSet();
46 inputweightset = layer->GetInputWeightSet();
48 typedef typename LayerInterfaceType::ValuePointer InterfaceValuePointer;
49 InterfaceValuePointer DWvalues_m_1 = inputweightset->GetPrevDWValues();
50 InterfaceValuePointer DWvalues_m_2 = inputweightset->GetPrev_m_2DWValues();
51 InterfaceValuePointer currentdeltavalues = inputweightset->GetTotalDeltaValues();
52 InterfaceValuePointer DBValues = inputweightset->GetTotalDeltaBValues();
53 InterfaceValuePointer PrevDBValues = inputweightset->GetPrevDBValues();
55 int input_cols = inputweightset->GetNumberOfInputNodes();
56 int input_rows = inputweightset->GetNumberOfOutputNodes();
64 DB_temp.set_size(inputweightset->GetNumberOfOutputNodes());
68 DB.set_size(inputweightset->GetNumberOfOutputNodes());
69 DB_m_1.set_size(inputweightset->GetNumberOfOutputNodes());
73 DB_m_1.copy_in(PrevDBValues);
75 if (!inputweightset->GetFirstPass())
77 DW_m_1.copy_in(DWvalues_m_1);
79 if (!inputweightset->GetSecondPass())
81 DW_m_2.copy_in(DWvalues_m_2);
84 inputweightset->GetNumberOfOutputNodes(),
85 inputweightset->GetNumberOfInputNodes());
88 inputweightset->GetNumberOfInputNodes());
92 if (!inputweightset->GetFirstPass())
94 DW_temp1 = (DW_temp * lr *(1 - m_Momentum)) + (DW_m_1 * m_Momentum);
98 DW_temp1 = DW_temp*lr;
101 inputweightset->SetDWValues(DW_temp1.data_block());
102 inputweightset->SetDBValues(DB_temp.data_block());
105 template<
class LayerType,
class TTargetVector>
115 template<
class LayerType,
class TTargetVector>
120 os << indent <<
"ErrorBackPropagationLearningWithMomentum(" <<
this <<
")" << std::endl;
121 os << indent <<
"m_Momentum = " << m_Momentum << std::endl;
122 Superclass::PrintSelf( os, indent );