TrainMachineLearningModelFromImagesExample.cxxΒΆ

Example usage:

./TrainMachineLearningModelFromImagesExample Input/QB_1_ortho.tif Input/VectorData_QB1.shp Output/clLIBSVMModelQB1.libsvm

Example source code (TrainMachineLearningModelFromImagesExample.cxx):

// This example illustrates the use of the
// \doxygen{otb}{MachineLearningModel} class. This class allows the
// estimation of a classification model (supervised learning) from images. In this example, we will train an SVM
// with 4 classes. We start by including the appropriate header files.
// List sample generator
#include "otbListSampleGenerator.h"

// Extract a ROI of the vectordata
#include "otbVectorDataIntoImageProjectionFilter.h"

// SVM model Estimator
#include "otbSVMMachineLearningModel.h"


// Image
#include "otbVectorImage.h"
#include "otbVectorData.h"

// Reader
#include "otbImageFileReader.h"
#include "otbVectorDataFileReader.h"

// Normalize the samples
//#include "otbShiftScaleSampleListFilter.h"


int main(int itkNotUsed(argc), char* argv[])
{
  const char* inputImageFileName  = argv[1];
  const char* trainingShpFileName = argv[2];
  const char* outputModelFileName = argv[3];

  using InputPixelType         = unsigned int;
  const unsigned int Dimension = 2;
  using InputImageType         = otb::VectorImage<InputPixelType, Dimension>;
  using VectorDataType         = otb::VectorData<double, 2>;
  using InputReaderType        = otb::ImageFileReader<InputImageType>;
  using VectorDataReaderType   = otb::VectorDataFileReader<VectorDataType>;

  // In this framework, we must transform the input samples store in a vector
  // data into a \subdoxygen{itk}{Statistics}{ListSample} which is the structure
  // compatible with the machine learning classes. On the one hand, we are using feature vectors
  // for the characterization of the classes, and on the other hand, the class labels
  // are scalar values. We first re-project the input vector data over the input image, using the
  // \doxygen{otb}{VectorDataIntoImageProjectionFilter} class. To convert the
  // input samples store in a vector data into a
  // \subdoxygen{itk}{Statistics}{ListSample}, we use the
  // \doxygen{otb}{ListSampleGenerator} class.

  // VectorData projection filter
  using VectorDataReprojectionType = otb::VectorDataIntoImageProjectionFilter<VectorDataType, InputImageType>;

  InputReaderType::Pointer inputReader = InputReaderType::New();
  inputReader->SetFileName(inputImageFileName);

  InputImageType::Pointer image = inputReader->GetOutput();
  image->UpdateOutputInformation();

  // Read the Vectordata
  VectorDataReaderType::Pointer vectorReader = VectorDataReaderType::New();
  vectorReader->SetFileName(trainingShpFileName);
  vectorReader->Update();

  VectorDataType::Pointer vectorData = vectorReader->GetOutput();
  vectorData->Update();

  VectorDataReprojectionType::Pointer vdreproj = VectorDataReprojectionType::New();

  vdreproj->SetInputImage(image);
  vdreproj->SetInput(vectorData);
  vdreproj->SetUseOutputSpacingAndOriginFromImage(false);

  vdreproj->Update();

  using ListSampleGeneratorType = otb::ListSampleGenerator<InputImageType, VectorDataType>;
  ListSampleGeneratorType::Pointer sampleGenerator;
  sampleGenerator = ListSampleGeneratorType::New();

  sampleGenerator->SetInput(image);
  sampleGenerator->SetInputVectorData(vdreproj->GetOutput());
  sampleGenerator->SetClassKey("Class");

  sampleGenerator->Update();


  // std::cout << "Number of classes: " << sampleGenerator->GetNumberOfClasses() << std::endl;

  // using ListSampleType = ListSampleGeneratorType::ListSampleType;
  // using ShiftScaleFilterType = otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType>;

  // // Shift scale the samples
  // ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
  // trainingShiftScaleFilter->SetInput(concatenateTrainingSamples->GetOutput());
  // trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
  // trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
  // trainingShiftScaleFilter->Update();


  // Now, we need to declare the machine learning model which will be used by the
  // classifier. In this example, we train an SVM model. The
  // \doxygen{otb}{SVMMachineLearningModel} class inherits from the pure virtual
  // class \doxygen{otb}{MachineLearningModel} which is templated over the type of
  // values used for the measures and the type of pixels used for the labels. Most
  // of the classification and regression algorithms available through this
  // interface in OTB is based on the OpenCV library \cite{opencv_library}. Specific methods
  // can be used to set classifier parameters. In the case of SVM, we set here the type
  // of the kernel. Other parameters are let with their default values.

  using SVMType = otb::SVMMachineLearningModel<InputImageType::InternalPixelType, ListSampleGeneratorType::ClassLabelType>;

  SVMType::Pointer SVMClassifier = SVMType::New();

  SVMClassifier->SetInputListSample(sampleGenerator->GetTrainingListSample());
  SVMClassifier->SetTargetListSample(sampleGenerator->GetTrainingListLabel());

  SVMClassifier->SetKernelType(CvSVM::LINEAR);

  // The machine learning interface is generic and gives access to other classifiers. We now train the
  // SVM model using the \code{Train} and save the model to a text file using the
  // \code{Save} method.

  SVMClassifier->Train();
  SVMClassifier->Save(outputModelFileName);


  // You can now use the \code{Predict} method which takes a
  // \subdoxygen{itk}{Statistics}{ListSample} as input and estimates the label of each
  // input sample using the model. Finally, the
  // \doxygen{otb}{ImageClassificationModel} inherits from the
  // \doxygen{itk}{ImageToImageFilter} and allows classifying pixels in the
  // input image by predicting their labels using a model.
}