Skip to content

Commit

Permalink
[ENH] Add warning for categorical variables (#279)
Browse files Browse the repository at this point in the history
* Add warning for categorical variables in Python

* matlab categorical warning

* Fixes for categorical detection

* Remove deprecated functions

* linting

* fix faulty keyword argument

* Fixes for false warnings
  • Loading branch information
ReinderVosDeWael authored Mar 25, 2022
1 parent d80e465 commit f1ea4dc
Show file tree
Hide file tree
Showing 13 changed files with 202 additions and 78 deletions.
27 changes: 1 addition & 26 deletions brainstat/stats/SLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from cmath import sqrt
from pathlib import Path
from pprint import pformat
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -13,7 +13,6 @@
from nibabel.nifti1 import Nifti1Image

from brainstat._typing import ArrayLike
from brainstat._utils import deprecated
from brainstat.datasets import fetch_parcellation, fetch_template_surface
from brainstat.datasets.base import fetch_template_surface, fetch_yeo_networks_metadata
from brainstat.mesh.utils import _mask_edges, mesh_edges
Expand Down Expand Up @@ -358,30 +357,6 @@ def lat(self, value):
def lat(self):
del self._lat

@deprecated(
"Direct usage of this method is deprecated. Please use the `fit()` method instead."
)
def t_test(self) -> None:
raise NotImplementedError

@deprecated(
"Direct usage of this method is deprecated. Please use the `fit()` method instead."
)
def linear_model(self, Y: Any) -> None:
raise NotImplementedError

@deprecated(
"Direct usage of this method is deprecated. Please use the `fit()` method instead."
)
def fdr(self) -> None:
raise NotImplementedError

@deprecated(
"Direct usage of this method is deprecated. Please use the `fit()` method instead."
)
def random_field_theory(self) -> None:
raise NotImplementedError


def _merge_rft(P1: dict, P2: dict) -> dict:
"""Merge two one-tailed outputs of the random_field_theory function.
Expand Down
161 changes: 121 additions & 40 deletions brainstat/stats/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd

from brainstat._typing import ArrayLike
from brainstat._utils import deprecated
from brainstat._utils import logger


def check_names(
Expand All @@ -22,6 +22,50 @@ def check_names(
return None


def check_categorical_variables(
x: Union[ArrayLike, pd.DataFrame, "FixedEffect"],
names: Optional[Union[str, Sequence[str]]] = None,
) -> None:
"""Checks whether categorical variables were provided as such.
Parameters
----------
x : ArrayLike, pandas.DataFrame
The input array.
names : str, sequence of str or None, optional
Names for each column in `x`. Default is None.
"""
if np.isscalar(x):
return

if isinstance(names, str):
names = [names]

if isinstance(x, pd.DataFrame):
x_df = x
elif isinstance(x, FixedEffect):
x_df = x.m
else:
x_df = pd.DataFrame(x, columns=names)

categorical_warning_threshold = np.minimum(5, x_df.shape[0] - 1)

for i, column in enumerate(x_df):
if not pd.api.types.is_numeric_dtype(x_df[column]):
# Variable is categorical.
continue

unique_numbers = x_df[column].unique()
if 1 < unique_numbers.size < categorical_warning_threshold:
if names is not None:
name = names[i]
else:
name = f"Column {i}"
logger.warning(
f"{name} has {unique_numbers.size} unique values but was supplied as a numeric (i.e. continuous) variable. Should it be a categorical variable? If yes, the easiest way to provide categorical variables is to convert numerics to strings."
)


def to_df(
x: Union[int, ArrayLike, "FixedEffect"],
n: int = 1,
Expand Down Expand Up @@ -246,12 +290,16 @@ def __init__(
x: Optional[Union[ArrayLike, pd.DataFrame]] = None,
names: Optional[Union[str, Sequence[str]]] = None,
add_intercept: bool = True,
_check_categorical: bool = True,
) -> None:

if x is None:
self.m = pd.DataFrame()
return

if _check_categorical:
check_categorical_variables(x, names)

if isinstance(x, FixedEffect):
self.m = x.m
return
Expand All @@ -265,7 +313,9 @@ def __init__(
self.m = to_df(x, names=names).reset_index(drop=True)
if add_intercept and "intercept" not in self.names:
self.m.insert(0, "intercept", 1)
check_duplicate_names(self.m)

if _check_categorical:
check_duplicate_names(self.m)

def _broadcast(
self, t: Union[ArrayLike, "FixedEffect"], idx: Optional[int] = None
Expand All @@ -286,7 +336,7 @@ def _add(
return NotImplemented

if self.empty:
return FixedEffect(t, add_intercept=False)
return FixedEffect(t, add_intercept=False, _check_categorical=False)

idx = None
if check_names(t) is None:
Expand All @@ -307,7 +357,7 @@ def _add(
df = pd.concat(terms, axis=1)
df.columns = names[0] + names[1]
cols = remove_duplicate_columns(df, tol=self.tolerance)
return FixedEffect(df[cols], add_intercept=False)
return FixedEffect(df[cols], add_intercept=False, _check_categorical=False)

def __add__(
self, t: Union[ArrayLike, "FixedEffect", "MixedEffect"]
Expand All @@ -331,7 +381,9 @@ def __sub__(self, t: Union[ArrayLike, "FixedEffect"]) -> "FixedEffect":
m = self.m / self.m.abs().sum(0)
merged = m.T.merge(df.T, how="outer", indicator=True)
mask = (merged._merge.values == "left_only")[: self.m.shape[1]]
return FixedEffect(self.m[self.m.columns[mask]], add_intercept=False)
return FixedEffect(
self.m[self.m.columns[mask]], add_intercept=False, _check_categorical=False
)

def _mul(
self, t: Union[ArrayLike, "FixedEffect", "MixedEffect"], side: str = "left"
Expand All @@ -349,7 +401,9 @@ def _mul(
names = [f"{t}*{k}" for k in self.names]
else:
names = [f"{k}*{t}" for k in self.names]
return FixedEffect(m, names=names, add_intercept=False)
return FixedEffect(
m, names=names, add_intercept=False, _check_categorical=False
)

df = self._broadcast(t)
if df.empty:
Expand All @@ -366,7 +420,7 @@ def _mul(
df = pd.concat(prod, axis=1)
df.columns = names
cols = remove_duplicate_columns(df, tol=self.tolerance)
return FixedEffect(df[cols], add_intercept=False)
return FixedEffect(df[cols], add_intercept=False, _check_categorical=False)

def __mul__(
self, t: Union[ArrayLike, "FixedEffect", "MixedEffect"]
Expand Down Expand Up @@ -467,6 +521,7 @@ def __init__(
ranisvar: bool = False,
add_intercept: bool = True,
add_identity: bool = True,
_check_categorical: bool = True,
) -> None:

if isinstance(ran, MixedEffect):
Expand All @@ -477,6 +532,9 @@ def __init__(
if ran is None:
self.variance = FixedEffect()
else:
if _check_categorical:
check_categorical_variables(ran, name_ran)

ran = to_df(ran)
if not ranisvar:
if ran.size == 1:
Expand All @@ -498,14 +556,22 @@ def __init__(

ran = ran @ ran.T
ran = ran.values.ravel()
self.variance = FixedEffect(ran, names=name_ran, add_intercept=False)
self.mean = FixedEffect(fix, names=name_fix, add_intercept=add_intercept)
self.variance = FixedEffect(
ran, names=name_ran, add_intercept=False, _check_categorical=False
)
self.mean = FixedEffect(
fix,
names=name_fix,
add_intercept=add_intercept,
_check_categorical=_check_categorical,
)

if add_identity:
I = MixedEffect(1, name_ran="I", add_identity=False)
I = MixedEffect(
1, name_ran="I", add_identity=False, _check_categorical=False
)
tmp_mixed = self + I
self.variance = tmp_mixed.variance

self.set_identity_last()

def set_identity_last(self) -> None:
Expand Down Expand Up @@ -540,14 +606,18 @@ def set_identity_last(self) -> None:
def broadcast_to(self, r1: "MixedEffect", r2: "MixedEffect") -> FixedEffect:
if r1.variance.shape[0] == 1:
v = np.eye(max(r2.shape[0], int(np.sqrt(r2.shape[2]))))
return FixedEffect(v.ravel(), names="I", add_intercept=False)
return FixedEffect(
v.ravel(), names="I", add_intercept=False, _check_categorical=False
)
return r1.variance

def _add(
self, r: Union[FixedEffect, "MixedEffect"], side: str = "left"
) -> "MixedEffect":
if not isinstance(r, MixedEffect):
r = MixedEffect(fix=r, add_intercept=False, add_identity=False)
r = MixedEffect(
fix=r, add_intercept=False, add_identity=False, _check_categorical=False
)

r.variance = self.broadcast_to(r, self)
self.variance = self.broadcast_to(self, r)
Expand All @@ -559,7 +629,12 @@ def _add(
fix = r.mean + self.mean

s = MixedEffect(
ran=ran, fix=fix, ranisvar=True, add_intercept=False, add_identity=False
ran=ran,
fix=fix,
ranisvar=True,
add_intercept=False,
add_identity=False,
_check_categorical=False,
)
s.set_identity_last()
return s
Expand All @@ -584,7 +659,12 @@ def _sub(
ran = r.variance - self.variance
fix = r.mean - self.mean
return MixedEffect(
ran=ran, fix=fix, ranisvar=True, add_intercept=False, add_identity=False
ran=ran,
fix=fix,
ranisvar=True,
add_intercept=False,
add_identity=False,
_check_categorical=False,
)

def __sub__(self, r: Union[FixedEffect, "MixedEffect"]) -> "MixedEffect":
Expand All @@ -608,7 +688,12 @@ def _mul(
ran = r.variance * self.variance
fix = r.mean * self.mean
s = MixedEffect(
ran=ran, fix=fix, ranisvar=True, add_intercept=False, add_identity=False
ran=ran,
fix=fix,
ranisvar=True,
add_intercept=False,
add_identity=False,
_check_categorical=False,
)

if self.mean.matrix.values.size > 0:
Expand All @@ -621,6 +706,7 @@ def _mul(
np.outer(x[i], x[j]).T.ravel(),
names=self.mean.names[i],
add_intercept=False,
_check_categorical=False,
)
else:
xs = x[i] + x[j]
Expand All @@ -630,11 +716,17 @@ def _mul(

v = np.outer(xs, xs) / 4
t = t + FixedEffect(
v.ravel(), names=xs_name, add_intercept=False
v.ravel(),
names=xs_name,
add_intercept=False,
_check_categorical=False,
)
v = np.outer(xd, xd) / 4
t = t + FixedEffect(
v.ravel(), names=xd_name, add_intercept=False
v.ravel(),
names=xd_name,
add_intercept=False,
_check_categorical=False,
)

s.variance = s.variance + t * r.variance
Expand All @@ -646,7 +738,9 @@ def _mul(
for j in range(i + 1):
if i == j:
t = t + FixedEffect(
np.outer(x[i], x[j]).ravel(), names=r.mean.names[i]
np.outer(x[i], x[j]).ravel(),
names=r.mean.names[i],
_check_categorical=False,
)
else:
xs = x[i] + x[j]
Expand All @@ -656,11 +750,17 @@ def _mul(

v = np.outer(xs, xs) / 4
t = t + FixedEffect(
v.ravel(), names=xs_name, add_intercept=False
v.ravel(),
names=xs_name,
add_intercept=False,
_check_categorical=False,
)
v = np.outer(xd, xd) / 4
t = t + FixedEffect(
v.ravel(), names=xd_name, add_intercept=False
v.ravel(),
names=xd_name,
add_intercept=False,
_check_categorical=False,
)
s.variance = s.variance + self.variance * t
s.set_identity_last()
Expand Down Expand Up @@ -692,22 +792,3 @@ def _repr_html_(self) -> str:
+ "\n\nVariance:\n"
+ self.variance._repr_html_()
)


## Deprecated functions
@deprecated("Please use FixedEffect instead.")
def Term(x=None, names=None):
return FixedEffect(x=x, names=names, add_intercept=False)


@deprecated("Please use MixedEffect instead.")
def Random(ran=None, fix=None, name_ran=None, name_fix=None, ranisvar=False):
return MixedEffect(
ran=ran,
fix=fix,
name_ran=name_ran,
name_fix=name_fix,
ranisvar=ranisvar,
add_intercept=False,
add_identity=False,
)
Loading

0 comments on commit f1ea4dc

Please sign in to comment.