TrainMachineLearningModelFromSamplesExample.cxxΒΆ

Example usage:

./TrainMachineLearningModelFromSamplesExample Output/clSVMModelFromSamples.svm

Example source code (TrainMachineLearningModelFromSamplesExample.cxx):

// This example illustrates the use of the \doxygen{otb}{SVMMachineLearningModel} class, which inherits from the
// \doxygen{otb}{MachineLearningModel} class. This class allows the
// estimation of a classification model (supervised learning) from samples. In this example, we will train an SVM model
// with 4 output classes, from 1000 randomly generated training samples, each of them having 7 components.
// We start by including the appropriate header files.
// List sample generator
#include "otbListSampleGenerator.h"

// Random number generator

// SVM model Estimator
#include "otbSVMMachineLearningModel.h"


int main(int argc, char* argv[])
{

  if (argc != 2)
  {
    std::cerr << "Usage: " << argv[0] << " outputModelFileName" << std::endl;
    return EXIT_FAILURE;
  }

  // The input parameters of the sample generator and of the SVM classifier are initialized.
  int nbSamples          = 1000;
  int nbSampleComponents = 7;
  int nbClasses          = 4;

  const char* outputModelFileName = argv[1];

  // Two lists are generated into a \subdoxygen{itk}{Statistics}{ListSample} which is the structure
  // used to handle both lists of samples and of labels for the machine learning classes derived from
  // \doxygen{otb}{MachineLearningModel}. The first list is composed of feature vectors representing
  // multi-component samples, and the second one is filled with their corresponding class labels. The
  // list of labels is composed of scalar values.


  // Input related typedefs
  using InputValueType      = float;
  using InputSampleType     = itk::VariableLengthVector<InputValueType>;
  using InputListSampleType = itk::Statistics::ListSample<InputSampleType>;

  // Target related typedefs
  using TargetValueType      = int;
  using TargetSampleType     = itk::FixedArray<TargetValueType, 1>;
  using TargetListSampleType = itk::Statistics::ListSample<TargetSampleType>;

  InputListSampleType::Pointer  InputListSample  = InputListSampleType::New();
  TargetListSampleType::Pointer TargetListSample = TargetListSampleType::New();

  InputListSample->SetMeasurementVectorSize(nbSampleComponents);


  // In this example, the list of multi-component training samples is randomly filled with a random number
  // generator based on the \subdoxygen{itk}{Statistics}{MersenneTwisterRandomVariateGenerator} class.
  // Each component's value is generated from a normal law centered around the corresponding class label of
  // each sample multiplied by 100, with a standard deviation of 10.


  itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen;
  randGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance();

  // Filling the two input training lists
  for (int i = 0; i < nbSamples; ++i)
  {
    InputSampleType sample;
    TargetValueType label = (i % nbClasses) + 1;

    // Multi-component sample randomly filled from a normal law for each component
    sample.SetSize(nbSampleComponents);
    for (int itComp = 0; itComp < nbSampleComponents; ++itComp)
    {
      sample[itComp] = randGen->GetNormalVariate(100 * label, 10);
    }

    InputListSample->PushBack(sample);
    TargetListSample->PushBack(label);
  }

  // Displays the corresponding values
  for (int i = 0; i < nbSamples; ++i)
  {
    std::cout << i + 1 << "-label = " << TargetListSample->GetMeasurementVector(i) << std::endl;
    std::cout << "sample = " << InputListSample->GetMeasurementVector(i) << std::endl << std::endl;
  }

  // Once both sample and label lists are generated, the second step consists in
  // declaring the machine learning classifier. In our case we use an SVM model
  // with the help of the \doxygen{otb}{SVMMachineLearningModel} class which is
  // derived from the \doxygen{otb}{MachineLearningModel} class.
  // This pure virtual class is based on the machine learning framework of the
  // OpenCV library (\cite{opencv_library}) which handles other classifiers than
  // the SVM.

  using SVMType = otb::SVMMachineLearningModel<InputValueType, TargetValueType>;

  SVMType::Pointer SVMClassifier = SVMType::New();

  SVMClassifier->SetInputListSample(InputListSample);
  SVMClassifier->SetTargetListSample(TargetListSample);

  SVMClassifier->SetKernelType(CvSVM::LINEAR);

  // Once the classifier is parametrized with both input lists and default parameters, except
  // for the kernel type in our example of SVM model estimation, the model
  // training is computed with the \code{Train} method. Finally, the \code{Save} method
  // exports the model to a text file. All the available classifiers based on OpenCV are
  // implemented with these interfaces. Like for the SVM model training, the other classifiers
  // can be parametrized with specific settings.

  SVMClassifier->Train();
  SVMClassifier->Save(outputModelFileName);
}