Source code for coniferest.session.oracle

from typing import Callable, Optional

import numpy as np

from coniferest.coniferest import Coniferest
from coniferest.label import Label
from coniferest.session import Session
from coniferest.session.callback import TerminateAfter, TerminateAfterNAnomalies


[docs] class OracleSession(Session): """Automated session to run experiments with labeled data. Parameters ---------- data : np.ndarray, 2-D Array with feature values of objects. labels : np.ndarray, 1-D Array with true labels, of Label or int type model : Coniferest Anomaly detection model to use max_iterations : int Maximum number of asked decisions max_anomalies : int Maximum number of anomalies to search for on_refit_callbacks : list of callable, or callable, or None, optional Functions to be called when model is refitted, default is empty list. on_decision_callbacks : list of callable, or callable, or None, optional Functions to be called when expert decision is made, default is empty list. Also see methods and attributes from the base `Session` class """ def __init__( self, data: np.ndarray, labels: np.ndarray, *, model: Coniferest, max_iterations: int, max_anomalies: int, on_refit_callbacks=None, on_decision_callbacks=None, ): if on_decision_callbacks is None: on_decision_callbacks = [] elif isinstance(on_decision_callbacks, Callable): on_decision_callbacks = [ on_decision_callbacks, ] super().__init__( data=data, metadata=labels, model=model, # Session.metadata is labels, so we just use this candidate metadata is the true label decision_callback=lambda label, _features, _self: label, on_refit_callbacks=on_refit_callbacks, on_decision_callbacks=[ TerminateAfter(max_iterations), TerminateAfterNAnomalies(max_anomalies), ] + on_decision_callbacks, )
[docs] def create_oracle_session( data: np.ndarray, labels: np.ndarray[int], *, model: Coniferest, max_iterations: Optional[int] = None, on_refit_callbacks=None, on_decision_callbacks=None, ) -> OracleSession: """Create an automated session to run experiments with labeled data. Parameters ---------- data : np.ndarray, 2-D Array with feature values of objects. labels : np.ndarray, 1-D Array with true labels, of Label or int type model : Coniferest Anomaly detection model to use max_iterations : int or None, optional Maximum number of asked decisions. Default is 5 times the number of anomalies. on_refit_callbacks : list of callable, or callable, or None, optional Functions to be called when model is refitted, default is empty list. on_decision_callbacks : list of callable, or callable, or None, optional Functions to be called when expert decision is made, default is empty list. Returns ------- OracleSession """ n_anomalies = np.sum(labels == Label.ANOMALY) max_iterations = min(n_anomalies * 5.0, len(labels)) if max_iterations is None else max_iterations return OracleSession( data, labels, model=model, max_iterations=max_iterations, max_anomalies=n_anomalies, on_refit_callbacks=on_refit_callbacks, on_decision_callbacks=on_decision_callbacks, )