"""Data check that checks if any of the features are highly correlated with the target by using mutual information or Pearson correlation."""
from woodwork.config import CONFIG_DEFAULTS
from evalml.data_checks import (
    DataCheck,
    DataCheckActionCode,
    DataCheckActionOption,
    DataCheckMessageCode,
    DataCheckWarning,
)
from evalml.utils.woodwork_utils import infer_feature_types
[docs]class TargetLeakageDataCheck(DataCheck):
    """Check if any of the features are highly correlated with the target by using mutual information, Pearson correlation, and other correlation metrics.
    If method='mutual_info', this data check uses mutual information and supports all target and feature types.
    Other correlation metrics only support binary with numeric and boolean dtypes. This method will return a value in [-1, 1] if other correlation metrics are selected
    and will returns a value in [0, 1] if mutual information is selected. Correlation metrics available can be found in Woodwork's
    `dependence_dict method <https://woodwork.alteryx.com/en/stable/generated/woodwork.table_accessor.WoodworkTableAccessor.dependence_dict.html#woodwork.table_accessor.WoodworkTableAccessor.dependence_dict>`_.
    Args:
        pct_corr_threshold (float): The correlation threshold to be considered leakage. Defaults to 0.95.
        method (string): The method to determine correlation. Use 'all' or 'max' for the maximum correlation, or for specific correlation metrics, use their name (ie 'mutual_info' for mutual information, 'pearson' for Pearson correlation, etc).
            possible methods can be found in Woodwork's `config <https://woodwork.alteryx.com/en/stable/guides/setting_config_options.html?highlight=config#Viewing-Config-Settings>`_, under `correlation_metrics`.
            Defaults to 'all'.
    """
    def __init__(self, pct_corr_threshold=0.95, method="all"):
        if pct_corr_threshold < 0 or pct_corr_threshold > 1:
            raise ValueError(
                "pct_corr_threshold must be a float between 0 and 1, inclusive.",
            )
        methods = CONFIG_DEFAULTS["correlation_metrics"]
        if method not in methods:
            raise ValueError(
                f"Method '{method}' not in available correlation methods. Available methods include {methods}",
            )
        self.pct_corr_threshold = pct_corr_threshold
        self.method = method
        self._method_to_check = "max" if method == "all" else method
    def _calculate_dependence(self, X, y):
        highly_corr_cols = []
        X2 = X.ww.copy()
        target_str = "target_y"
        while target_str in list(X2.columns):
            target_str += "_y"
        X2.ww[target_str] = y
        try:
            dep_corr = X2.ww.dependence_dict(
                measures=self.method,
                target_col=target_str,
            )
        except KeyError:
            # keyError raised when the target does not appear due to incompatibility with the metric, return []
            return []
        highly_corr_cols = sorted(
            [
                corr_info["column_1"]
                for corr_info in dep_corr
                if abs(corr_info[self._method_to_check]) >= self.pct_corr_threshold
            ],
            key=lambda x: X2.columns.tolist().index(x),
        )
        return highly_corr_cols
[docs]    def validate(self, X, y):
        """Check if any of the features are highly correlated with the target by using mutual information, Pearson correlation, and/or Spearman correlation.
        If `method='mutual_info'` or `'method='max'`, supports all target and feature types. Other correlation metrics only support binary with numeric and boolean dtypes.
        This method will return a value in [-1, 1] if other correlation metrics are selected and will returns a value in [0, 1] if mutual information is selected.
        Args:
            X (pd.DataFrame, np.ndarray): The input features to check.
            y (pd.Series, np.ndarray): The target data.
        Returns:
            dict (DataCheckWarning): dict with a DataCheckWarning if target leakage is detected.
        Examples:
            >>> import pandas as pd
            Any columns that are strongly correlated with the target will raise a warning. This could be indicative of
            data leakage.
            >>> X = pd.DataFrame({
            ...    "leak": [10, 42, 31, 51, 61] * 15,
            ...    "x": [42, 54, 12, 64, 12] * 15,
            ...    "y": [13, 5, 13, 74, 24] * 15,
            ... })
            >>> y = pd.Series([10, 42, 31, 51, 40] * 15)
            ...
            >>> target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.95)
            >>> assert target_leakage_check.validate(X, y) == [
            ...     {
            ...         "message": "Column 'leak' is 95.0% or more correlated with the target",
            ...         "data_check_name": "TargetLeakageDataCheck",
            ...         "level": "warning",
            ...         "code": "TARGET_LEAKAGE",
            ...         "details": {"columns": ["leak"], "rows": None},
            ...         "action_options": [
            ...             {
            ...                 "code": "DROP_COL",
            ...                 "data_check_name": "TargetLeakageDataCheck",
            ...                 "parameters": {},
            ...                 "metadata": {"columns": ["leak"], "rows": None}
            ...             }
            ...         ]
            ...     }
            ... ]
            The default method can be changed to pearson from mutual_info.
            >>> X["x"] = y / 2
            >>> target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.8, method="pearson")
            >>> assert target_leakage_check.validate(X, y) == [
            ...     {
            ...         "message": "Columns 'leak', 'x' are 80.0% or more correlated with the target",
            ...         "data_check_name": "TargetLeakageDataCheck",
            ...         "level": "warning",
            ...         "details": {"columns": ["leak", "x"], "rows": None},
            ...         "code": "TARGET_LEAKAGE",
            ...         "action_options": [
            ...             {
            ...                 "code": "DROP_COL",
            ...                  "data_check_name": "TargetLeakageDataCheck",
            ...                  "parameters": {},
            ...                  "metadata": {"columns": ["leak", "x"], "rows": None}
            ...             }
            ...         ]
            ...     }
            ... ]
        """
        messages = []
        X = infer_feature_types(X)
        y = infer_feature_types(y)
        highly_corr_cols = self._calculate_dependence(X, y)
        warning_msg_singular = "Column {} is {}% or more correlated with the target"
        warning_msg_plural = "Columns {} are {}% or more correlated with the target"
        if highly_corr_cols:
            if len(highly_corr_cols) == 1:
                warning_msg = warning_msg_singular.format(
                    "'{}'".format(str(highly_corr_cols[0])),
                    self.pct_corr_threshold * 100,
                )
            else:
                warning_msg = warning_msg_plural.format(
                    (", ").join(["'{}'".format(str(col)) for col in highly_corr_cols]),
                    self.pct_corr_threshold * 100,
                )
            messages.append(
                DataCheckWarning(
                    message=warning_msg,
                    data_check_name=self.name,
                    message_code=DataCheckMessageCode.TARGET_LEAKAGE,
                    details={"columns": highly_corr_cols},
                    action_options=[
                        DataCheckActionOption(
                            DataCheckActionCode.DROP_COL,
                            data_check_name=self.name,
                            metadata={"columns": highly_corr_cols},
                        ),
                    ],
                ).to_dict(),
            )
        return messages