Source code for cascade_at.saver.results_handler

"""
The results of a Cascade-AT model need to be saved to the IHME epi databases.
This module wrangles the draw files from a completed model and uploads summaries
to the epi databases for visualization in EpiViz.

Eventually, this module should be replaced by something like ``save_results_at``.
"""

import os
from pathlib import Path
import pandas as pd
from typing import List

from cascade_at.core.db import db_tools
from cascade_at.core.log import get_loggers
from cascade_at.core import CascadeATError
from cascade_at.dismod.api.dismod_extractor import ExtractorCols

LOG = get_loggers(__name__)


VALID_TABLES = [
    'model_estimate_final',
    'model_estimate_fit',
    'model_prior'
]


class UiCols:
    MEAN = 'mean'
    LOWER = 'lower'
    UPPER = 'upper'
    LOWER_QUANTILE = 0.025
    UPPER_QUANTILE = 0.975


[docs]class ResultsError(CascadeATError): """Raised when there is an error with uploading or validating the results.""" pass
[docs]class ResultsHandler: """ Handles all of the DisMod-AT results including draw saving and uploading to the epi database. """ def __init__(self): """ Attributes ---------- self.draw_keys The keys of the draw data frames self.summary_cols The columns that need to be present in all summary files """ self.draw_keys: List[str] = ['measure_id', 'year_id', 'age_group_id', 'location_id', 'sex_id', 'model_version_id'] self.summary_cols: List[str] = [UiCols.MEAN, UiCols.LOWER, UiCols.UPPER] def _validate_results(self, df: pd.DataFrame) -> None: """ Validates the input draw files. Put any additional validations here. Parameters ---------- df An input data frame with draws """ missing_cols = [x for x in self.draw_keys if x not in df.columns] if missing_cols: raise ResultsError(f"Missing id columns {missing_cols} for saving the results.") def _validate_summaries(self, df: pd.DataFrame) -> None: missing_cols = [x for x in self.summary_cols if x not in df.columns] if missing_cols: raise ResultsError(f"Missing summary columns {missing_cols} for saving the results.")
[docs] def summarize_results(self, df: pd.DataFrame) -> pd.DataFrame: """ Summarizes results from either mean or draw cols to get mean, upper, and lower cols. Parameters ---------- df A data frame with draw columns or just a mean column """ if ExtractorCols.VALUE_COL_FIT in df.columns: df[UiCols.MEAN] = df[ExtractorCols.VALUE_COL_FIT] df[UiCols.LOWER] = df[ExtractorCols.VALUE_COL_FIT] df[UiCols.UPPER] = df[ExtractorCols.VALUE_COL_FIT] else: draw_cols = [col for col in df.columns if col.startswith(ExtractorCols.VALUE_COL_SAMPLES)] df[UiCols.MEAN] = df[draw_cols].mean(axis=1) df[UiCols.LOWER] = df[draw_cols].quantile(q=UiCols.LOWER_QUANTILE, axis=1) df[UiCols.UPPER] = df[draw_cols].quantile(q=UiCols.UPPER_QUANTILE, axis=1) return df[self.draw_keys + [UiCols.MEAN, UiCols.LOWER, UiCols.UPPER]]
[docs] def save_draw_files(self, df: pd.DataFrame, model_version_id: int, directory: Path, add_summaries: bool) -> None: """ Saves a data frame by location and sex in .csv files. This currently saves the summaries, but when we get save_results working it will save draws and then summaries as part of that. Parameters ---------- df Data frame with the following columns: ['location_id', 'year_id', 'age_group_id', 'sex_id', 'measure_id', 'mean' OR 'draw'] model_version_id The model version to attach to the data directory Path to save the files to add_summaries Save an additional file with summaries to upload """ LOG.info(f"Saving results to {directory.absolute()}") df['model_version_id'] = model_version_id self._validate_results(df=df) for loc in df.location_id.unique().tolist(): os.makedirs(str(directory / str(loc)), exist_ok=True) for sex in df.sex_id.unique().tolist(): subset = df.loc[ (df.location_id == loc) & (df.sex_id == sex) ].copy() subset.to_csv(directory / str(loc) / f'{loc}_{sex}.csv') if add_summaries: summary = self.summarize_results(df=subset) self.save_summary_files( df=summary, model_version_id=model_version_id, directory=directory )
[docs] def save_summary_files(self, df: pd.DataFrame, model_version_id: int, directory: Path) -> None: """ Saves a data frame with summaries by location and sex in summary.csv files. Parameters ---------- df Data frame with the following columns: ['location_id', 'year_id', 'age_group_id', 'sex_id', 'measure_id', 'mean', 'lower', and 'upper'] model_version_id The model version to attach to the data directory Path to save the files to """ LOG.info(f"Saving results to {directory.absolute()}") df['model_version_id'] = model_version_id self._validate_results(df=df) self._validate_summaries(df=df) for loc in df.location_id.unique().tolist(): os.makedirs(str(directory / str(loc)), exist_ok=True) for sex in df.sex_id.unique().tolist(): subset = df.loc[ (df.location_id == loc) & (df.sex_id == sex) ].copy() subset.to_csv(directory / str(loc) / f'{loc}_{sex}_summary.csv')
[docs] @staticmethod def upload_summaries(directory: Path, conn_def: str, table: str) -> None: """ Uploads results from a directory to the model_estimate_final table in the Epi database specified by the conn_def argument. In the future, this will probably be replaced by save_results_dismod but we don't have draws to work with so we're just uploading summaries for now directly. Parameters ---------- directory Directory where files are saved conn_def Connection to a database to be used with db_tools.ezfuncs table which table to upload to """ if table not in VALID_TABLES: raise ResultsError("Don't know how to upload to table " f"{table}. Valid tables are {VALID_TABLES}.") session = db_tools.ezfuncs.get_session(conn_def=conn_def) loader = db_tools.loaders.Infiles(table=table, schema='epi', session=session) generic_file = (directory / '*' / '*summary.csv').absolute() LOG.info(f"Loading all files to {conn_def} that match {generic_file} glob.") loader.indir(path=str(generic_file), commit=True, with_replace=True)