RSGISLib Scikit-Learn Pixel Classification
These functions allow a classifier from the scikit-learn (https://scikit-learn.org) library to be trained and applied on an individual image pixel basis. This requires a number of processing steps to be undertaken:
Extract training
Split training: Training, Validation, Testing
Train Classifier and Optimise Hyperparameters
Apply Classifier
However, fist we’ll create a couple of directories for our outputs and intermediary files:
import os
out_dir = "baseline_cls_skl_rf"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
tmp_dir = "tmp_skl_rf"
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
We will also define the input file path and the list ImageBandInfo objects, which specifies which images and bands are used for the analysis:
import rsgislib.imageutils
input_img = "./LS5TM_19970716_vmsk_mclds_topshad_rad_srefdem_stdsref_subset.tif"
imgs_info = []
imgs_info.append(
rsgislib.imageutils.ImageBandInfo(
file_name=input_img, name="ls97", bands=[1, 2, 3, 4, 5, 6]
)
)
When applying a classifier a mask image needs to be provided where a pixel value within that mask specifying which pixels should be classified. While defining the input image we can also define that valid mask image using the rsgislib.imageutils.gen_valid_mask function, which simply creates a mask of pixels which are not ‘no data’:
vld_msk_img = os.path.join(out_dir, "LS5TM_19970716_vmsk.kea")
rsgislib.imageutils.gen_valid_mask(
input_img, output_img=vld_msk_img, gdalformat="KEA", no_data_val=0.0
)
To define training a raster with a unique value for each class, or multiple binary rasters one for each class. Commonly the training regions might be defined using a vector layer which would require rasterising:
import rsgislib.vectorutils.createrasters
mangrove_vec_file = "./training/mangroves.geojson"
mangrove_vec_lyr = "mangroves"
mangrove_smpls_img = os.path.join(tmp_dir, "mangrove_smpls.kea")
rsgislib.vectorutils.createrasters.rasterise_vec_lyr(
vec_file=mangrove_vec_file,
vec_lyr=mangrove_vec_lyr,
input_img=input_img,
output_img=mangrove_smpls_img,
gdalformat="KEA",
burn_val=1,
)
other_terrestrial_vec_file = "./training/other_terrestrial.geojson"
other_terrestrial_vec_lyr = "other_terrestrial"
other_terrestrial_smpls_img = os.path.join(tmp_dir, "other_terrestrial_smpls.kea")
rsgislib.vectorutils.createrasters.rasterise_vec_lyr(
vec_file=other_terrestrial_vec_file,
vec_lyr=other_terrestrial_vec_lyr,
input_img=input_img,
output_img=other_terrestrial_smpls_img,
gdalformat="KEA",
burn_val=1,
)
water_vec_file = "./training/water.geojson"
water_vec_lyr = "water"
water_smpls_img = os.path.join(tmp_dir, "water_smpls.kea")
rsgislib.vectorutils.createrasters.rasterise_vec_lyr(
vec_file=water_vec_file,
vec_lyr=water_vec_lyr,
input_img=input_img,
output_img=water_smpls_img,
gdalformat="KEA",
burn_val=1,
)
To extract the image pixel values, which are stored within a HDF5 file (see https://portal.hdfgroup.org/display/HDF5/HDF5 for more information) the following functions are used. To define the images and associated bands to be used for the classification and therefore values need to be extracted then a list of rsgislib.imageutils.ImageBandInfo classes needs to be provided:
import rsgislib.zonalstats
mangrove_all_smpls_h5_file = os.path.join(out_dir, "mangrove_all_smpls.h5")
rsgislib.zonalstats.extract_zone_img_band_values_to_hdf(
imgs_info,
in_msk_img=mangrove_smpls_img,
out_h5_file=mangrove_all_smpls_h5_file,
mask_val=1,
datatype=rsgislib.TYPE_16UINT,
)
other_terrestrial_all_smpls_h5_file = os.path.join(
out_dir, "other_terrestrial_all_smpls.h5"
)
rsgislib.zonalstats.extract_zone_img_band_values_to_hdf(
imgs_info,
in_msk_img=other_terrestrial_smpls_img,
out_h5_file=other_terrestrial_all_smpls_h5_file,
mask_val=1,
datatype=rsgislib.TYPE_16UINT,
)
water_all_smpls_h5_file = os.path.join(out_dir, "water_all_smpls.h5")
rsgislib.zonalstats.extract_zone_img_band_values_to_hdf(
imgs_info,
in_msk_img=water_smpls_img,
out_h5_file=water_all_smpls_h5_file,
mask_val=1,
datatype=rsgislib.TYPE_16UINT,
)
If training data is extracted from multiple input images then it will need to be merged using the following function. In this case, for illustration we’ll merge the water and terrestrial samples:
other_all_smpls_h5_file = os.path.join(out_dir, "other_all_smpls.h5")
rsgislib.zonalstats.merge_extracted_hdf5_data(
h5_files=[other_terrestrial_all_smpls_h5_file, water_all_smpls_h5_file],
out_h5_file=other_all_smpls_h5_file,
datatype=rsgislib.TYPE_16UINT,
)
To split the extracted samples into a training, validation and testing sets you can use the rsgislib.classification.split_sample_train_valid_test function. Note, this function is also used to standardise the number of samples used to train the classifier so the training data are balanced:
import rsgislib.classification
mangrove_train_smpls_h5_file = os.path.join(out_dir, "mangrove_train_smpls.h5")
mangrove_valid_smpls_h5_file = os.path.join(out_dir, "mangrove_valid_smpls.h5")
mangrove_test_smpls_h5_file = os.path.join(out_dir, "mangrove_test_smpls.h5")
rsgislib.classification.split_sample_train_valid_test(
in_h5_file=mangrove_all_smpls_h5_file,
train_h5_file=mangrove_train_smpls_h5_file,
valid_h5_file=mangrove_valid_smpls_h5_file,
test_h5_file=mangrove_test_smpls_h5_file,
test_sample=10000,
valid_sample=10000,
train_sample=35000,
rnd_seed=42,
datatype=rsgislib.TYPE_16UINT,
)
other_terrestrial_train_smpls_h5_file = os.path.join(
out_dir, "other_terrestrial_train_smpls.h5"
)
other_terrestrial_valid_smpls_h5_file = os.path.join(
out_dir, "other_terrestrial_valid_smpls.h5"
)
other_terrestrial_test_smpls_h5_file = os.path.join(
out_dir, "other_terrestrial_test_smpls.h5"
)
rsgislib.classification.split_sample_train_valid_test(
in_h5_file=other_terrestrial_all_smpls_h5_file,
train_h5_file=other_terrestrial_train_smpls_h5_file,
valid_h5_file=other_terrestrial_valid_smpls_h5_file,
test_h5_file=other_terrestrial_test_smpls_h5_file,
test_sample=10000,
valid_sample=10000,
train_sample=35000,
rnd_seed=42,
datatype=rsgislib.TYPE_16UINT,
)
water_train_smpls_h5_file = os.path.join(out_dir, "water_train_smpls.h5")
water_valid_smpls_h5_file = os.path.join(out_dir, "water_valid_smpls.h5")
water_test_smpls_h5_file = os.path.join(out_dir, "water_test_smpls.h5")
rsgislib.classification.split_sample_train_valid_test(
in_h5_file=water_all_smpls_h5_file,
train_h5_file=water_train_smpls_h5_file,
valid_h5_file=water_valid_smpls_h5_file,
test_h5_file=water_test_smpls_h5_file,
test_sample=10000,
valid_sample=10000,
train_sample=35000,
rnd_seed=42,
datatype=rsgislib.TYPE_16UINT,
)
Note
Training samples are used to train the classifier. Validation samples are used to test the accuracy of the classifier during the parameter optimisation process and are therefore part of the training process and not independent. Testing samples completely independent of the training process and are used as an independent sample to test the overall accuracy of the classifier.
Apply a Scikit-Learn Random Forests Classifier
To train a multi-class classifier you first need to specify the reference samples as a dict of rsgislib.classification.ClassInfoObj objects:
import rsgislib.classification
cls_info_dict = dict()
cls_info_dict["Mangrove"] = rsgislib.classification.ClassInfoObj(
id=0,
out_id=1,
train_file_h5=mangrove_train_smpls_h5_file,
test_file_h5=mangrove_test_smpls_h5_file,
valid_file_h5=mangrove_valid_smpls_h5_file,
red=0,
green=255,
blue=0,
)
cls_info_dict["Other Terrestrial"] = rsgislib.classification.ClassInfoObj(
id=1,
out_id=2,
train_file_h5=other_terrestrial_train_smpls_h5_file,
test_file_h5=other_terrestrial_test_smpls_h5_file,
valid_file_h5=other_terrestrial_valid_smpls_h5_file,
red=100,
green=100,
blue=100,
)
cls_info_dict["Water"] = rsgislib.classification.ClassInfoObj(
id=2,
out_id=3,
train_file_h5=water_train_smpls_h5_file,
test_file_h5=water_test_smpls_h5_file,
valid_file_h5=water_valid_smpls_h5_file,
red=0,
green=0,
blue=255,
)
To train the Random Forest classifier we need to first optimise the algorithm parameters. For this we’ll use a Grid Search:
import rsgislib.classification.classsklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
grid_search = GridSearchCV(
RandomForestClassifier(),
param_grid={"n_estimators": [10, 20, 50, 100], "max_depth": [2, 4, 8]},
)
skl_rf_clf_obj = (
rsgislib.classification.classsklearn.perform_sklearn_classifier_param_search(
cls_train_info=cls_info_dict, search_obj=grid_search
)
)
Once we have an instance of the classifier with the optimal parameters we can then train the classifier:
rsgislib.classification.classsklearn.train_sklearn_classifier(
cls_train_info=cls_info_dict, sk_classifier=skl_rf_clf_obj
)
Using the trained classifier we can then applied it to the image data:
out_cls_img = os.path.join(out_dir, "LS5TM_19970716_skl_rf_cls_img.kea")
out_score_img = os.path.join(out_dir, "LS5TM_19970716_skl_rf_cls_score_img.kea")
rsgislib.classification.classsklearn.apply_sklearn_classifier(
cls_train_info=cls_info_dict,
sk_classifier=skl_rf_clf_obj,
in_msk_img=vld_msk_img,
img_msk_val=1,
img_file_info=imgs_info,
out_class_img=out_cls_img,
gdalformat="KEA",
class_clr_names=True,
out_score_img=out_score_img,
ignore_consec_cls_ids=False,
)
The output image file name needs to be defined and an image mask also needs to provided which defines the parts of the image to be classified. This is useful as by using a previous classification result as the mask for another classifier a hierarchical classification process could be built.
Training Functions
- rsgislib.classification.classsklearn.perform_sklearn_classifier_param_search(cls_train_info: Dict[str, ClassInfoObj], search_obj: BaseSearchCV) BaseEstimator
A function to find the ‘optimal’ parameters for classification using a grid search or random search (http://scikit-learn.org/stable/modules/grid_search.html). The validation data will be used to identify the optimal parameters and the returned classifier will be initialised with those parameters but not trained
- Parameters:
cls_train_info – list of rsgislib.classification.ClassInfoObj objects which will be used to train the classifier.
search_obj – is an instance of the sklearn.model_selection.BaseSearchCV (e.g., GridSearchCV or RandomizedSearchCV) object parameterised with an instance of the classifier and associated parameters to be searched.
- Returns:
Instance of
- rsgislib.classification.classsklearn.train_sklearn_classifier(cls_train_info: ~typing.Dict[str, ~rsgislib.classification.ClassInfoObj], sk_classifier: ~sklearn.base.BaseEstimator) -> (<class 'float'>, <class 'float'>)
This function trains the classifier.
- Parameters:
cls_train_info – list of rsgislib.classification.ClassInfoObj objects which will be used to train and test the classifier.
sk_classifier – an instance of a parameterised scikit-learn classifier (http://scikit-learn.org/stable/supervised_learning.html)
:return training and testing accuracies (between 0-1)
Classify Functions
- rsgislib.classification.classsklearn.apply_sklearn_classifier(cls_train_info: Dict[str, ClassInfoObj], sk_classifier: BaseEstimator, in_msk_img: str, img_msk_val: int, img_file_info: List[ImageBandInfo], out_class_img: str, gdalformat: str = 'KEA', class_clr_names: bool = True, out_score_img: str = None, ignore_consec_cls_ids: bool = False)
This function uses a trained classifier and applies it to the provided input image.
- Parameters:
cls_train_info – dict (where the key is the class name) of rsgislib.classification.ClassInfoObj objects which will be used to train the classifier provide pixel value id and RGB class values.
sk_classifier – a trained instance of a scikit-learn classifier
in_msk_img – is an image file providing a mask to specify where should be classified. Simplest mask is all the valid data regions (rsgislib.imageutils.gen_valid_mask)
img_msk_val – the pixel value within the imgMask to limit the region to which the classification is applied. Can be used to create a hierarchical classification.
img_file_info – a list of rsgislib.imageutils.ImageBandInfo objects to identify which images and bands are to be used for the classification so it adheres to the training data.
out_class_img – output image file with the classification. Note. by default a colour table and class names column is added to the image if the gdalformat is KEA.
gdalformat – is the output image format
class_clr_names – default is True and therefore a colour table will the colours specified in classTrainInfo and a class_names_col column (from imgFileInfo) will be added to the output file.
out_score_img – A file path for a score image. If None then not outputted. Note, this function uses the predict_proba() function from the scikit-learn model which isn’t available for all classifiers and therefore might produce an error if called on a model which doesn’t have this function. For example, sklearn.svm.SVC.
ignore_consec_cls_ids – A boolean to specify whether to ignore that the class ids should be consecutive and the out_ids used to specify other non-consecutive ids. This has some risks but allows more flexibility when using the function.
- rsgislib.classification.classsklearn.apply_sklearn_classifier_rat(clumps_img: str, variables: List[str], sk_classifier: BaseEstimator, cls_train_info: Dict[str, ClassInfoObj], out_col_int: str = 'OutClass', out_col_str: str = 'OutClassName', roi_col: str = None, roi_val: int = 1, class_colours: bool = True)
A function which will apply an scikit-learn classifier within a Raster Attribute Table (RAT).
- Parameters:
clumps_img – is the clumps image on which the classification is to be performed
variables – is an array of column names which are to be used for the classification
sk_classifier – a trained instance of a scikit-learn classifier
cls_train_info – dict (where the key is the class name) of rsgislib.classification.ClassInfoObj objects which will be used to train the classifier provide pixel value id and RGB class values.
out_col_int – is the output column name for the int class representation (Default: ‘OutClass’)
out_col_str – is the output column name for the class names column (Default: ‘OutClassName’)
roi_col – is a column name for a column which specifies the region to be classified. If None ignored (Default: None)
roi_val – is a int value used within the roi_col to select a region to be classified (Default: 1)
class_colours – is a boolean specifying whether the RAT colour table should be updated using the classification colours (default: True)
Other Utility Functions
- rsgislib.classification.classsklearn.feat_sel_sklearn_multiclass_borutashap(sk_classifier: BaseEstimator, cls_info_dict: Dict[str, ClassInfoObj], out_csv_file: str, n_trials: int = 100, sub_train_smpls: int | float = None, rnd_seed: int = None, feat_names: List[str] = None)