Twinify example: NapsuMQ usage#

%%capture

import jax.numpy as jnp
import numpy as np
import pandas as pd
import d3p
from twinify.napsu_mq.napsu_mq import NapsuMQModel
from twinify.napsu_mq import rubins_rules
from tests.napsu_mq.binary_logistic_regression_generator import BinaryLogisticRegressionDataGenerator
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
rng = d3p.random.PRNGKey(42526709)
data_gen_rng, inference_rng = d3p.random.split(rng)
data_gen_rng = d3p.random.convert_to_jax_rng_key(data_gen_rng)

Prepare dataset#

Generate 10 000 rows of datapoints for logistic regression task with true coefficients (1, 3, -4, 2, 1)

coefficients = [1.0, 3.0, -4.0, 2.0, 1.0]
dataset = BinaryLogisticRegressionDataGenerator(jnp.array(coefficients)).generate_data(n=10000, rng_key=data_gen_rng)
orig_df = pd.DataFrame(dataset, dtype="int")
n, d = orig_df.shape
orig_df
0 1 2 3 4 5
0 1 0 1 1 1 1
1 1 0 1 0 0 0
2 1 1 0 0 1 1
3 1 1 1 0 0 0
4 1 0 0 1 0 1
... ... ... ... ... ... ...
9995 1 0 0 1 0 1
9996 1 1 1 0 1 0
9997 1 1 0 0 0 1
9998 1 1 0 0 1 1
9999 1 1 0 1 1 1

10000 rows × 6 columns

Create NapsuMQModel and fit the probabilistic model with DP#

# Fitting the model takes few minutes

rng = d3p.random.PRNGKey(74249069)
inference_rng, sampling_rng = d3p.random.split(rng)

# We can define column marginal relationships that we want to preserve
required_marginals = [(0, 1), (0, 5), (3, 5)]

model = NapsuMQModel(required_marginals=required_marginals, use_laplace_approximation=False)
result = model.fit(
    data=orig_df,
    rng=inference_rng,
    epsilon=1,
    delta=(n ** (-2))
)
4it [00:00, 6223.00it/s]0:00<?, ?it/s]
2it [00:00, 10230.01it/s]
2it [00:00, 12905.55it/s]
2it [00:00, 13774.40it/s]
2it [00:00, 7084.97it/s]
4it [00:00, 5417.25it/s]
4it [00:00, 8000.58it/s]
4it [00:00, 8512.03it/s]
2it [00:00, 7752.87it/s]
2it [00:00, 6136.51it/s]
4it [00:00, 3941.09it/s]
100%|██████████| 12/12 [00:00<00:00, 282.68it/s]
100%|██████████| 5/5 [00:00<00:00, 62230.03it/s]
100%|██████████| 6/6 [00:00<00:00, 15060.34it/s]
sample: 100%|██████████| 2800/2800 [00:27<00:00, 101.81it/s, 31 steps of size 1.95e-01. acc. prob=0.90] 
sample: 100%|██████████| 2800/2800 [00:02<00:00, 1339.43it/s, 31 steps of size 1.56e-01. acc. prob=0.94]
sample: 100%|██████████| 2800/2800 [00:02<00:00, 1373.09it/s, 31 steps of size 1.72e-01. acc. prob=0.92]
sample: 100%|██████████| 2800/2800 [00:01<00:00, 1406.52it/s, 15 steps of size 1.89e-01. acc. prob=0.90]

Create a synthetic dataset using probabilistic model#

synthetic_dfs = result.generate(
    rng=sampling_rng, 
    num_data_per_parameter_sample=10000, 
    num_parameter_samples=10,
    single_dataframe=False
)

Plot and compare the results#

def value_count_list(df, col, type):
    counts = df[col].value_counts()
    return [(value, counts[value], type) for value in counts.index]

def value_count_df(orig_df, synthetic_dfs, col):
    vl_list = value_count_list(orig_df, col, "Original")
    for synthetic_df in synthetic_dfs:
        vl_list += (value_count_list(synthetic_df, col, "Synthetic"))

    return pd.DataFrame.from_records(vl_list, columns=["Value", "Count", "Type"])

value_count_dfs = { col: value_count_df(orig_df, synthetic_dfs, col) for col in orig_df.columns }

fig, axes = plt.subplots(1, len(value_count_dfs.items()), figsize=(3 * len(value_count_dfs.items()), 5.5))
for i, (key, value) in enumerate(value_count_dfs.items()):
    ax = axes[i]
    ax.set_title(key)
    ax.tick_params("x", labelrotation=90)
    sns.barplot(x="Value", y="Count", data=value_count_dfs[key], hue="Type", estimator=np.mean, ax=ax)
    if i < len(value_count_dfs.items()) - 1: 
        ax.get_legend().remove()
    else:
        ax.legend(loc="upper left", bbox_to_anchor=(1.00, 1))
        
plt.tight_layout()
plt.show()
../_images/645161b6ad30388f4286b4ae249c89bfed71b418d6eac5d4db86e54431b629be.png