Code documentation
The code for the Protocol Deviation (PD) Classifier is documented in the file below. The PD classifier is implemented in the Python programming language.
Table Of Content
QAPredictor
Description
This class initialises two models, dvcat
and dvdecod
, with corresponding encoders for PDs. The model predictions are managed by specified prediction methods for dvcat
and dvdecod
. Text data is encoded using a selected Sentence-BERT embedding model, facilitating a machine learning pipeline to classify and predict text-based categories.
Parameters
models_file (str)
: Path to the file containing serialised models and encoders. The file must contain bothdvcat
anddvdecod
models and their encoders, structured as dictionaries within the file.dvcat_predict_name (str)
: The name of the prediction method in thedvcat
model used for obtaining prediction probabilities. Defaults topredict_proba
.dvdecod_predict_name (str)
: The name of the prediction method in thedvdecod
model used for obtaining prediction probabilities. Defaults topredict_proba
.embeddings_model (str)
: The name of the Sentence-BERT model used for embedding the input text data. Defaults toall-mpnet-base-v2
.
__load_models
__load_models
__load_models(models_file: str)
Description
This method loads two models, dvcat
and dvdecod
, along with their respective encoders from a specified file. The file should contain a dictionary where each model is stored under the keys dvcat
and dvdecod
, with sub-keys model
and encoder
respectively. If the models or encoders are not found, a ValueError is raised.
Parameters
models_file (str)
: Path to the file containing the serialised model dictionary. The file must exist and be a valid file path.
Return value(s)
None
. Models and encoders are assigned to instance variables_dvcat_model
,_dvcat_encoder
,_dvdecod_model
, and_dvdecod_encoder
.
Raises:
AssertionError
: Ifmodels_file
is not a valid file path.ValueError
: If the expecteddvcat
ordvdecod
keys, or their sub-keysmodel
andencoder
, are not found in the file.
encode
encode
encode(texts: List[str])
Description
This method takes a list of text strings, normalises each string, and then uses a pre-trained embedding model to generate vector representations (embeddings) for each text. Normalisation is performed to ensure consistency in text formatting before encoding.
Parameters
texts (List[str])
: A list of strings to be encoded. Each string is first normalised before embedding.
Return value(s)
embeddings (numpy.ndarray or torch.Tensor)
: A collection of embeddings for the input texts, where each embedding corresponds to a normalised input text.
predict
predict
predict(prediction_input: Union[str, List[str]], batch_size: int = 10, num_predictions: int = 1)
Description
This method processes a batch of input strings (or a single string) for prediction using two classification models: dvcat
and dvdecod
. The input is first normalised, encoded, and then passed through the dvcat
model for primary predictions. Based on these initial predictions, further predictions are generated using the dvdecod
model. Predictions and probabilities are stored in a structured format for each input string and are returned as a list of categorised responses.
Parameters
prediction_input (Union[str, List[str]])
: A single string or list of strings to be classified. If a single string is provided, it is converted to a list internally.batch_size (int, optional)
: The number of strings to process at a time. Default is 10.num_predictions (int, optional)
: The number of predictions to generate per input string. If set to 0, only the highest-probability prediction is returned. Defaults to 1.
Return value(s)
List[QAResponse]
: A list of prediction responses, where each response contains the input string, predicted categories, and the associated probabilities. Each prediction includes bothdvcat
anddvdecod
categories and their respective probabilities.
dvdecod_validator
dvdecod_validator
dvdecod_validator(dvdecod_predictions, dvcat_prediction_ids)
Description
This method validates each dvdecod
prediction according to its corresponding dvcat
category label. Each dvcat category has an allowable set of dvdecod
labels, defined in QAResponseCategory.LABEL_SPACE
. For each dvdecod
prediction, the function checks if the highest-probability labels are valid according to the associated dvcat
label and stores the valid labels.
Parameters
dvdecod_predictions (ndarray)
: The raw prediction scores or probabilities from thedvdecod
model, where each row represents a sample and each column represents a label's predicted score.dvcat_prediction_ids (ndarray)
: Array ofdvcat
prediction indices, with each index corresponding to advcat
category predicted for the input sample.
Return value(s)
List[QAResponse]
: A list of prediction responses, where each response contains the input string, predicted categories, and the associated probabilities. Each prediction includes bothdvcat
anddvdecod
categories and their respective probabilities.
QAResponseCategory
Description
QAResponseCategory
contains the known categories and sub-categories of PDs within clinical trials, organised as a dictionary LABEL_SPACE
. Each main category has a list of specific deviations that fall under it. Instances of this class represent a single category and deviation code, including the associated probability score.
Parameters
dvcat (str)
: The main category label for the deviation type.dvdecod (str)
: The specific sub-category or code label for the deviation.probability (float)
: The confidence score or probability for the assigned category and sub-category.
validate_dvcat
validate_dvcat
validate_dvcat(value: str):
Description
Validates the value of the dvcat
field against the defined label space, ensuring it corresponds to a recognized category in the LABEL_SPACE
attribute.
Parameters
value (str)
: The category label to validate, provided by the user or calling function.
Return value(s)
value (str)
: The validated category label, returned unchanged if it matches an entry inLABEL_SPACE
.
Raises:
ValueError
: If the provided category label does not match any key inLABEL_SPACE
, an error is raised, listing accepted values.
validate_dvdecod
validate_dvdecod
validate_dvdecod(value, info: ValidationInfo)
Description
Validates the dvdecod
field by ensuring that the provided value corresponds to a valid entry within the label space associated with the specified dvcat
category.
Parameters
-value (str)
: The dvdecod
value to validate, expected to be a string representing a specific category within the LABEL_SPACE
for the given dvcat
. -info (ValidationInfo)
: An object containing additional context for the validation, which includes the data dictionary from which dvcat
can be retrieved.
Return value(s)
value (str)
: The validateddvdecod
value. If the validation fails due to an invalid or undefineddvcat
, the method may return theUNKNOWN_LABEL
.
Raises:
ValueError
: If thedvcat
is not defined in the data, aValueError
is raised, indicating that the provideddvdecod
is invalid. If thedvdecod
does not exist within the allowed labels for the specifieddvcat
, anotherValueError
is raised listing the accepted values.
validate_proba
validate_proba
validate_proba(value: float)
Description
Validate the probability
value for the QAResponseCategory
.
Parameters
value (float): The
probability
value to be validated. This represents the likelihood of a specific prediction and should be a float between 0 and 1.
Return value(s)
value (float)
: The validatedprobability
value if it meets the criteria (between 0 and 1).
Raises:
ValueError
: If the value is less than 0 or greater than 1, indicating an invalid probabilistic value.
QAResponse
Description
QAResponse
represents a response for a question-answering system.
validate_categories
validate_categories
validate_categories(value: List)
Description
This method ensures that all elements in the provided list are either instances of QAResponseCategory
or can be successfully parsed from dictionaries. If parsing fails, a ValueError
is raised with detailed information about the failure.
Parameters
value (float)
: Theprobability
value to be validated. This represents the likelihood of a specific prediction and should be a float between 0 and 1.
Return value(s)
validated_categories (List[QAResponseCategory)
: A sorted list of validated categories, sorted by their probability attribute in descending order.
Utils
Description
Utility Functions.
setup_logger
setup_logger
setup_logger(id: str)
Description
This function configures a logger that outputs messages to the console. It sets the logging level to INFO and clears any existing handlers to ensure that the logger is set up fresh each time the function is called. The log format includes the provided ID
, timestamp
, log level
, logger name
, and the log message
.
Parameters
id (str)
: An identifier to prefix log messages, helping to distinguish logs from different sources or components.
Return value(s)
logging.Logger
: The configured logger instance ready for use.
normalize_text
normalize_text
normalize_text(s)
Description
Normalises a given text string by performing various transformations to ensure consistent formatting and remove unnecessary whitespace or characters. This function processes the input string to:
— Remove extra spaces and leading/trailing whitespace.
— Replace multiple periods with a single period.
— Eliminate newlines and any unusual punctuation formats.
Parameters
s (str)
: The input string to be normalised.
Return value(s)
s (str)
: The normalised string with standardised spacing and punctuation.
Last updated