OTB  6.7.0
Orfeo Toolbox
otbTrainKNN.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainKNN_hxx
22 #define otbTrainKNN_hxx
25 
26 namespace otb
27 {
28 namespace Wrapper
29 {
30 
31  template <class TInputValue, class TOutputValue>
32  void
33  LearningApplicationBase<TInputValue,TOutputValue>
34  ::InitKNNParams()
35  {
36  AddChoice("classifier.knn", "KNN classifier");
37  SetParameterDescription("classifier.knn", "http://docs.opencv.org/modules/ml/doc/k_nearest_neighbors.html");
38 
39  //K parameter
40  AddParameter(ParameterType_Int, "classifier.knn.k", "Number of Neighbors");
41  SetParameterInt("classifier.knn.k",32);
42  SetParameterDescription("classifier.knn.k","The number of neighbors to use.");
43 
44  if (this->m_RegressionFlag)
45  {
46  // Decision rule : mean / median
47  AddParameter(ParameterType_Choice, "classifier.knn.rule", "Decision rule");
48  SetParameterDescription("classifier.knn.rule", "Decision rule for regression output");
49 
50  AddChoice("classifier.knn.rule.mean", "Mean of neighbors values");
51  SetParameterDescription("classifier.knn.rule.mean","Returns the mean of neighbors values");
52 
53  AddChoice("classifier.knn.rule.median", "Median of neighbors values");
54  SetParameterDescription("classifier.knn.rule.median","Returns the median of neighbors values");
55  }
56  }
57 
58  template <class TInputValue, class TOutputValue>
59  void
60  LearningApplicationBase<TInputValue,TOutputValue>
61  ::TrainKNN(typename ListSampleType::Pointer trainingListSample,
62  typename TargetListSampleType::Pointer trainingLabeledListSample,
63  std::string modelPath)
64  {
66  typename KNNType::Pointer knnClassifier = KNNType::New();
67  knnClassifier->SetRegressionMode(this->m_RegressionFlag);
68  knnClassifier->SetInputListSample(trainingListSample);
69  knnClassifier->SetTargetListSample(trainingLabeledListSample);
70  knnClassifier->SetK(GetParameterInt("classifier.knn.k"));
71  if (this->m_RegressionFlag)
72  {
73  std::string decision = this->GetParameterString("classifier.knn.rule");
74  if (decision == "mean")
75  {
76  knnClassifier->SetDecisionRule(KNNType::KNN_MEAN);
77  }
78  else if (decision == "median")
79  {
80  knnClassifier->SetDecisionRule(KNNType::KNN_MEDIAN);
81  }
82  }
83 
84  knnClassifier->Train();
85  knnClassifier->Save(modelPath);
86  }
87 
88 } //end namespace wrapper
89 } //end namespace otb
90 
91 #endif