"""
================
Cascade Commands
================
Sequences of cascade operations that work together to create a cascade command
that will run the whole cascade (or a drill -- which is a version of the cascade).
"""
from typing import Optional, List
from cascade_at.core.log import get_loggers
from cascade_at.cascade.cascade_stacks import single_fit_with_uncertainty
from cascade_at.cascade.cascade_dags import make_cascade_dag
from cascade_at.cascade.cascade_operations import _CascadeOperation
from cascade_at.inputs.locations import LocationDAG
from cascade_at.inputs.utilities.gbd_ids import SEX_NAME_TO_ID, CascadeConstants
LOG = get_loggers(__name__)
[docs]class _CascadeCommand:
"""
Base class for a cascade command.
"""
def __init__(self):
"""
Initializes a task dictionary. All tasks added to this command
in the form of cascade operations are added to the dictionary.
Attributes
----------
self.task_dict
A dictionary of cascade operations, keyed by the command
for that operation. This is so that we can look up the
task later by the exact command.
"""
self.task_dict = {}
[docs] def add_task(self, cascade_operation: _CascadeOperation) -> None:
"""
Adds a cascade operation to the task dictionary.
Parameters
----------
cascade_operation
A cascade operation to add to the command dictionary
"""
self.task_dict.update({
cascade_operation.command: cascade_operation
})
[docs] def get_commands(self) -> List[str]:
"""
Gets a list of commands in sequence so that you can run
them without using jobmon.
Returns
-------
Returns a list of commands that you can run on the command-line.
"""
return list(self.task_dict.keys())
[docs]class Drill(_CascadeCommand):
def __init__(self, model_version_id: int,
drill_parent_location_id: int, drill_sex: int,
n_sim: int, n_pool: int = 10,
skip_configure: bool = False):
"""
A cascade command that runs a drill model, meaning
that it runs one Dismod-AT model with a parent
plus its children.
Parameters
----------
model_version_id
The model version ID to create the drill for
drill_parent_location_id
The parent location ID to start the drill from
drill_sex
Which sex to drill for
n_sim
The number of simulations to do to get uncertainty at the leaf nodes
n_pool
The number of threads to create in a multiprocessing pool.
If this is 1, then it will not do multiprocessing.
"""
super().__init__()
self.model_version_id = model_version_id
self.drill_parent_id = drill_parent_location_id
self.drill_sex = drill_sex
tasks = single_fit_with_uncertainty(
model_version_id=model_version_id,
location_id=drill_parent_location_id,
sex_id=drill_sex,
n_sim=n_sim,
n_pool=n_pool,
skip_configure=skip_configure)
for t in tasks:
self.add_task(t)
[docs]class TraditionalCascade(_CascadeCommand):
def __init__(self, model_version_id: int, split_sex: bool,
dag: LocationDAG, n_sim: int, n_pool: int = 10,
location_start: Optional[int] = None,
sex: Optional[int] = None, skip_configure: bool = False):
"""
Runs the "traditional" dismod cascade. The traditional cascade
as implemented here runs fit fixed all the way to the leaf nodes of
the cascade to save time (rather than fit both).
To get posterior to prior it uses the coefficient of variation
to get the variance of the posterior that becomes the prior
at the next level. At the leaf nodes to get final posteriors,
it does sample asymptotic. If sample asymptotic fails due to bad
constraints it does sample simulate instead.
Parameters
----------
model_version_id
The model version ID
split_sex
Whether or not to split sex
dag
A location dag that specifies the structure of the cascade hierarchy
n_sim
The number of simulations to do to get uncertainty at the leaf nodes
n_pool
The number of threads to create in a multiprocessing pool.
If this is 1, then it will not do multiprocessing.
location_start
Which location to start the cascade from (typically 1 = Global)
sex
Which sex to run the cascade for (if it's 3 = Both, then it will
split sex, if it's 1 or 2, then it will only run it for that sex.
skip_configure
Use this option to skip the initial inputs pulling; should only
be used in debugging cases by developers.
"""
super().__init__()
self.model_version_id = model_version_id
if sex is None:
sex = SEX_NAME_TO_ID['Both']
if location_start is None:
location_start = CascadeConstants.GLOBAL_LOCATION_ID
tasks = make_cascade_dag(
model_version_id=model_version_id,
dag=dag,
location_start=location_start,
sex_start=sex,
split_sex=split_sex,
n_sim=n_sim,
n_pool=n_pool,
skip_configure=skip_configure,
)
for t in tasks:
self.add_task(t)