"""
Represents covariates in the model.
"""
from numbers import Number
from numpy import isnan
from typing import Optional
from cascade_at.core.log import get_loggers
LOG = get_loggers(__name__)
[docs]class Covariate:
def __init__(self, column_name: str,
reference: Optional[float] = None,
max_difference: Optional[float] = None):
"""
Establishes a reference value for a covariate column on input data
and in output data. It is possible to create a covariate column with
nothing but a name, but it must have a reference value before it
can be used in a model.
Parameters
----------
column_name
Name of the column in the input data.
reference
Reference where covariate has no effect.
max_difference
If a data point's covariate is farther than `max_difference`
from the reference value, then this data point is excluded
from the calculation. Must be greater than or equal to zero.
"""
self._name = None
self._reference = None
self._max_difference = None
self.name = column_name
if reference is not None:
self.reference = reference
self.max_difference = max_difference
@property
def name(self):
return self._name
@name.setter
def name(self, nom):
if not isinstance(nom, str):
raise TypeError(f"Covariate name must be a string, not {nom}")
if len(nom) < 1:
raise ValueError(f"Covariate name must not be empty string")
self._name = nom
@property
def reference(self):
return self._reference
@reference.setter
def reference(self, ref):
self._reference = float(ref)
@property
def max_difference(self):
return self._max_difference
@max_difference.setter
def max_difference(self, difference):
if difference is None or isinstance(difference, Number) and isnan(difference):
self._max_difference = None
else:
diff = float(difference)
if diff < 0:
raise ValueError(
f"max difference for a covariate must be greater than "
f"or equal to zero, not {difference}")
self._max_difference = diff
def __hash__(self):
return hash((self._name, self._reference, self._max_difference))
def __repr__(self):
return f"Covariate({self.name}, {self.reference}, {self.max_difference})"
def __eq__(self, other):
if not isinstance(other, Covariate):
raise NotImplementedError(f"Cannot compare a covariate and a {type(other)}: {other}.")
return (self._name == other.name and self._reference == other._reference
and self._max_difference == other._max_difference)