17 #ifndef __itkWeightSetBase_txx
18 #define __itkWeightSetBase_txx
28 template<
class TMeasurementVector,
class TTargetVector>
37 m_NumberOfInputNodes = 0;
38 m_NumberOfOutputNodes = 0;
39 m_RandomGenerator = RandomVariateGeneratorType::New();
41 m_RandomGenerator->Initialize( randomSeed );
47 template<
class TMeasurementVector,
class TTargetVector>
53 template<
class TMeasurementVector,
class TTargetVector>
62 template<
class TMeasurementVector,
class TTargetVector>
67 m_NumberOfInputNodes = n + 1;
71 template<
class TMeasurementVector,
class TTargetVector>
76 return m_NumberOfInputNodes;
79 template<
class TMeasurementVector,
class TTargetVector>
84 m_NumberOfOutputNodes = n;
88 template<
class TMeasurementVector,
class TTargetVector>
93 return m_NumberOfOutputNodes;
96 template<
class TMeasurementVector,
class TTargetVector>
101 m_OutputValues.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
102 m_OutputValues.fill(0);
103 m_WeightMatrix.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
104 m_WeightMatrix.fill(0);
106 m_DW.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
108 m_DW_new.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
111 m_DW_m_1.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
113 m_DW_m_2.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
116 m_DB_new.set_size(m_NumberOfOutputNodes);
118 m_DB.set_size(m_NumberOfOutputNodes);
121 m_DB_m_1.set_size(m_NumberOfOutputNodes);
123 m_DB_m_2.set_size(m_NumberOfOutputNodes);
126 m_Del.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
128 m_Del_new.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
130 m_Del_m_1.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
132 m_Del_m_2.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
135 m_Delb.set_size(m_NumberOfOutputNodes);
137 m_Delb_new.set_size(m_NumberOfOutputNodes);
139 m_Delb_m_1.set_size(m_NumberOfOutputNodes);
141 m_Delb_m_2.set_size(m_NumberOfOutputNodes);
144 m_InputLayerOutput.set_size(1, m_NumberOfInputNodes - 1);
145 m_InputLayerOutput.fill(0);
148 template<
class TMeasurementVector,
class TTargetVector>
153 unsigned int num_rows = m_WeightMatrix.rows();
154 unsigned int num_cols = m_WeightMatrix.cols();
155 std::cout<<num_rows <<
" "<<num_cols<<std::endl;
156 std::cout<<
"conectivity matrix size = "<<m_ConnectivityMatrix.rows()<<
" "
157 << m_ConnectivityMatrix.cols()<<std::endl;
159 for (
unsigned int i = 0; i < num_rows; i++)
161 for (
unsigned int j = 0; j < num_cols; j++)
163 if(m_ConnectivityMatrix[i][j]==1)
165 m_WeightMatrix(i, j) = RandomWeightValue(-1*m_Range,m_Range);
169 m_WeightMatrix(i, j) = 0;
175 template<
class TMeasurementVector,
class TTargetVector>
180 return static_cast<ValueType>(m_RandomGenerator->GetUniformVariate(low,high));
183 template<
class TMeasurementVector,
class TTargetVector>
190 layeroutput.set_size(m_NumberOfInputNodes - 1);
191 layeroutput.copy_in(inputlayeroutputvalues);
192 m_InputLayerOutput.set_row(0, layeroutput);
195 template<
class TMeasurementVector,
class TTargetVector>
202 template<
class TMeasurementVector,
class TTargetVector>
208 m_Delb_new += m_Delb;
212 template<
class TMeasurementVector,
class TTargetVector>
220 m_Del.set_column( m_NumberOfInputNodes-1,v);
221 m_Del_new.set_column( m_NumberOfInputNodes-1,v);
225 template<
class TMeasurementVector,
class TTargetVector>
231 W_temp.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
235 m_WeightMatrix = W_temp;
241 template<
class TMeasurementVector,
class TTargetVector>
247 DW_temp.set_size(m_NumberOfOutputNodes, m_NumberOfInputNodes);
253 m_DW.set_column( m_NumberOfInputNodes-1,v);
257 template<
class TMeasurementVector,
class TTargetVector>
263 db_temp.set_size(m_NumberOfOutputNodes);
271 template<
class TMeasurementVector,
class TTargetVector>
276 return m_Del_new.data_block();
279 template<
class TMeasurementVector,
class TTargetVector>
284 return m_Delb_new.data_block();
287 template<
class TMeasurementVector,
class TTargetVector>
292 return m_Del.data_block();
295 template<
class TMeasurementVector,
class TTargetVector>
300 return m_Del_m_1.data_block();
303 template<
class TMeasurementVector,
class TTargetVector>
308 return m_Del_m_2.data_block();
311 template<
class TMeasurementVector,
class TTargetVector>
316 return m_Delb_m_1.data_block();
319 template<
class TMeasurementVector,
class TTargetVector>
324 return m_DW.data_block();
327 template<
class TMeasurementVector,
class TTargetVector>
332 return m_DB_m_1.data_block();
335 template<
class TMeasurementVector,
class TTargetVector>
340 return m_DW_m_1.data_block();
343 template<
class TMeasurementVector,
class TTargetVector>
348 return m_DW_m_2.data_block();
351 template<
class TMeasurementVector,
class TTargetVector>
356 return m_InputLayerOutput.data_block();
359 template<
class TMeasurementVector,
class TTargetVector>
364 return m_OutputValues.data_block();
367 template<
class TMeasurementVector,
class TTargetVector>
372 return m_Delb.data_block();
375 template<
class TMeasurementVector,
class TTargetVector>
380 return m_WeightMatrix.data_block();
383 template<
class TMeasurementVector,
class TTargetVector>
388 return m_WeightMatrix.data_block();
391 template<
class TMeasurementVector,
class TTargetVector>
396 m_ConnectivityMatrix = c;
400 template<
class TMeasurementVector,
class TTargetVector>
405 m_Del_m_2 = m_Del_m_1;
406 m_Del_m_1 = m_Del_new;
408 m_Delb_m_2 = m_Delb_m_1;
409 m_Delb_m_1 = m_Delb_new;
414 m_DW.set_column(m_NumberOfInputNodes - 1, m_DB);
415 m_WeightMatrix += m_DW;
416 m_DW.set_column(m_NumberOfInputNodes - 1, m_Delb_new);
424 if(m_FirstPass ==
true)
428 else if(m_FirstPass ==
false && m_SecondPass==
true)
430 m_SecondPass =
false;
436 template<
class TMeasurementVector,
class TTargetVector>
441 Superclass::PrintSelf( os, indent );
443 os << indent <<
"WeightSetBase(" <<
this <<
")"
446 os << indent <<
"m_RandomGenerator = " << m_RandomGenerator
448 os << indent <<
"m_NumberOfInputNodes = " << m_NumberOfInputNodes
450 os << indent <<
"m_NumberOfOutputNodes = " << m_NumberOfOutputNodes
452 os << indent <<
"m_OutputValues = " << m_OutputValues
454 os << indent <<
"m_InputErrorValues = " << m_InputErrorValues
457 os << indent <<
"m_DW = " << m_DW
459 os << indent <<
"m_DW_new = " << m_DW_new
461 os << indent <<
"m_DW_m_1 = " << m_DW_m_1
463 os << indent <<
"m_DW_m_2 = " << m_DW_m_2
465 os << indent <<
"m_DW_m = " << m_DW_m
468 os << indent <<
"m_DB = " << m_DB
470 os << indent <<
"m_DB_new = " << m_DB_new
472 os << indent <<
"m_DB_m_1 = " << m_DB_m_1
474 os << indent <<
"m_DB_m_2 = " << m_DB_m_2
477 os << indent <<
"m_Del = " << m_Del
479 os << indent <<
"m_Del_new = " << m_Del_new
481 os << indent <<
"m_Del_m_1 = " << m_Del_m_1
483 os << indent <<
"m_Del_m_2 = " << m_Del_m_2
486 os << indent <<
"m_Delb = " << m_Delb
488 os << indent <<
"m_Delb_new = " << m_Delb_new
490 os << indent <<
"m_Delb_m_1 = " << m_Delb_m_1
492 os << indent <<
"m_Delb_m_2 = " << m_Delb_m_2
495 os << indent <<
"m_InputLayerOutput = " << m_InputLayerOutput
497 os << indent <<
"m_WeightMatrix = " << m_WeightMatrix
499 os << indent <<
"m_ConnectivityMatrix = " << m_ConnectivityMatrix
502 os << indent <<
"m_Momentum = " << m_Momentum
504 os << indent <<
"m_Bias = " << m_Bias
506 os << indent <<
"m_FirstPass = " << m_FirstPass
508 os << indent <<
"m_SecondPass = " << m_SecondPass
510 os << indent <<
"m_Range = " << m_Range