Tax-Analyzer-Framework

srw.py
Login

File taf/srw.py from the latest check-in


"""
SampleReWeighting class implements methods that reweight sample data.
One method can transform typical variable-weight sample survey data
into a much larger uniform-weight sample in which each observation has
the same weight, but that is statistically equivalent to the original
survey data.  Another method can randomly draw a stratified sample from
a given data set so that the number of observations is reduced, but
that the stratified sample is statistically equivalent to the original
data.
"""

import numpy as np
import pandas as pd


class SampleReWeighting:
    """
    SampleReWeighting class constructor.
    """
    def __init__(
            self,
            weight_variable_name='weight'  # name in pandas.DataFrame
    ):
        # process arguments
        assert isinstance(weight_variable_name, str), \
            'weight_variable_name must be a string'
        self.wght_vname = weight_variable_name

    def uniform_weight_sample(
            self,
            vwdf,  # variable-weight pandas.DataFrame
            uwght,  # target uniform weight expressed as an integer
            verbose=False  # whether or not to report on low weights
    ):
        """
        Returns pandas.DataFrame containing a uniform-weight sample
        (constructed from the specified arguments so that each
        observation has the specified uwght plus some small adjustment),
        that has more observations than the variable-weight sample in vwdf,
        but is statistically similar to the variable-weight sample.
        """
        # check arguments
        assert isinstance(vwdf, pd.DataFrame), \
            'vwdf must be a pandas.DataFrame'
        assert self.wght_vname in vwdf, \
            f'weight_variable_name={self.wght_vname} not in vwdf'
        assert isinstance(uwght, int) and uwght > 0, \
            f'uwght={uwght} must be a positive integer'
        assert uwght < vwdf[self.wght_vname].mean(), \
            f'uwght={uwght} >= vwdf_mean_weight={vwdf[self.wght_vname].mean()}'
        dfcols = list(vwdf)
        vwx = vwdf.to_numpy(copy=True)
        assert vwx.ndim == 2, 'vwdf not a two-dimensional dataframe'
        # compute number of times each observation will be replicated
        ratio = vwx[:, dfcols.index(self.wght_vname)] / uwght
        if verbose:
            lownum = np.count_nonzero(ratio <= 0.50)
        # always retain one copy low-weight observations
        reps = np.clip(np.round(ratio).astype('int'), 1, None)
        if verbose and lownum > 0:
            print(('WARNING: uniform_weight_sample function is retaining '
                   f'{lownum} observations\n         with weights no '
                   f'greater than (0.5*uwght)={(0.5 * uwght):.1f}\n'))
        # construct uniform-weight sample by replicating vwx observations
        uwx = np.repeat(vwx, reps, axis=0)
        uwdf = pd.DataFrame(uwx, columns=dfcols)
        uwdf[self.wght_vname] = uwght
        # compute exact uwght so that the sum of the uniform weights is
        # closer to the sum of the variable weights
        exact_uwght = uwght * (
            vwdf[self.wght_vname].sum() / uwdf[self.wght_vname].sum()
        )
        uwdf[self.wght_vname] = np.round(exact_uwght, 2)
        uwdf = uwdf.convert_dtypes()
        return uwdf

    def stratified_sample(
            self,
            odf,  # original data in a pandas.DataFrame
            strat_variable_name,  # name of variable used to define strata
            strat_edges,  # list of values defining the sampling strata
            strat_probs,  # list of stratum sampling probabilities
            rnseed=123456789,  # random number seed used to conduct sampling
            verbose=False
    ):
        """
        Returns pandas.DataFrame containing a stratified sample
        constructed from the specified original uniform-weight data
        and the other method arguments.
        """
        # pylint: disable=too-many-arguments,too-many-locals
        wvar = self.wght_vname
        svar = strat_variable_name
        # check arguments
        assert isinstance(odf, pd.DataFrame), 'odf must be a pandas.DataFrame'
        assert self.wght_vname in odf, f'weight_variable={wvar} not in odf'
        assert svar in odf, f'stratification variable {svar} not in odf'
        assert isinstance(strat_edges, list), 'strat_edges is not a list'
        assert len(strat_edges) >= 2, 'strat_edges must have two or more edges'
        assert isinstance(strat_probs, list), 'strat_probs is not a list'
        assert len(strat_probs) == len(strat_edges)-1, \
            'number of strat_probs must be one less than number of strat_edges'
        assert 1 <= rnseed <= 999999999, \
            f'rnseed={rnseed} not in [1,999999999] range'
        # check that odf has a uniform weight
        assert np.allclose(
            [odf[wvar].min()],
            [odf[wvar].max()]
        ), f'odf does not have a uniform weight named {wvar}'
        # draw stratified sample from odf
        # ... loop through strata drawing the stratum sample
        smsg = ('stratum {} contains {:7d} records and '
                'uses sampling prob {:.4f}')
        sdfs = []
        for stratum, sprob in enumerate(strat_probs, start=1):
            loedge = strat_edges[stratum-1]
            hiedge = strat_edges[stratum]
            assert hiedge > loedge, \
                f'loedge={loedge} must be greater than hiedge={hiedge}'
            assert 0.0 < sprob <= 1.0, \
                f'sprob={sprob:.4f} not in (0,1] range for stratum {stratum}'
            strat_df = odf[(loedge < odf[svar]) & (odf[svar] <= hiedge)]
            if verbose:
                print(smsg.format(stratum, strat_df.shape[0], sprob))
            seed = rnseed + 1000 * (stratum - 1)
            sample_df = strat_df.sample(frac=sprob, random_state=seed)
            sample_df['weight'] /= sprob
            if verbose:
                print(f'   included records {sample_df.shape[0]:7d}')
            sdfs.append(sample_df)
        # ... concatenate the stratum samples into a pandas.DataFrame
        sdf = pd.concat(sdfs, ignore_index=True, copy=False)
        if verbose:
            print(f'total number of records {sdf.shape[0]:7d}')
        return sdf