RSGISLib CatBoost Image Classification

Training Functions

rsgislib.classification.classcatboost.train_catboost_binary_classifier(mdl_cls_obj, cls1_train_file: str, cls1_valid_file: str, cls1_test_file: str, cls2_train_file: str, cls2_valid_file: str, cls2_test_file: str, cat_cols: Optional[List] = None, out_mdl_file: Optional[str] = None, verbose_training: bool = False)

A function which trains a catboost classifier with two classes (i.e., binary) Class 1 is the class which you are interested in and Class 2 is the ‘other class’.

This function requires that catboost module to be installed.

  • mdl_cls_obj – The catboost model object.

  • cls1_train_file – Training samples HDF5 file for the primary class (i.e., the one being classified)

  • cls1_valid_file – Validation samples HDF5 file for the primary class (i.e., the one being classified)

  • cls1_test_file – Testing samples HDF5 file for the primary class (i.e., the one being classified)

  • cls2_train_file – Training samples HDF5 file for the ‘other’ class

  • cls2_valid_file – Validation samples HDF5 file for the ‘other’ class

  • cls2_test_file – Testing samples HDF5 file for the ‘other’ class

  • cat_cols – list of indexes for variables which are categorical.

  • out_mdl_file – An optional path for a JSON file to save the catboost model to disk.

  • verbose_training – a boolean to specifying whether a verbose output should be provided during training (Default: False)

Classify Functions

rsgislib.classification.classcatboost.apply_catboost_binary_classifier(mdl_cls_obj, in_msk_img: str, img_mask_val: int, img_file_info: List, out_class_img: str, gdalformat: str = 'KEA', out_prob_img: Optional[str] = None)

This function applies a trained binary (i.e., two classes) catboost model. The function train_catboost_binary_classifier can be used to train such as model.

  • mdl_cls_obj – a trained catboost binary model. Can be loaded from disk using the get_catboost_mdl function.

  • in_msk_img – is an image file providing a mask to specify where should be classified. Simplest mask is all the valid data regions (rsgislib.imageutils.gen_valid_mask)

  • img_mask_val – the pixel value within the in_msk_img to limit the region to which the classification is applied. Can be used to create a hierarchical classification.

  • img_file_info – a list of rsgislib.imageutils.ImageBandInfo objects (also used within rsgislib.zonalstats.extract_zone_img_band_values_to_hdf) to identify which images and bands are to be used for the classification so it adheres to the training data.

  • out_class_img – output image file with the hard classification output.

  • gdalformat – is the output image format (default: KEA)

  • out_prob_img – Optional output image which contains the probabilities for the two classes.


rsgislib.classification.classcatboost.get_catboost_mdl(mdl_file: Optional[str] = None, mdl_format: str = 'json')

A function which creates a default catboost classifier and optionally loads an existing model is available.

  • mdl_file – a path to a saved catboost model.

  • mdl_format – the format of the model file. cbm is catboost binary and json is JSON format.


catboost.CatBoostClassifier object