Source code for cascade_at.model.var

import numpy as np
from scipy.interpolate import RectBivariateSpline, interp1d

from cascade_at.dismod.constants import PriorKindEnum
from cascade_at.model.age_time_grid import AgeTimeGrid


[docs]class Var(AgeTimeGrid): """A Var is a function of age and time, defined by values on a grid. It linearly interpolates over values defined at grid points in a rectangular grid of age and time. This is a single age-time grid. It is usually found in :py:class:`cascade.model.DismodGroups` object which is a set of age-time grids. The following are ``DismodGroups`` containing :py:class:`cascade.model.Var`: the fit, initial guess, truth var, and scale var. Args: ages (List[float]): Points along the age axis. times (List[float]): Points in time. column_name (str): A var has an internal Pandas DataFrame representation, and this column name can be ``mean`` or ``meas_value``, depending on which Var is needed. """ def __init__(self, ages, times, column_name="mean"): self._column_name = column_name super().__init__(ages, times, columns=self._column_name) self._spline = None def check(self, name=None): """This raises a :py:class:`ValueError` if any part of the Var is uninitialized. None of the means should be nan. There should only be the three mulstds.""" if not self.grid[self._column_name].notna().all(): raise ValueError( f"Var {name} has {self.grid[self._column_name].isna().sum()} nan values") if set(self.mulstd.keys()) - {"value", "dage", "dtime"}: raise ValueError( f"Var {name} has mulstds besides the three: {list(self.mulstd.keys())}" )
[docs] def __setitem__(self, at_slice, value): """ To set a value on a Var instance, set it on ranges of age and time or at specific ages and times. >>> var = Var([0, 10, 20], [2000]) >>> var[:, :] = 0.001 >>> var[5:50, 2000] = 0.01 >>> var[10, :] = 0.02 Args: at_slice (slice, slice): What to change, as integer offset into ages and times. value (float): A float or integer. """ super().__setitem__(at_slice, [value])
[docs] def __getitem__(self, age_and_time): """ Gets the value of a Var at a single point. The point has to be one of the ages and times defined when the var was created. >>> var = Var([0, 50, 100], [1990, 2000, 2010]) >>> var[:, :] = 1e-4 >>> assert var[50, 2000] == 1e-4 Trying to read from an age and time not in the ages and times of the grid will result in a :py:class:`KeyError`. An easy way to set values is to use the `age_time` iterator, which loops through the ages and times in the underlying grid. >>> for age, time in var.age_time(): >>> var[age, time] = 0.01 * age Args: age_and_time (age, time): A two-dimensional index of age and time. Returns: float: The value at this age and time. """ return float(super().__getitem__(age_and_time)[self._column_name])
[docs] def set_mulstd(self, kind, value): """Set the value of the multiplier on the standard deviation. Kind must be one of "value", "dage", or "dtime". The value should be convertible to a float. >>> var = Var([50], [2000, 2001, 2002]) >>> var.set_mulstd("value", 0.4) """ sig = "kind is one of value, dage, dtime, and value is a float." if kind not in PriorKindEnum.__members__: raise ValueError(f"{sig} kind={kind}") self.mulstd[kind].loc[:, self._column_name] = float(value)
[docs] def get_mulstd(self, kind): """ Get the value of a standard deviation multiplier for a Var. >>> var = Var([50], [2000, 2001, 2002]) >>> var.set_mulstd("value", 0.4) >>> assert var.get_mulstd("value") == 4 If the standard deviation multiplier wasn't set, then this will return a nan. >>> assert np.isnan(var.get_mulstd("dage")) """ if kind not in PriorKindEnum.__members__: raise ValueError(f"Argument is one of value, dage, dtime, not {kind}.") return float(self.mulstd[kind][self._column_name])
def __str__(self): return f"Var({len(self.ages), len(self.times)})"
[docs] def __call__(self, age, time): """A Var is a function of age and time, and this is how to call it. >>> var = Var([0, 100], [1990, 2000]) >>> var[0, 1990] = 0 >>> var[0, 2000] = 1 >>> var[100, 1990] = 2 >>> var[100, 2000] = 3 >>> for a, t in var.age_time(): >>> print(f"At corner ({a}, {t}), {var(a, t)}") >>> for a, ti in [[53, 1997], [-5, 2000], [120, 2000], [0, 1900], [0, 2010]]: >>> print(f"Anywhere ({a}, {t}), {var(a, t)}") At corner (0.0, 1990.0), 0.0 At corner (0.0, 2000.0), 1.0 At corner (100.0, 1990.0), 2.0 At corner (100.0, 2000.0), 3.0 Anywhere (53, 2000.0), 2.06 Anywhere (-5, 2000.0), 1.0 Anywhere (120, 2000.0), 3.0 Anywhere (0, 2000.0), 1.0 Anywhere (0, 2000.0), 1.0 The grid points in a Var represent a continuous function, determined by bivariate interpolation. All points outside the grid are equal to the nearest point inside the grid. """ if self._spline is None: self._spline = self._as_function() result = self._spline(age, time) # Result can be a numpy array, so undo that if input wasn't an array. if np.isscalar(age) and np.isscalar(time): return result.item() # Numpy array has item(). else: return result
def _as_function(self): """Constructs a function which mimics how Dismod-AT turns a field of points in age and time into a continuous function. Returns: function: Of age and time. """ age_time_df = self.grid ordered = age_time_df.sort_values(["age", "time"]) age = np.sort(np.unique(age_time_df.age.values)) time = np.sort(np.unique(age_time_df.time.values)) if len(age) > 1 and len(time) > 1: heights = ordered[self._column_name].values.reshape(len(age), len(time)) spline = RectBivariateSpline(age, time, heights, kx=1, ky=1) def bivariate_function(x, y): return spline(x, y)[0] return bivariate_function elif len(age) * len(time) > 1: fill = (ordered[self._column_name].values[0], ordered[self._column_name].values[-1]) independent = age if len(age) != 1 else time spline = interp1d( independent, ordered[self._column_name].values, kind="linear", bounds_error=False, fill_value=fill) def age_spline(x, _): return spline(x) def time_spline(_, y): return spline(y) if len(age) != 1: return age_spline else: return time_spline elif len(age) == 1 and len(time) == 1: def constant_everywhere(_a, _t): return ordered[self._column_name].values[0] return constant_everywhere else: raise RuntimeError(f"Cannot interpolate if ages or times are length zero: " f"ages {len(age)} times {len(time)}")