Source code for evalml.pipelines.components.component_base

"""Base class for all components."""
import copy
from abc import ABC, abstractmethod

import cloudpickle

from evalml.exceptions import MethodPropertyNotFoundError
from evalml.pipelines.components.component_base_meta import ComponentBaseMeta
from evalml.utils import (
    _downcast_nullable_X,
    _downcast_nullable_y,
    classproperty,
    infer_feature_types,
    log_subtitle,
    safe_repr,
)
from evalml.utils.logger import get_logger


[docs]class ComponentBase(ABC, metaclass=ComponentBaseMeta): """Base class for all components. Args: parameters (dict): Dictionary of parameters for the component. Defaults to None. component_obj (obj): Third-party objects useful in component implementation. Defaults to None. random_seed (int): Seed for the random number generator. Defaults to 0. """ _default_parameters = None _can_be_used_for_fast_partial_dependence = True # Referring to the pandas nullable dtypes; not just woodwork logical types _integer_nullable_incompatibilities = [] _boolean_nullable_incompatibilities = [] def __init__(self, parameters=None, component_obj=None, random_seed=0, **kwargs): """Base class for all components. Args: parameters (dict): Dictionary of parameters for the component. Defaults to None. component_obj (obj): Third-party objects useful in component implementation. Defaults to None. random_seed (int): Seed for the random number generator. Defaults to 0. kwargs (Any): Any keyword arguments to pass into the component. """ self.random_seed = random_seed self._component_obj = component_obj self._parameters = parameters or {} self._is_fitted = False @property @classmethod @abstractmethod def name(cls): """Returns string name of this component.""" @property @classmethod @abstractmethod def modifies_features(cls): """Returns whether this component modifies (subsets or transforms) the features variable during transform. For Estimator objects, this attribute determines if the return value from `predict` or `predict_proba` should be used as features or targets. """ @property @classmethod @abstractmethod def modifies_target(cls): """Returns whether this component modifies (subsets or transforms) the target variable during transform. For Estimator objects, this attribute determines if the return value from `predict` or `predict_proba` should be used as features or targets. """ @property @classmethod @abstractmethod def training_only(cls): """Returns whether or not this component should be evaluated during training-time only, or during both training and prediction time.""" @classproperty def needs_fitting(self): """Returns boolean determining if component needs fitting before calling predict, predict_proba, transform, or feature_importances. This can be overridden to False for components that do not need to be fit or whose fit methods do nothing. Returns: True. """ return True @property def parameters(self): """Returns the parameters which were used to initialize the component.""" return copy.copy(self._parameters) @classproperty def default_parameters(cls): """Returns the default parameters for this component. Our convention is that Component.default_parameters == Component().parameters. Returns: dict: Default parameters for this component. """ if cls._default_parameters is None: cls._default_parameters = cls().parameters return cls._default_parameters @classproperty def _supported_by_list_API(cls): return not cls.modifies_target def _handle_partial_dependence_fast_mode( self, pipeline_parameters, X=None, target=None, ): """Determines whether or not a component can be used with partial dependence's fast mode. Args: pipeline_parameters (dict): Pipeline parameters that will be used to create the pipelines used in partial dependence fast mode. X (pd.DataFrame, optional): Holdout data being used for partial dependence calculations. target (str, optional): The target whose values we are trying to predict. """ if self._can_be_used_for_fast_partial_dependence: return pipeline_parameters raise TypeError( f"Component {self.name} cannot run partial dependence fast mode.", )
[docs] def clone(self): """Constructs a new component with the same parameters and random state. Returns: A new instance of this component with identical parameters and random state. """ return self.__class__(**self.parameters, random_seed=self.random_seed)
[docs] def fit(self, X, y=None): """Fits component to data. Args: X (pd.DataFrame): The input training data of shape [n_samples, n_features] y (pd.Series, optional): The target training data of length [n_samples] Returns: self Raises: MethodPropertyNotFoundError: If component does not have a fit method or a component_obj that implements fit. """ X = infer_feature_types(X) if y is not None: y = infer_feature_types(y) try: self._component_obj.fit(X, y) return self except AttributeError: raise MethodPropertyNotFoundError( "Component requires a fit method or a component_obj that implements fit", )
[docs] def describe(self, print_name=False, return_dict=False): """Describe a component and its parameters. Args: print_name(bool, optional): whether to print name of component return_dict(bool, optional): whether to return description as dictionary in the format {"name": name, "parameters": parameters} Returns: None or dict: Returns dictionary if return_dict is True, else None. """ logger = get_logger(f"{__name__}.describe") if print_name: title = self.name log_subtitle(logger, title) for parameter in self.parameters: parameter_str = ("\t * {} : {}").format( parameter, self.parameters[parameter], ) logger.info(parameter_str) if return_dict: component_dict = {"name": self.name} component_dict.update({"parameters": self.parameters}) return component_dict
[docs] def save(self, file_path, pickle_protocol=cloudpickle.DEFAULT_PROTOCOL): """Saves component at file path. Args: file_path (str): Location to save file. pickle_protocol (int): The pickle data stream format. """ with open(file_path, "wb") as f: cloudpickle.dump(self, f, protocol=pickle_protocol)
[docs] @staticmethod def load(file_path): """Loads component at file path. Args: file_path (str): Location to load file. Returns: ComponentBase object """ with open(file_path, "rb") as f: return cloudpickle.load(f)
def __eq__(self, other): """Check for equality.""" if not isinstance(other, self.__class__): return False random_seed_eq = self.random_seed == other.random_seed if not random_seed_eq: return False attributes_to_check = ["_parameters", "_is_fitted"] for attribute in attributes_to_check: if getattr(self, attribute) != getattr(other, attribute): return False return True def __str__(self): """String representation of a component.""" return self.name def __repr__(self): """String representation of a component.""" parameters_repr = ", ".join( [f"{key}={safe_repr(value)}" for key, value in self.parameters.items()], ) return f"{(type(self).__name__)}({parameters_repr})"
[docs] def update_parameters(self, update_dict, reset_fit=True): """Updates the parameter dictionary of the component. Args: update_dict (dict): A dict of parameters to update. reset_fit (bool, optional): If True, will set `_is_fitted` to False. """ self._parameters.update(update_dict) if reset_fit: self._is_fitted = False
def _handle_nullable_types(self, X=None, y=None): """Transforms X and y to remove any incompatible nullable types according to a component's needs. Args: X (pd.DataFrame, optional): Input data to a component of shape [n_samples, n_features]. May contain nullable types. y (pd.Series or pd.DataFrame, optional): The target of length [n_samples] or the unstacked target for a multiseries problem of length [n_samples, n_features*n_series]. May contain nullable types. Returns: X, y with any incompatible nullable types downcasted to compatible equivalents. """ X_bool_incompatible = "X" in self._boolean_nullable_incompatibilities X_int_incompatible = "X" in self._integer_nullable_incompatibilities if X is not None and (X_bool_incompatible or X_int_incompatible): X = _downcast_nullable_X( X, handle_boolean_nullable=X_bool_incompatible, handle_integer_nullable=X_int_incompatible, ) y_bool_incompatible = "y" in self._boolean_nullable_incompatibilities y_int_incompatible = "y" in self._integer_nullable_incompatibilities if y is not None and (y_bool_incompatible or y_int_incompatible): y = _downcast_nullable_y( y, handle_boolean_nullable=y_bool_incompatible, handle_integer_nullable=y_int_incompatible, ) return X, y