Source code for cascade_at.inputs.locations

import networkx as nx
import numpy as np
import pandas as pd
from typing import List, Optional

from cascade_at.inputs.utilities.gbd_ids import CascadeConstants
from cascade_at.core.db import db_queries
from cascade_at.core import CascadeATError
from cascade_at.core.log import get_loggers

LOG = get_loggers(__name__)


class LocationDAGError(CascadeATError):
    pass


[docs]class LocationDAG: def __init__(self, location_set_version_id: Optional[int] = None, gbd_round_id: Optional[int] = None, df: Optional[pd.DataFrame] = None, root: Optional[int] = None): """ Create a location DAG from the GBD location hierarchy, using networkx graph where each node is the location ID, and its properties are all properties from db_queries. The root of this dag is the global location ID. Parameters ---------- location_set_version_id The location set version corresponding to the hierarchy to pull from the IHME databases gbd_round_id Which gbd round the location set version is coming from df An optional df to pass instead of location sets and gbd rounds if you'd rather construct the DAG from a pandas data frame. """ if df is None: if location_set_version_id is None: raise LocationDAGError("Need to either pass a location set version ID or a data frame.") else: if gbd_round_id is None: raise LocationDAGError("You must pass a GBD round ID with a location set" "version ID to the db_queries function.") LOG.info( f"Creating a location DAG for location_set_version_id " f"{location_set_version_id}") self.location_set_version_id = location_set_version_id self.df = db_queries.get_location_metadata( location_set_version_id=location_set_version_id, location_set_id=CascadeConstants.ESTIMATION_LOCATION_HIERARCHY_ID, gbd_round_id=gbd_round_id ) root = CascadeConstants.GLOBAL_LOCATION_ID else: if location_set_version_id is not None: LOG.warn("You passed both a location set version ID and a dataframe." "Ignoring the location set version ID and creating the DAG based" "on the data frame.") if root is None: raise LocationDAGError("Must specify a root if you're passing in a data frame.") self.df = df self.dag = nx.DiGraph() for index, row in self.df.iterrows(): self.dag.add_node(int(row['location_id']), **row.to_dict()) self.dag.add_edges_from([ (int(row.parent_id), int(row.location_id)) for row in self.df.loc[ self.df.location_id != root].itertuples() ]) self.dag.graph["root"] = root
[docs] def depth(self, location_id: int) -> int: """ Gets the depth of the hierarchy at this location. """ return nx.shortest_path_length(G=self.dag, source=self.dag.graph["root"], target=location_id)
[docs] def descendants(self, location_id: int) -> List[int]: """ Gets all descendants (not just direct children) for a location ID. :param location_id: (int) :return: """ return nx.algorithms.dag.descendants(G=self.dag, source=location_id)
[docs] def children(self, location_id: int) -> List[int]: """ Gets the child location IDs. """ return list(self.dag.successors(location_id))
[docs] def parent_children(self, location_id: int) -> List[int]: """ Gets the parent and the child location IDs. """ return [location_id] + list(self.dag.successors(location_id))
[docs] def is_leaf(self, location_id: int) -> bool: """ Checks if a location is a leaf node in the tree. """ return len(list(self.dag.successors(location_id))) == 0
[docs] def to_dataframe(self) -> pd.DataFrame: """ Converts the location DAG to a data frame with location ID and parent ID and name. Helpful for debugging, and putting into the dismod database. Returns: pd.DataFrame """ sorted_locations = list(nx.lexicographical_topological_sort(self.dag)) parents = list() names = list() for loc in sorted_locations: parent = list(self.dag.predecessors(loc)) if parent: parents.append(int(parent[0])) else: parents.append(np.nan) names.append(self.dag.nodes[loc]["location_name"]) return pd.DataFrame(dict( location_id=sorted_locations, parent_id=parents, name=names ))
def locations_by_drill(drill_location_start: int, drill_location_end: List[int], dag: LocationDAG): if not drill_location_start and drill_location_end: raise ValueError( "A location_drill_start must be specified in order " "to perform a location drill.") elif drill_location_start and not drill_location_end: LOG.info( f"This is a DRILL model, so only going to pull data " f"associated with drill location start " f"{drill_location_start} and its descendants." ) drill_locations_all = ([drill_location_start] + list(dag.descendants( location_id=drill_location_start))) drill_locations = list( dag.parent_children(drill_location_start)) elif drill_location_start and drill_location_end: LOG.info( f"This is a DRILL model, so only data for " f"{drill_location_start} (the parent) and descendents " f"of {drill_location_end} (the children) will be pulled." ) drill_locations_all = [drill_location_start] for child in drill_location_end: drill_locations_all.append(child) drill_locations_all = drill_locations_all + list( dag.descendants(location_id=child)) drill_locations = [drill_location_start] + drill_location_end else: drill_locations_all = None drill_locations = None return drill_locations_all, drill_locations