Tax-Analyzer-Framework

test6.py
Login

File taf/dptests/test6.py from the latest check-in


"""
Test SampleReWeighting class stratified_sample method using private
Malaysia PIT tax return data where the income strata boundaries and
the strata sampling probabilities are defined in the MYI-Tax-Analyzer
"Tax Data Preparation Overview" document, which is at the "tax data
preparation" link on model's homepage:
https://chiselapp.com/user/WBG/repository/MYI-Tax-Analyzer

If you are just looking for how to use the SampleReWeighting class
stratified_sample method, look at the generate_data_sample function
below.

Note that the stratified sample actually used by the MYI-Tax-Analyzer
was constructed before the SampleReWeighting class was developed and
added to the Tax-Analyzer-Framework, so the stratified samples drawn
here are not exactly the same as the one used in the operation of the
MYI-Tax-Analyzer microsimulation model.

Also, note that this test requires access to the myitaxanalyzer package
and to the full 2018 tax return files (in the all.csv file) in order to
generate aggregate sample results from each stratified sample.
that is drawn.
"""

import sys
import time
import argparse
import subprocess
import taf  # Tax-Analyzer-Framework

DATA_YEAR = 2018
ALL_DATA_FILENAME = 'all.csv'
SAMPLE_DATA_FILENAME = 'smp.csv'
STRATUM_INCOME = 'gross_income'
STRATUM_EDGES = [-9e99, 100e3, 500e3, 9e99]  # one_stratum: [-9e99, 9e99]
STRATUM_PROBS = [0.05, 0.25, 1.00]           # one_stratum: [0.09590642]
SAMPLING_SEEDS = [
    123456789,
    234567891,
    345678912,
    456789123,
    567891234,
    678912345,
    789123456,
    891234567,
    912345678
]
DO_TIMING = False


def main(num_samples):
    """
    High-level logic.
    """
    start_time = time.time()
    # read complete tax return data into a pandas.DataFrame
    alldf = taf.csv2df(ALL_DATA_FILENAME)
    assert alldf.shape == (3860847, 42)
    # add variables to alldf that are expected by the SampleReWeighting class
    # ... weights variable equal to one for all records is already included
    # ... create STRATUM_INCOME variable used to define sampling strata
    alldf[STRATUM_INCOME] = (
        alldf.businc + alldf.empinc + alldf.rentinc + alldf.intinc +
        alldf.pioneerinc + alldf.totinc_transfer
    )
    # execute model using complete data as input
    run_results = []
    output_filename = f'{ALL_DATA_FILENAME[:-4]}-{DATA_YEAR % 100}-clp.tab'
    execute_model_run(ALL_DATA_FILENAME)
    extract_run_statistics(output_filename, run_results)
    # for num_samples, generate a sample, use it as input in a model run, and
    # append list of decile PIT liabilities to the run_results list
    output_filename = f'{SAMPLE_DATA_FILENAME[:-4]}-{DATA_YEAR % 100}-clp.tab'
    for rnseed in SAMPLING_SEEDS[:num_samples]:
        generate_data_sample(alldf, rnseed, SAMPLE_DATA_FILENAME)
        execute_model_run(SAMPLE_DATA_FILENAME)
        extract_run_statistics(output_filename, run_results)
    # write run_results
    for index, run_res in enumerate(run_results):
        sys.stdout.write(f'{index:2d} ')
        sys.stdout.write(f'[{run_res[0]:4.2f}] ')
        for amt in run_res[1:]:
            if amt >= 10.0:
                sys.stdout.write(f'{amt:7.2f}')
            else:
                sys.stdout.write(f'{amt:6.2f}')
        sys.stdout.write('\n')
    if DO_TIMING:
        print(f'execution_time= {(time.time() - start_time):.2f} secs')
    return 0
# end of main function code


def generate_data_sample(fulldf, seed, filename):
    """
    Draw random stratified sample using specified seed from specifid fulldf
    and write the sample and its weights to files using specified filename.
    """
    srw = taf.SampleReWeighting()
    smpdf = srw.stratified_sample(
        fulldf,
        STRATUM_INCOME,
        STRATUM_EDGES,
        STRATUM_PROBS,
        rnseed=seed,
        verbose=False
    )
    smpdf.drop(columns=[STRATUM_INCOME], inplace=True)
    taf.df2csv(smpdf, filename)
# end of generate_data_sample function code


def execute_model_run(data_filename):
    """
    Execute MYI-Tax-Analyzer run using specified input data_filename.
    """
    cmd = (f'myita {data_filename} {DATA_YEAR} '
           'clp.json --nodump --noparam --silent')
    subprocess.run(cmd.split(), check=True)
# end of execute_model_run function code


def extract_run_statistics(output_filename, run_stats):
    """
    Append decile PIT liability list from output_filename to run_stats list.
    """
    with open(output_filename, 'r', encoding='utf-8') as outfile:
        output = outfile.read()
    lines = output.splitlines()
    sample_size = float(lines[13].split()[1])
    stats = [sample_size]
    for line in lines[3:14]:
        columns = line.split()
        stats.append(float(columns[3]))
    run_stats.append(stats)
# end of extract_run_statistics function code


def process_command_line_arguments():
    """
    Process optional command-line arguments returning a dictionary
    with this value: num_samples.
    """
    usage_str = 'python test6.py [--num_samples NUM] [--help]'
    parser = argparse.ArgumentParser(
        prog='',
        usage=usage_str,
        description=('Run test6 using the specified value of the '
                     'one optional command-line parameter.')
    )
    parser.add_argument(
        '--num_samples', metavar='NUM', type=int, default=1,
        help=('option that specifies number of stratified samples to draw '
              '[default=1]')
    )
    args = parser.parse_args()
    # check command-line arguments
    args_ok = True
    if args.num_samples < 1:
        sys.stderr.write('ERROR: NUM must be at least one\n')
        args_ok = False
    max_num = len(SAMPLING_SEEDS)
    if args.num_samples > max_num:
        sys.stderr.write(f'ERROR: NUM not in [1,{max_num}] range\n')
        args_ok = False
    if args_ok:
        return {'num_samples': args.num_samples}
    sys.stderr.write(f'USAGE: {usage_str}\n')
    return {}


if __name__ == '__main__':
    arg = process_command_line_arguments()
    if arg:
        sys.exit(main(arg['num_samples']))
    else:
        sys.exit(1)