Orfeo Toolbox  3.16
itkErrorBackPropagationLearningWithMomentum.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: Insight Segmentation & Registration Toolkit
4  Module: $RCSfile: itkErrorBackPropagationLearningWithMomentum.txx,v $
5  Language: C++
6  Date: $Date: 2009-01-24 21:33:49 $
7  Version: $Revision: 1.8 $
8 
9  Copyright (c) Insight Software Consortium. All rights reserved.
10  See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.
11 
12  This software is distributed WITHOUT ANY WARRANTY; without even
13  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
14  PURPOSE. See the above copyright notices for more information.
15 
16 =========================================================================*/
17 
18 #ifndef __itkErrorBackPropagationLearningWithMomentum_txx
19 #define __itkErrorBackPropagationLearningWithMomentum_txx
20 
22 #include <fstream>
23 
24 
25 namespace itk
26 {
27 namespace Statistics
28 {
29 
30 template<class LayerType, class TTargetVector>
33 {
34  m_Momentum = 0.9; //Default
35 }
36 
37 template<class LayerType, class TTargetVector>
38 void
41 {
42  typedef typename LayerInterfaceType::WeightSetType::Pointer WeightSetPointer;
43  WeightSetPointer outputweightset;
44  WeightSetPointer inputweightset;
45  outputweightset = layer->GetOutputWeightSet();
46  inputweightset = layer->GetInputWeightSet();
47 
48  typedef typename LayerInterfaceType::ValuePointer InterfaceValuePointer;
49  InterfaceValuePointer DWvalues_m_1 = inputweightset->GetPrevDWValues();
50  InterfaceValuePointer DWvalues_m_2 = inputweightset->GetPrev_m_2DWValues();
51  InterfaceValuePointer currentdeltavalues = inputweightset->GetTotalDeltaValues();
52  InterfaceValuePointer DBValues = inputweightset->GetTotalDeltaBValues();
53  InterfaceValuePointer PrevDBValues = inputweightset->GetPrevDBValues();
54 
55  int input_cols = inputweightset->GetNumberOfInputNodes();
56  int input_rows = inputweightset->GetNumberOfOutputNodes();
57 
58  vnl_matrix<ValueType> DW_m_1(input_rows, input_cols);
59  DW_m_1.fill(0);
60  vnl_matrix<ValueType> DW_m_2(input_rows, input_cols);
61  DW_m_2.fill(0);
62 
63  vnl_vector<ValueType> DB_temp;
64  DB_temp.set_size(inputweightset->GetNumberOfOutputNodes());
65  DB_temp.fill(0);
67  vnl_vector<ValueType> DB_m_1;
68  DB.set_size(inputweightset->GetNumberOfOutputNodes());
69  DB_m_1.set_size(inputweightset->GetNumberOfOutputNodes());
70  DB.fill(0);
71  DB_m_1.fill(0);
72  DB.copy_in(DBValues);
73  DB_m_1.copy_in(PrevDBValues);
74 
75  if (!inputweightset->GetFirstPass())
76  {
77  DW_m_1.copy_in(DWvalues_m_1);
78  }
79  if (!inputweightset->GetSecondPass())
80  {
81  DW_m_2.copy_in(DWvalues_m_2);
82  }
83  vnl_matrix<ValueType> DW_temp(currentdeltavalues,
84  inputweightset->GetNumberOfOutputNodes(),
85  inputweightset->GetNumberOfInputNodes());
86 
87  vnl_matrix<ValueType> DW_temp1(inputweightset->GetNumberOfOutputNodes(),
88  inputweightset->GetNumberOfInputNodes());
89  DW_temp1.fill(0);
90 
91  //Momentum
92  if (!inputweightset->GetFirstPass())
93  {
94  DW_temp1 = (DW_temp * lr *(1 - m_Momentum)) + (DW_m_1 * m_Momentum);
95  }
96  else
97  {
98  DW_temp1 = DW_temp*lr;
99  }
100  DB_temp=(DB*lr);
101  inputweightset->SetDWValues(DW_temp1.data_block());
102  inputweightset->SetDBValues(DB_temp.data_block());
103 }
104 
105 template<class LayerType, class TTargetVector>
106 void
108 ::Learn( LayerInterfaceType * itkNotUsed(layer), TTargetVector itkNotUsed(errors),ValueType itkNotUsed(lr))
109 {
110  //It appears that this interface should not be called.
111  //itkExceptionMacrto(<< "This should never be called");
112 }
113 
115 template<class LayerType, class TTargetVector>
116 void
118 ::PrintSelf( std::ostream& os, Indent indent ) const
119 {
120  os << indent << "ErrorBackPropagationLearningWithMomentum(" << this << ")" << std::endl;
121  os << indent << "m_Momentum = " << m_Momentum << std::endl;
122  Superclass::PrintSelf( os, indent );
123 }
124 
125 } // end namespace Statistics
126 } // end namespace itk
127 
128 #endif

Generated at Sat Feb 2 2013 23:35:57 for Orfeo Toolbox with doxygen 1.8.1.1