OTB  9.0.0
Orfeo Toolbox
otbTrainKNN.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainKNN_hxx
22 #define otbTrainKNN_hxx
25 
26 namespace otb
27 {
28 namespace Wrapper
29 {
30 
31 template <class TInputValue, class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitKNNParams()
33 {
34  AddChoice("classifier.knn", "KNN classifier");
35  SetParameterDescription("classifier.knn", "http://docs.opencv.org/modules/ml/doc/k_nearest_neighbors.html");
36 
37  // K parameter
38  AddParameter(ParameterType_Int, "classifier.knn.k", "Number of Neighbors");
39  SetParameterInt("classifier.knn.k", 32);
40  SetParameterDescription("classifier.knn.k", "The number of neighbors to use.");
41 
42  if (this->m_RegressionFlag)
43  {
44  // Decision rule : mean / median
45  AddParameter(ParameterType_Choice, "classifier.knn.rule", "Decision rule");
46  SetParameterDescription("classifier.knn.rule", "Decision rule for regression output");
47 
48  AddChoice("classifier.knn.rule.mean", "Mean of neighbors values");
49  SetParameterDescription("classifier.knn.rule.mean", "Returns the mean of neighbors values");
50 
51  AddChoice("classifier.knn.rule.median", "Median of neighbors values");
52  SetParameterDescription("classifier.knn.rule.median", "Returns the median of neighbors values");
53  }
54 }
55 
56 template <class TInputValue, class TOutputValue>
57 void LearningApplicationBase<TInputValue, TOutputValue>::TrainKNN(typename ListSampleType::Pointer trainingListSample,
58  typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
59 {
61  typename KNNType::Pointer knnClassifier = KNNType::New();
62  knnClassifier->SetRegressionMode(this->m_RegressionFlag);
63  knnClassifier->SetInputListSample(trainingListSample);
64  knnClassifier->SetTargetListSample(trainingLabeledListSample);
65  knnClassifier->SetK(GetParameterInt("classifier.knn.k"));
66  if (this->m_RegressionFlag)
67  {
68  std::string decision = this->GetParameterString("classifier.knn.rule");
69  if (decision == "mean")
70  {
71  knnClassifier->SetDecisionRule(KNNType::KNN_MEAN);
72  }
73  else if (decision == "median")
74  {
75  knnClassifier->SetDecisionRule(KNNType::KNN_MEDIAN);
76  }
77  }
78 
79  knnClassifier->Train();
80  knnClassifier->Save(modelPath);
81 }
82 
83 } // end namespace wrapper
84 } // end namespace otb
85 
86 #endif
otb::Wrapper::ParameterType_Choice
@ ParameterType_Choice
Definition: otbWrapperTypes.h:47
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otb::Wrapper::ParameterType_Int
@ ParameterType_Int
Definition: otbWrapperTypes.h:38
otb::KNearestNeighborsMachineLearningModel
Definition: otbKNearestNeighborsMachineLearningModel.h:35
otbKNearestNeighborsMachineLearningModel.h