RSGISLib Scikit-Learn Pixel Classification
These functions allow a classifier from the scikit-learn (https://scikit-learn.org) library to be trained and applied on an individual image pixel basis. This requires a number of processing steps to be undertaken:
Define training pixels
Extract training pixels
Train classifier
Apply classifier
To define training a raster with a unique value for each class, or multiple binary rasters one for each class. Commonly the training regions might be defined using a vector layer which would require rasterising:
import rsgislib.vectorutils
sen2_img = 'sen2_srefimg.kea'
mangroves_sample_vec_file = 'mangrove_cls_samples.geojson'
mangroves_sample_vec_lyr = 'mangrove_cls_samples'
mangroves_sample_img = 'mangrove_cls_samples.kea'
rsgislib.vectorutils.rasteriseVecLyr(mangroves_sample_vec_file, mangroves_sample_vec_lyr, sen2_img, mangroves_sample_img, gdalformat='KEA')
other_sample_vec_file = 'other_cls_samples.geojson'
other_sample_vec_lyr = 'other_cls_samples'
other_sample_img = 'other_cls_samples.kea'
rsgislib.vectorutils.rasteriseVecLyr(other_sample_vec_file, other_sample_vec_lyr, sen2_img, other_sample_img, gdalformat='KEA')
To extract the image pixel values, which are stored within a HDF5 file (see https://portal.hdfgroup.org/display/HDF5/HDF5 for more information) the following functions are used. To define the images and associated bands to be used for the classification and therefore values need to be extracted then a list of rsgislib.imageutils.ImageBandInfo classes needs to be provided:
import rsgislib.imageutils
imgs_info = []
imgs_info.append(rsgislib.imageutils.ImageBandInfo(fileName='sen2_srefimg.kea', name='sen2', bands=[1,2,3,4,5,6,7,8,9,10]))
imgs_info.append(rsgislib.imageutils.ImageBandInfo(fileName='sen1_dBimg.kea', name='sen1', bands=[1,2]))
mangroves_sample_h5 = 'mangrove_cls_samples.h5'
rsgislib.imageutils.extractZoneImageBandValues2HDF(imgs_info, mangroves_sample_img, mangroves_sample_h5, 1)
other_sample_h5 = 'other_cls_samples.h5'
rsgislib.imageutils.extractZoneImageBandValues2HDF(imgs_info, other_sample_img, other_sample_h5, 1)
If training data is extracted from multiple input images then it will need to be merged using the following function:
rsgislib.imageutils.mergeExtractedHDF5Data(['mang_samples_1.h5', 'mang_samples_2.h5'], 'mangrove_cls_samples.h5')
rsgislib.imageutils.mergeExtractedHDF5Data(['other_samples_1.h5', 'other_samples_2.h5'], 'other_cls_samples.h5')
The data then needs splitting into training and testing datasets. The training data should also normally be balanced so there is the same number of samples per class:
mangroves_sample_train_h5 = 'mangrove_cls_samples_train.h5'
mangroves_sample_test_h5 = 'mangrove_cls_samples_test.h5'
rsgislib.imageutils.splitSampleHDF5File(mangroves_sample_h5, mangroves_sample_train_h5, mangroves_sample_test_h5, 1000, 42)
other_sample_train_h5 = 'other_cls_samples_train.h5'
other_sample_test_h5 = 'other_cls_samples_test.h5'
rsgislib.imageutils.splitSampleHDF5File(other_sample_h5, other_sample_train_h5, other_sample_test_h5, 1000, 42)
The classifier now needs training, so import rsgislib.classification.classsklearn modules:
import rsgislib.classification
import rsgislib.classification.classsklearn
You then have two options for training, providing the parameters yourself or performing a grid search to find the optimal parameters for the classifier given the input data.
To train a classifier with parameters defined by yourself then you would use the follow code, we will use the random forests classifier but you can use any other classifier from the scikit-learn library:
from sklearn.ensemble import RandomForestClassifier
skclf = RandomForestClassifier(n_estimators=100)
cls_train_info = dict()
cls_train_info['Mangroves'] = rsgislib.classification.ClassSimpleInfoObj(id=1, fileH5='mangrove_cls_samples_train.h5', red=0, green=255, blue=0)
cls_train_info['Other'] = rsgislib.classification.ClassSimpleInfoObj(id=2, fileH5='other_cls_samples_train.h5', red=100, green=100, blue=100)
rsgislib.classification.classsklearn.train_sklearn_classifier(cls_train_info, skclf)
To train a classifier using a grid search you need to define the classifier parameters to be searched and a range of valid values for those parameters, the optimal trained classifier (using all the training data) will be returned by the function:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
cls_train_info = dict()
cls_train_info['Mangroves'] = rsgislib.classification.ClassSimpleInfoObj(id=1, fileH5='mangrove_cls_samples_train.h5', red=0, green=255, blue=0)
cls_train_info['Other'] = rsgislib.classification.ClassSimpleInfoObj(id=2, fileH5='other_cls_samples_train.h5', red=100, green=100, blue=100)
grid_search = GridSearchCV(RandomForestClassifier(), param_grid={'n_estimators':[10,20,50,100], 'max_depth':[2,4,8]})
skclf = rsgislib.classification.classsklearn.train_sklearn_classifier_gridsearch(cls_train_info, 500, grid_search)
To apply trained classifier you need to use the following function:
out_cls_img = 'mangrove_classification_result.kea'
img_msk = 'valid_area_msk.kea'
rsgislib.classification.classsklearn.apply_sklearn_classifier(cls_train_info, skclf, img_msk, 1, imgs_info, out_cls_img, 'KEA', classClrNames=True)
The output image file name needs to be defined and an image mask also needs to provided which defines the parts of the image to be classified. This is useful as by using a previous classification result as the mask for another classifier a hierarchical classification process could be built.
Training Functions
- rsgislib.classification.classsklearn.perform_sklearn_classifier_param_search(cls_train_info: Dict[str, ClassInfoObj], search_obj: BaseSearchCV) BaseEstimator
A function to find the ‘optimal’ parameters for classification using a grid search or random search (http://scikit-learn.org/stable/modules/grid_search.html). The validation data will be used to identify the optimal parameters and the returned classifier will be initialised with those parameters but not trained
- Parameters
cls_train_info – list of rsgislib.classification.ClassInfoObj objects which will be used to train the classifier.
search_obj – is an instance of the sklearn.model_selection.BaseSearchCV (e.g., GridSearchCV or RandomizedSearchCV) object parameterised with an instance of the classifier and associated parameters to be searched.
- Returns
Instance of
- rsgislib.classification.classsklearn.train_sklearn_classifier(cls_train_info: ~typing.Dict[str, ~rsgislib.classification.ClassInfoObj], sk_classifier: ~sklearn.base.BaseEstimator) -> (<class 'float'>, <class 'float'>)
This function trains the classifier.
- Parameters
cls_train_info – list of rsgislib.classification.ClassInfoObj objects which will be used to train and test the classifier.
sk_classifier – an instance of a parameterised scikit-learn classifier (http://scikit-learn.org/stable/supervised_learning.html)
:return training and testing accuracies (between 0-1)
Classify Functions
- rsgislib.classification.classsklearn.apply_sklearn_classifier(cls_train_info: Dict[str, ClassInfoObj], sk_classifier: BaseEstimator, in_img_mask: str, img_mask_val: int, img_file_info: List[ImageBandInfo], output_img: str, gdalformat: str = 'KEA', class_clr_names: bool = True, out_score_img: Optional[str] = None, ignore_consec_cls_ids: bool = False)
This function uses a trained classifier and applies it to the provided input image.
- Parameters
cls_train_info – dict (where the key is the class name) of rsgislib.classification.ClassInfoObj objects which will be used to train the classifier provide pixel value id and RGB class values.
sk_classifier – a trained instance of a scikit-learn classifier
in_img_mask – is an image file providing a mask to specify where should be classified. Simplest mask is all the valid data regions (rsgislib.imageutils.gen_valid_mask)
img_mask_val – the pixel value within the imgMask to limit the region to which the classification is applied. Can be used to create a hierarchical classification.
img_file_info – a list of rsgislib.imageutils.ImageBandInfo objects to identify which images and bands are to be used for the classification so it adheres to the training data.
output_img – output image file with the classification. Note. by default a colour table and class names column is added to the image if the gdalformat is KEA.
gdalformat – is the output image format
class_clr_names – default is True and therefore a colour table will the colours specified in classTrainInfo and a class_names_col column (from imgFileInfo) will be added to the output file.
out_score_img – A file path for a score image. If None then not outputted. Note, this function uses the predict_proba() function from the scikit-learn model which isn’t available for all classifiers and therefore might produce an error if called on a model which doesn’t have this function. For example, sklearn.svm.SVC.
ignore_consec_cls_ids – A boolean to specify whether to ignore that the class ids should be consecutive and the out_ids used to specify other non-consecutive ids. This has some risks but allows more flexibility when using the function.
- rsgislib.classification.classsklearn.apply_sklearn_classifier_rat(clumps_img: str, variables: List[str], sk_classifier: BaseEstimator, cls_train_info: Dict[str, ClassInfoObj], out_col_int: str = 'OutClass', out_col_str: str = 'OutClassName', roi_col: Optional[str] = None, roi_val: int = 1, class_colours: bool = True)
A function which will apply an scikit-learn classifier within a Raster Attribute Table (RAT).
- Parameters
clumps_img – is the clumps image on which the classification is to be performed
variables – is an array of column names which are to be used for the classification
sk_classifier – a trained instance of a scikit-learn classifier
cls_train_info – dict (where the key is the class name) of rsgislib.classification.ClassInfoObj objects which will be used to train the classifier provide pixel value id and RGB class values.
out_col_int – is the output column name for the int class representation (Default: ‘OutClass’)
out_col_str – is the output column name for the class names column (Default: ‘OutClassName’)
roi_col – is a column name for a column which specifies the region to be classified. If None ignored (Default: None)
roi_val – is a int value used within the roi_col to select a region to be classified (Default: 1)
class_colours – is a boolean specifying whether the RAT colour table should be updated using the classification colours (default: True)