RSGISLib Scikit-Learn Pixel Classification Module¶
These functions allow a classifier from the scikit-learn (https://scikit-learn.org) library to be trained and applied on an individual image pixel basis. This requires a number of processing steps to be undertaken:
Define training pixels
Extract training pixels
Train classifier
Apply classifier
To define training a raster with a unique value for each class, or multiple binary rasters one for each class. Commonly the training regions might be defined using a vector layer which would require rasterising:
import rsgislib.vectorutils
sen2_img = 'sen2_srefimg.kea'
mangroves_sample_vec_file = 'mangrove_cls_samples.geojson'
mangroves_sample_vec_lyr = 'mangrove_cls_samples'
mangroves_sample_img = 'mangrove_cls_samples.kea'
rsgislib.vectorutils.rasteriseVecLyr(mangroves_sample_vec_file, mangroves_sample_vec_lyr, sen2_img, mangroves_sample_img, gdalformat='KEA')
other_sample_vec_file = 'other_cls_samples.geojson'
other_sample_vec_lyr = 'other_cls_samples'
other_sample_img = 'other_cls_samples.kea'
rsgislib.vectorutils.rasteriseVecLyr(other_sample_vec_file, other_sample_vec_lyr, sen2_img, other_sample_img, gdalformat='KEA')
To extract the image pixel values, which are stored within a HDF5 file (see https://portal.hdfgroup.org/display/HDF5/HDF5 for more information) the following functions are used. To define the images and associated bands to be used for the classification and therefore values need to be extracted then a list of rsgislib.imageutils.ImageBandInfo classes needs to be provided:
import rsgislib.imageutils
imgs_info = []
imgs_info.append(rsgislib.imageutils.ImageBandInfo(fileName='sen2_srefimg.kea', name='sen2', bands=[1,2,3,4,5,6,7,8,9,10]))
imgs_info.append(rsgislib.imageutils.ImageBandInfo(fileName='sen1_dBimg.kea', name='sen1', bands=[1,2]))
mangroves_sample_h5 = 'mangrove_cls_samples.h5'
rsgislib.imageutils.extractZoneImageBandValues2HDF(imgs_info, mangroves_sample_img, mangroves_sample_h5, 1)
other_sample_h5 = 'other_cls_samples.h5'
rsgislib.imageutils.extractZoneImageBandValues2HDF(imgs_info, other_sample_img, other_sample_h5, 1)
If training data is extracted from multiple input images then it will need to be merged using the following function:
rsgislib.imageutils.mergeExtractedHDF5Data(['mang_samples_1.h5', 'mang_samples_2.h5'], 'mangrove_cls_samples.h5')
rsgislib.imageutils.mergeExtractedHDF5Data(['other_samples_1.h5', 'other_samples_2.h5'], 'other_cls_samples.h5')
The data then needs splitting into training and testing datasets. The training data should also normally be balanced so there is the same number of samples per class:
mangroves_sample_train_h5 = 'mangrove_cls_samples_train.h5'
mangroves_sample_test_h5 = 'mangrove_cls_samples_test.h5'
rsgislib.imageutils.splitSampleHDF5File(mangroves_sample_h5, mangroves_sample_train_h5, mangroves_sample_test_h5, 1000, 42)
other_sample_train_h5 = 'other_cls_samples_train.h5'
other_sample_test_h5 = 'other_cls_samples_test.h5'
rsgislib.imageutils.splitSampleHDF5File(other_sample_h5, other_sample_train_h5, other_sample_test_h5, 1000, 42)
The classifier now needs training, so import rsgislib.classification.classsklearn modules:
import rsgislib.classification
import rsgislib.classification.classsklearn
You then have two options for training, providing the parameters yourself or performing a grid search to find the optimal parameters for the classifier given the input data.
To train a classifier with parameters defined by yourself then you would use the follow code, we will use the random forests classifier but you can use any other classifier from the scikit-learn library:
from sklearn.ensemble import RandomForestClassifier
skclf = RandomForestClassifier(n_estimators=100)
cls_train_info = dict()
cls_train_info['Mangroves'] = rsgislib.classification.ClassSimpleInfoObj(id=1, fileH5='mangrove_cls_samples_train.h5', red=0, green=255, blue=0)
cls_train_info['Other'] = rsgislib.classification.ClassSimpleInfoObj(id=2, fileH5='other_cls_samples_train.h5', red=100, green=100, blue=100)
rsgislib.classification.classsklearn.train_sklearn_classifier(cls_train_info, skclf)
To train a classifier using a grid search you need to define the classifier parameters to be searched and a range of valid values for those parameters, the optimal trained classifier (using all the training data) will be returned by the function:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
cls_train_info = dict()
cls_train_info['Mangroves'] = rsgislib.classification.ClassSimpleInfoObj(id=1, fileH5='mangrove_cls_samples_train.h5', red=0, green=255, blue=0)
cls_train_info['Other'] = rsgislib.classification.ClassSimpleInfoObj(id=2, fileH5='other_cls_samples_train.h5', red=100, green=100, blue=100)
grid_search = GridSearchCV(RandomForestClassifier(), param_grid={'n_estimators':[10,20,50,100], 'max_depth':[2,4,8]})
skclf = rsgislib.classification.classsklearn.train_sklearn_classifer_gridsearch(cls_train_info, 500, grid_search)
To apply trained classifier you need to use the following function:
out_cls_img = 'mangrove_classification_result.kea'
img_msk = 'valid_area_msk.kea'
rsgislib.classification.classsklearn.apply_sklearn_classifer(cls_train_info, skclf, img_msk, 1, imgs_info, out_cls_img, 'KEA', classClrNames=True)
The output image file name needs to be defined and an image mask also needs to provided which defines the parts of the image to be classified. This is useful as by using a previous classification result as the mask for another classifier a hierarchical classification process could be built.
Training Functions¶
-
rsgislib.classification.classsklearn.
train_sklearn_classifer_gridsearch
(classTrainInfo, paramSearchSampNum=0, gridSearch=GridSearchCV(estimator=RandomForestClassifier(), param_grid={}))¶ A function to find the ‘optimal’ parameters for classification using a Grid Search (http://scikit-learn.org/stable/modules/grid_search.html). The returned classifier instance will be trained using all the inputted data.
- Parameters
classTrainInfo – list of rsgislib.classification.ClassSimpleInfoObj objects which will be used to train the classifier.
paramSearchSampNum – the number of samples that will be randomly sampled from the training data for each class for applying the grid search (tend to use a small data sample as can take a long time). A value of 500 would use 500 samples per class.
gridSearch – is an instance of the sklearn.model_selection.GridSearchCV with an instance of the choosen classifier and parameters to be searched.
-
rsgislib.classification.classsklearn.
train_sklearn_classifier
(classTrainInfo, skClassifier)¶ This function trains the classifier.
- Parameters
classTrainInfo – list of rsgislib.classification.ClassSimpleInfoObj objects which will be used to train the classifier.
skClassifier – an instance of a parameterised scikit-learn classifier (http://scikit-learn.org/stable/supervised_learning.html)
Classify Functions¶
-
rsgislib.classification.classsklearn.
apply_sklearn_classifer
(classTrainInfo, skClassifier, imgMask, imgMaskVal, imgFileInfo, outputImg, gdalformat, classClrNames=True, outScoreImg=None)¶ This function uses a trained classifier and applies it to the provided input image.
- Parameters
classTrainInfo – dict (where the key is the class name) of rsgislib.classification.ClassSimpleInfoObj objects which will be used to train the classifier (i.e., train_sklearn_classifier()), provide pixel value id and RGB class values.
skClassifier – a trained instance of a scikit-learn classifier (e.g., use train_sklearn_classifier or train_sklearn_classifer_gridsearch)
imgMask – is an image file providing a mask to specify where should be classified. Simplest mask is all the valid data regions (rsgislib.imageutils.genValidMask)
imgMaskVal – the pixel value within the imgMask to limit the region to which the classification is applied. Can be used to create a heirachical classification.
imgFileInfo – a list of rsgislib.imageutils.ImageBandInfo objects (also used within rsgislib.imageutils.extractZoneImageBandValues2HDF) to identify which images and bands are to be used for the classification so it adheres to the training data.
outputImg – output image file with the classification. Note. by default a colour table and class names column is added to the image. If an error is produced use HFA or KEA formats.
gdalformat – is the output image format - all GDAL supported formats are supported.
classClrNames – default is True and therefore a colour table will the colours specified in classTrainInfo and a ClassName column (from imgFileInfo) will be added to the output file.
outScoreImg – A file path for a score image. If None then not outputted. Note, this function uses the predict_proba() function from the scikit-learn model which isn’t available for all classifiers and therefore might produce an error if called on a model which doesn’t have this function. For example, sklearn.svm.SVC.
-
rsgislib.classification.classsklearn.
perform_voting_classification
(skClassifiers, trainSamplesInfo, imgFileInfo, classAreaMask, classMaskPxlVal, tmpDIR, tmpImgBase, outClassImg, gdalformat='KEA', numCores=- 1)¶ A function which will perform a number of classification creating a combined classification by a simple vote. The classifier parameters can be differed as a list of classifiers is provided (the length of the list is equal to the number of votes), where the training data is resampled for each classifier. The analysis can be performed using multiple processing cores.
Where:
- Parameters
skClassifiers – a list of classifiers (from scikit-learn), the number of classifiers defined will be equal to the number of votes.
trainSamplesInfo – a list of rsgislib.classification.classimgutils.SamplesInfoObj objects used to parameters the classifer and extract training data.
imgFileInfo – a list of rsgislib.imageutils.ImageBandInfo objects (also used within rsgislib.imageutils.extractZoneImageBandValues2HDF) to identify which images and bands are to be used for the classification so it adheres to the training data.
classAreaMask – a mask image which is used to specified the areas of the scene which are to be classified.
classMaskPxlVal – is the pixel value within the classAreaMask image for the areas of the image which are to be classified.
tmpDIR – a temporary file location which will be created and removed during processing.
tmpImgBase – the same name of files written to the tmpDIR
outClassImg – the final output image file.
gdalformat – the output file format for outClassImg
numCores – is the number of processing cores to be used for the analysis (if -1 then all cores on the machine will be used).
Example:
classVoteTemp = os.path.join(imgTmp, 'ClassVoteTemp') imgFileInfo = [rsgislib.imageutils.ImageBandInfo(img2010dB, 'sardb', [1,2]), rsgislib.imageutils.ImageBandInfo(imgSRTM, 'srtm', [1])] trainSamplesInfo = [] trainSamplesInfo.append(SamplesInfoObj(className='Water', classID=1, maskImg=classTrainRegionsMask, maskPxlVal=1, outSampImgFile='WaterSamples.kea', numSamps=500, samplesH5File='WaterSamples_pxlvals.h5', red=0, green=0, blue=255)) trainSamplesInfo.append(SamplesInfoObj(className='Land', classID=2, maskImg=classTrainRegionsMask, maskPxlVal=2, outSampImgFile='LandSamples.kea', numSamps=500, samplesH5File='LandSamples_pxlvals.h5', red=150, green=150, blue=150)) trainSamplesInfo.append(SamplesInfoObj(className='Mangroves', classID=3, maskImg=classTrainRegionsMask, maskPxlVal=3, outSampImgFile='MangroveSamples.kea', numSamps=500, samplesH5File='MangroveSamples_pxlvals.h5', red=0, green=153, blue=0)) skClassifiers = [] for i in range(5): skClassifiers.append(ExtraTreesClassifier(n_estimators=50)) for i in range(5): skClassifiers.append(ExtraTreesClassifier(n_estimators=100)) for i in range(5): skClassifiers.append(ExtraTreesClassifier(n_estimators=50, max_depth=2)) for i in range(5): skClassifiers.append(ExtraTreesClassifier(n_estimators=100, max_depth=2)) mangroveRegionClassImg = MangroveRegionClass.kea classsklearn.perform_voting_classification(skClassifiers, trainSamplesInfo, imgFileInfo, classWithinMask, 1, classVoteTemp, 'ClassImgSample', mangroveRegionClassImg, gdalformat='KEA', numCores=-1)