diff --git a/drdid/ipwd_did.py b/drdid/ipwd_did.py
index 44d5fda..9b5f43b 100644
--- a/drdid/ipwd_did.py
+++ b/drdid/ipwd_did.py
@@ -1,104 +1,152 @@
from .utils import *
-def std_ipw_did_panel(
- y1: ndarray, y0: ndarray, D: ndarray,
- covariates, i_weights, boot = False):
-
- n = len(D)
- delta_y = y1 - y0
-
- int_cov = has_intercept(int_cov)
- i_weights = has_weights(i_weights)
-
- _, _, w_cont, _, _, asy_lin_rep_ps =\
- fit_ps(D, int_cov, i_weights)
-
- w_treat = i_weights * D
-
- att_treat = w_treat * delta_y
- att_cont = w_cont * delta_y
-
- eta_treat = np.mean(att_treat) / np.mean(w_treat)
- eta_cont = np.mean(att_cond) / np.mean(w_cont)
-
- ipw_att = eta_treat - eta_cont
-
- inf_treat = (att_treat - w_treat * eta_treat) / np.mean(w_treat)
- inf_cont_1 = att_cont - w_cont * eta_cont
- w_ref = w_cont * (delta_y - eta_cont)
- M2 = np.mean(w_ref[:, n_x] * int_cov, axis=0)
- inf_control = asy_lin_rep_ps * M2
-
- att_inf_func = inf_treat - inf_control
-
- se_att = None
- if not boot:
- se_att = np.std(att_inf_func) / np.sqrt(n)
-
- return (ipw_att, att_inf_func, se_att)
-# ----------------- RC
-
-def std_ipw_did_rc(
- y: ndarray, post: ndarray, D: ndarray, covariates = None, i_weights = None
- , boot = False):
-
- n = len(D)
- post1 = (1 - post)
- d1 = (1 - D)
-
- int_cov = has_intercept(covariates, n)
- i_weights = has_weights(i_weights, n)
-
- ps_fit, w_cont_pre, _, w_cont_post, _, asy_lin_rep_ps = \
- fit_ps(D, int_cov, i_weights, post)
-
- w_i = i_weights * D
- w_treat_pre = w_i * post1
- w_treat_post = w_i * post
-
- diff_w = i_weights * ps_fit * (1 - D)
-
- def eta_form(vect, y_ = y):
- return vect * y_ / np.mean(vect)
-
- eta_treat_pre = eta_form(w_treat_pre)
- eta_treat_post = eta_form(w_treat_post)
- eta_cont_pre = eta_form(w_cont_pre)
- eta_cont_post = eta_form(w_cont_post)
-
- att_treat_pre = np.mean(eta_treat_pre)
- att_treat_post = np.mean(eta_treat_post)
- att_cont_pre = np.mean(eta_cont_pre)
- att_cont_post = np.mean(eta_cont_post)
-
- ipw_att = att_treat_post - att_treat_pre - (att_cont_post - att_cont_pre)
-
- inf_treat_pre = eta_treat_pre - \
- w_treat_pre * att_treat_pre / np.mean(w_treat_pre)
- inf_treat_post = eta_treat_post -\
- w_treat_post * att_treat_post / np.mean(w_treat_post)
- inf_treat = inf_treat_post - inf_treat_pre
-
- inf_cont_pre = eta_cont_pre - \
- w_cont_pre * att_cont_pre / np.mean(w_cont_pre)
- inf_cont_post = eta_cont_post -\
- w_cont_post * att_cont_post / np.mean(w_cont_post)
+def std_ipw_did_panel(y1, y0, D, covariates, i_weights = None):
+ D = np.asarray(D).flatten()
+ n = len(D)
+ delta_y = np.asarray(y1 - y0).flatten()
+ int_cov = np.ones((n, 1))
+
+ if covariates is not None:
+ covariates = np.asarray(covariates)
+ if np.all(covariates[:, 0] == 1):
+ int_cov = covariates
+ else:
+ int_cov = np.column_stack((np.ones(n), covariates))
+
+ if i_weights is None:
+ i_weights = np.ones(n)
+ elif np.min(i_weights) < 0:
+ raise ValueError("i_weights must be non-negative")
+
+ i_weights = i_weights / np.mean(i_weights)
+ pscore_model = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights)
+ pscore_results = pscore_model.fit()
+ # print(D.mean())
+ # print(pscore_results.summary2())
+ if not pscore_results.converged:
+ print("Warning: glm algorithm did not converge")
+ if np.any(np.isnan(pscore_results.params)):
+ raise ValueError("Propensity score model coefficients have NA components. \n Multicollinearity (or lack of variation) of covariates is a likely reason.")
+ ps_fit = pscore_results.predict()
+ ps_fit = np.minimum(ps_fit, 1 - 1e-16)
+
+ w_treat = i_weights * D
+ w_cont = i_weights * ps_fit * (1 - D) / (1 - ps_fit)
+
+ att_treat = w_treat * delta_y
+ att_cont = w_cont * delta_y
+
+ eta_treat = mean(att_treat) / mean(w_treat)
+ eta_cont = mean(att_cont) / mean(w_cont)
+
+ ipw_att = eta_treat - eta_cont
+
+ score_ps = i_weights[:, np.newaxis] * (D - ps_fit)[:, np.newaxis] * int_cov
+ Hessian_ps = pscore_results.cov_params() * n
+ asy_lin_rep_ps = np.dot(score_ps, Hessian_ps)
+
+ inf_treat = (att_treat - w_treat * eta_treat) / mean(w_treat)
+ inf_cont_1 = att_cont - w_cont * eta_cont
+ pre_m2 = w_cont * (delta_y - eta_cont)
+ M2 = np.mean(pre_m2[:, np.newaxis] * int_cov, axis = 0)
+ print(M2)
+ inf_cont_2 = np.dot(asy_lin_rep_ps, M2)
+
+ inf_control = (inf_cont_1 + inf_cont_2) / np.mean(w_cont)
+ att_inf_func = inf_treat - inf_control
+ print(np.std(att_inf_func) / np.sqrt(n))
+ return ipw_att, att_inf_func
+
+def std_ipw_did_rc(y, post, D, covariates, i_weights = None):
+ D = np.asarray(D).flatten()
+ y = np.asarray(y).flatten()
+ post = np.asarray(post).flatten()
+ n = len(D)
+ if covariates is None:
+ int_cov = np.ones((n, 1))
+ else:
+ covariates = np.asarray(covariates)
+ if np.all(covariates[:, 0] == 1):
+ int_cov = covariates
+ else:
+ int_cov = np.column_stack((np.ones(n), covariates))
+
+ # Pesos
+ if i_weights is None:
+ i_weights = np.ones(n)
+ else:
+ i_weights = np.asarray(i_weights)
+ if np.min(i_weights) < 0:
+ raise ValueError("i_weights must be non-negative")
+
+ # Normalizar pesos
+ i_weights = np.asarray(i_weights).flatten()
+ i_weights = i_weights / np.mean(i_weights)
+
+ pscore_model = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights)
+ pscore_results = pscore_model.fit()
+ if not pscore_results.converged:
+ print("Warning: glm algorithm did not converge")
+ if np.any(np.isnan(pscore_results.params)):
+ raise ValueError("Propensity score model coefficients have NA components. \n Multicollinearity (or lack of variation) of covariates is a likely reason.")
+ ps_fit = pscore_results.predict()
+ ps_fit = np.minimum(ps_fit, 1 - 1e-16)
+
+
+ w_treat_pre = i_weights * D * (1 - post)
+ w_treat_post = i_weights * D * post
+ # print(np.mean(w_treat_pre))
+
+ w_cont_pre = i_weights * ps_fit * (1 - D) * (1 - post)/(1 - ps_fit)
+ w_cont_post = i_weights * ps_fit * (1 - D) * post/(1 - ps_fit)
+
+ # Elements of the influence function (summands)
+ eta_treat_pre = w_treat_pre * y / np.mean(w_treat_pre)
+ eta_treat_post = w_treat_post * y / np.mean(w_treat_post)
+ # print(eta_treat_pre)
+
+ eta_cont_pre = w_cont_pre * y / np.mean(w_cont_pre)
+ eta_cont_post = w_cont_post * y / np.mean(w_cont_post)
+
+ # Estimator of each component
+ att_treat_pre = np.mean(eta_treat_pre)
+ att_treat_post = np.mean(eta_treat_post)
+ att_cont_pre = np.mean(eta_cont_pre)
+ att_cont_post = np.mean(eta_cont_post)
+ ipw_att = (att_treat_post - att_treat_pre) - (att_cont_post - att_cont_pre)
+
+ score_ps = (i_weights * (D - ps_fit))[:, np.newaxis] * int_cov
+ Hessian_ps = pscore_results.cov_params() * n
+ asy_lyn_rep_ps = np.dot(score_ps, Hessian_ps)
+
+ inf_treat_pre = eta_treat_pre - w_treat_pre * att_treat_pre/np.mean(w_treat_pre)
+ inf_treat_post = eta_treat_post - w_treat_post * att_treat_post/np.mean(w_treat_post)
+ inf_treat = inf_treat_post - inf_treat_pre
+ # Now, get the influence function of control component
+ # Leading term of the influence function: no estimation effect
+ inf_cont_pre = eta_cont_pre - w_cont_pre * att_cont_pre/np.mean(w_cont_pre)
+ inf_cont_post = eta_cont_post - w_cont_post * att_cont_post/np.mean(w_cont_post)
+ inf_cont = inf_cont_post - inf_cont_pre
+
+ # Estimation effect from gamma hat (pscore)
+ # Derivative matrix (k x 1 vector)
- def simple_rep(a, b, y_ = y, cov = int_cov):
- return (a * (y - b))[:, n_x] * cov / np.mean(a)
-
- M2_pre = np.mean(simple_rep(w_cont_pre, att_cont_pre), axis=0)
- M2_post = np.mean(simple_rep(w_cont_post, att_treat_post), axis=0)
+ M2_pre = np.mean((w_cont_pre *(y - att_cont_pre))[:, np.newaxis] * int_cov, axis = 0)/np.mean(w_cont_pre)
+ M2_post = np.mean((w_cont_post *(y - att_cont_post))[:, np.newaxis] * int_cov, axis = 0)/np.mean(w_cont_post)
- inf_cont_ps = np.dot(asy_lin_rep_ps, M2_post - M2_pre)
- inf_cont = inf_cont + inf_cont_ps
- att_inf_func = inf_treat - inf_cont
+ # Now the influence function related to estimation effect of pscores
+ M2 = M2_post - M2_pre
+ # print()
- if not boot:
- se_att = np.std(att_inf_func) / np.sqrt(n)
+ inf_cont_ps = np.dot(asy_lyn_rep_ps, M2)
- return(ipw_att, att_inf_func, se_att)
+ # Influence function for the control component
+ inf_cont = inf_cont + inf_cont_ps
+ #get the influence function of the DR estimator (put all pieces together)
+ att_inf_func = inf_treat - inf_cont
+ # print(np.std(att_inf_func) / np.sqrt(n))
+ return ipw_att, att_inf_func
diff --git a/drdid/panel.ipynb b/drdid/panel.ipynb
new file mode 100644
index 0000000..c943bb0
--- /dev/null
+++ b/drdid/panel.ipynb
@@ -0,0 +1,644 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " treated | \n",
+ " age | \n",
+ " educ | \n",
+ " black | \n",
+ " married | \n",
+ " nodegree | \n",
+ " dwincl | \n",
+ " re74 | \n",
+ " re75 | \n",
+ " re78 | \n",
+ " hisp | \n",
+ " early_ra | \n",
+ " sample | \n",
+ " experimental | \n",
+ " i_w | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 46 | \n",
+ " 14 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 391.853363 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 0.155489 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " NaN | \n",
+ " 35 | \n",
+ " 14 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 25862.322266 | \n",
+ " 16823.662109 | \n",
+ " 12059.726562 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 0.754220 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " treated age educ black married nodegree dwincl re74 \\\n",
+ "0 NaN 46 14 0 1 0 NaN 391.853363 \n",
+ "1 NaN 35 14 0 1 0 NaN 25862.322266 \n",
+ "\n",
+ " re75 re78 hisp early_ra sample experimental i_w \n",
+ "0 0.000000 0.000000 0 NaN 2 0 0.155489 \n",
+ "1 16823.662109 12059.726562 0 NaN 2 0 0.754220 "
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# from drdid.reg_did import reg_panel_att_if\n",
+ "import pandas as pd, numpy as np\n",
+ "\n",
+ "data = pd.read_csv(\"../csdid_comparing/data/R_panel.csv\")\n",
+ "data.head(2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y1 = data['re78'].to_numpy()\n",
+ "y0 = data['re75'].to_numpy()\n",
+ "d = data['experimental'].to_numpy()\n",
+ "w = np.array(data['i_w'])\n",
+ "x = data[['age']].to_numpy()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[6931.2961773 -168.67231538]\n",
+ "252651.52672451953\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(475.214004116089, 594.302830865747)"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "from scipy import stats\n",
+ "import numpy as np\n",
+ "import statsmodels.api as sm\n",
+ "\n",
+ "lm = sm.WLS\n",
+ "glm = sm.GLM\n",
+ "n_x = np.newaxis\n",
+ "qr_solver = np.linalg.pinv\n",
+ "binomial = sm.families.Binomial()\n",
+ "mean = np.mean\n",
+ "\n",
+ "def reg_did_panel(y1, y0, D, covariates, i_weights=None):\n",
+ " D = np.asarray(D).flatten()\n",
+ " n = len(D)\n",
+ " deltaY = np.asarray(y1 - y0).flatten()\n",
+ " int_cov = np.ones((n, 1))\n",
+ " \n",
+ " if covariates is not None:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " elif np.min(i_weights) < 0:\n",
+ " raise ValueError(\"i_weights must be non-negative\")\n",
+ " \n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ " \n",
+ " mask = D == 0\n",
+ " X = int_cov[mask]\n",
+ " y = deltaY[mask]\n",
+ " w = i_weights[mask]\n",
+ " \n",
+ " # reg_coeff = np.linalg.lstsq(X * w[:, np.newaxis], y * w, rcond=None)[0]\n",
+ " reg_coeff = lm(y, X, weights=w).fit().params\n",
+ " print(reg_coeff)\n",
+ " \n",
+ " if np.any(np.isnan(reg_coeff)):\n",
+ " raise ValueError(\"Outcome regression model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is probably the reason for it.\")\n",
+ " \n",
+ " out_delta = np.dot(int_cov, reg_coeff)\n",
+ " w_treat = i_weights * D\n",
+ " w_cont = i_weights * (1 - D)\n",
+ " reg_att_treat = w_treat * deltaY\n",
+ " reg_att_cont = w_cont * out_delta\n",
+ " eta_treat = np.mean(reg_att_treat) / np.mean(w_treat)\n",
+ " eta_cont = np.mean(reg_att_cont) / np.mean(w_cont)\n",
+ " reg_att = eta_treat - eta_cont\n",
+ " \n",
+ " weights_ols = i_weights * (1 - D)\n",
+ " wols_x = weights_ols[:, np.newaxis] * int_cov\n",
+ " wols_eX = weights_ols[:, np.newaxis] * (deltaY - out_delta)[:, np.newaxis] * int_cov\n",
+ " XpX_inv = np.linalg.inv(np.dot(wols_x.T, int_cov) / n)\n",
+ " asy_lin_rep_ols = np.dot(wols_eX, XpX_inv)\n",
+ " \n",
+ " inf_treat = (reg_att_treat - w_treat * eta_treat) / np.mean(w_treat)\n",
+ " print(np.sum(w_treat * eta_treat))\n",
+ " \n",
+ " inf_cont_1 = (reg_att_cont - w_cont * eta_cont)\n",
+ " M1 = np.mean(w_cont[:, np.newaxis] * int_cov, axis=0)\n",
+ " inf_cont_2 = np.dot(asy_lin_rep_ols, M1)\n",
+ " inf_control = (inf_cont_1 + inf_cont_2) / np.mean(w_cont)\n",
+ " \n",
+ " reg_att_inf_func = (inf_treat - inf_control)\n",
+ " se_reg_att = np.std(reg_att_inf_func) / np.sqrt(n)\n",
+ " \n",
+ " return reg_att, se_reg_att\n",
+ "reg_did_panel(y1, y0, d, x, w)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(475.21400411608613,\n",
+ " array([ 423.37312598, 9421.37209022, 14394.03516237, ...,\n",
+ " 2231.7013017 , 25987.48210317, 1468.3600495 ]))"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import statsmodels.api as sm\n",
+ "\n",
+ "def drdid_panel(y1, y0, D, covariates, i_weights=None, boot=False, boot_type=\"weighted\", nboot=None, inffunc=False):\n",
+ " # Convert inputs to numpy arrays\n",
+ " D = np.asarray(D).flatten()\n",
+ " n = len(D)\n",
+ " deltaY = np.asarray(y1 - y0).flatten()\n",
+ " \n",
+ " # Add constant to covariate matrix\n",
+ " if covariates is None:\n",
+ " int_cov = np.ones((n, 1))\n",
+ " else:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " # Weights\n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " elif np.min(i_weights) < 0:\n",
+ " raise ValueError(\"i_weights must be non-negative\")\n",
+ " \n",
+ " # Normalize weights\n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ " # print(D.mean())\n",
+ " \n",
+ " # Compute the Pscore by MLE\n",
+ " pscore_model = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights)\n",
+ " pscore_results = pscore_model.fit()\n",
+ " # print(D.mean())\n",
+ " # print(pscore_results.summary2())\n",
+ " if not pscore_results.converged:\n",
+ " print(\"Warning: glm algorithm did not converge\")\n",
+ " if np.any(np.isnan(pscore_results.params)):\n",
+ " raise ValueError(\"Propensity score model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " ps_fit = pscore_results.predict()\n",
+ " ps_fit = np.minimum(ps_fit, 1 - 1e-16)\n",
+ " # print(ps_fit)\n",
+ " \n",
+ " # Compute the Outcome regression for the control group using wols\n",
+ " mask = D == 0\n",
+ " reg_model = sm.WLS(deltaY[mask], int_cov[mask], weights=i_weights[mask])\n",
+ " reg_results = reg_model.fit()\n",
+ " if np.any(np.isnan(reg_results.params)):\n",
+ " raise ValueError(\"Outcome regression model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " out_delta = np.dot(int_cov, reg_results.params)\n",
+ " \n",
+ " # Compute Traditional Doubly Robust DiD estimators\n",
+ " w_treat = i_weights * D\n",
+ " w_cont = i_weights * ps_fit * (1 - D) / (1 - ps_fit)\n",
+ " dr_att_treat = w_treat * (deltaY - out_delta)\n",
+ " dr_att_cont = w_cont * (deltaY - out_delta)\n",
+ " \n",
+ " eta_treat = np.mean(dr_att_treat) / np.mean(w_treat)\n",
+ " eta_cont = np.mean(dr_att_cont) / np.mean(w_cont)\n",
+ " \n",
+ " dr_att = eta_treat - eta_cont\n",
+ " \n",
+ " # Compute influence function\n",
+ " weights_ols = i_weights * (1 - D)\n",
+ " wols_x = weights_ols[:, np.newaxis] * int_cov\n",
+ " wols_eX = weights_ols[:, np.newaxis] * (deltaY - out_delta)[:, np.newaxis] * int_cov\n",
+ " XpX_inv = np.linalg.inv(np.dot(wols_x.T, int_cov) / n)\n",
+ " asy_lin_rep_wols = np.dot(wols_eX, XpX_inv)\n",
+ " \n",
+ " score_ps = i_weights[:, np.newaxis] * (D - ps_fit)[:, np.newaxis] * int_cov\n",
+ " Hessian_ps = pscore_results.cov_params() * n\n",
+ " asy_lin_rep_ps = np.dot(score_ps, Hessian_ps)\n",
+ " \n",
+ " inf_treat_1 = dr_att_treat - w_treat * eta_treat\n",
+ " M1 = np.mean(w_treat[:, np.newaxis] * int_cov, axis=0)\n",
+ " inf_treat_2 = np.dot(asy_lin_rep_wols, M1)\n",
+ " inf_treat = (inf_treat_1 - inf_treat_2) / np.mean(w_treat)\n",
+ " \n",
+ " inf_cont_1 = dr_att_cont - w_cont * eta_cont\n",
+ " M2 = np.mean(w_cont[:, np.newaxis] * (deltaY - out_delta - eta_cont)[:, np.newaxis] * int_cov, axis=0)\n",
+ " inf_cont_2 = np.dot(asy_lin_rep_ps, M2)\n",
+ " M3 = np.mean(w_cont[:, np.newaxis] * int_cov, axis=0)\n",
+ " inf_cont_3 = np.dot(asy_lin_rep_wols, M3)\n",
+ " inf_control = (inf_cont_1 + inf_cont_2 - inf_cont_3) / np.mean(w_cont)\n",
+ " \n",
+ " dr_att_inf_func = inf_treat - inf_control\n",
+ " \n",
+ " return dr_att, dr_att_inf_func\n",
+ "\n",
+ "drdid_panel(y1, y0, d, None, w)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import statsmodels.api as sm\n",
+ "\n",
+ "nopanel = pd.read_csv(\"../csdid_comparing/data/R_nopanel_reg.csv\")\n",
+ "# y = nopanel['y']\n",
+ "# post = nopanel['post']\n",
+ "# D = nopanel['d']\n",
+ "# w = nopanel['i_w']\n",
+ "\n",
+ "def reg_did_rc(y, post, D, covariates, i_weights=None):\n",
+ " D = np.asarray(D).flatten()\n",
+ " post = np.asarray(post).flatten()\n",
+ " n = len(D)\n",
+ " y = np.asarray(y).flatten()\n",
+ " i_weights = np.asarray(i_weights).flatten()\n",
+ " int_cov = np.ones((n, 1))\n",
+ " \n",
+ " if covariates is not None:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " elif np.min(i_weights) < 0:\n",
+ " raise ValueError(\"i_weights must be non-negative\")\n",
+ " \n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ " \n",
+ " # Pre-treatment regression\n",
+ " mask_pre = (D == 0) & (post == 0)\n",
+ " X_pre = int_cov[mask_pre]\n",
+ " y_pre = y[mask_pre]\n",
+ " w_pre = i_weights[mask_pre]\n",
+ " model_pre = sm.WLS(y_pre, X_pre, weights=w_pre)\n",
+ " results_pre = model_pre.fit()\n",
+ " reg_coeff_pre = results_pre.params\n",
+ " \n",
+ " if np.any(np.isnan(reg_coeff_pre)):\n",
+ " raise ValueError(\"Outcome regression model coefficients have NA components. \\n Multicollinearity of covariates is probably the reason for it.\")\n",
+ " \n",
+ " out_y_pre = np.dot(int_cov, reg_coeff_pre)\n",
+ " \n",
+ " # Post-treatment regression\n",
+ " mask_post = (D == 0) & (post == 1)\n",
+ " X_post = int_cov[mask_post]\n",
+ " y_post = y[mask_post]\n",
+ " w_post = i_weights[mask_post]\n",
+ " model_post = sm.WLS(y_post, X_post, weights=w_post)\n",
+ " results_post = model_post.fit()\n",
+ " reg_coeff_post = results_post.params\n",
+ " \n",
+ " if np.any(np.isnan(reg_coeff_post)):\n",
+ " raise ValueError(\"Outcome regression model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is probably the reason for it.\")\n",
+ " \n",
+ " out_y_post = np.dot(int_cov, reg_coeff_post)\n",
+ " \n",
+ " w_treat_pre = i_weights * D * (1 - post)\n",
+ " w_treat_post = i_weights * D * post\n",
+ " w_cont = i_weights * D\n",
+ " reg_att_treat_pre = w_treat_pre * y\n",
+ " reg_att_treat_post = w_treat_post * y\n",
+ " reg_att_cont = w_cont * (out_y_post - out_y_pre)\n",
+ " eta_treat_pre = np.mean(reg_att_treat_pre) / np.mean(w_treat_pre)\n",
+ " eta_treat_post = np.mean(reg_att_treat_post) / np.mean(w_treat_post)\n",
+ " eta_cont = np.mean(reg_att_cont) / np.mean(w_cont)\n",
+ " reg_att = (eta_treat_post - eta_treat_pre) - eta_cont\n",
+ " \n",
+ " weights_ols_pre = i_weights * (1 - D) * (1 - post)\n",
+ " # print(weights_ols_pre.reshape((n, 1)))\n",
+ " wols_x_pre = weights_ols_pre[:, np.newaxis] * int_cov\n",
+ " wols_eX_pre = weights_ols_pre[:, np.newaxis] * (y - out_y_pre)[:, np.newaxis] * int_cov\n",
+ " XpX_inv_pre = np.linalg.inv(np.dot(wols_x_pre.T, int_cov) / n)\n",
+ " asy_lin_rep_ols_pre = np.dot(wols_eX_pre, XpX_inv_pre)\n",
+ " \n",
+ " weights_ols_post = i_weights * (1 - D) * post\n",
+ " wols_x_post = weights_ols_post[:, np.newaxis] * int_cov\n",
+ " wols_eX_post = weights_ols_post[:, np.newaxis] * (y - out_y_post)[:, np.newaxis] * int_cov\n",
+ " XpX_inv_post = np.linalg.inv(np.dot(wols_x_post.T, int_cov) / n)\n",
+ " asy_lin_rep_ols_post = np.dot(wols_eX_post, XpX_inv_post)\n",
+ " \n",
+ " inf_treat_pre = (reg_att_treat_pre - w_treat_pre * eta_treat_pre) / np.mean(w_treat_pre)\n",
+ " inf_treat_post = (reg_att_treat_post - w_treat_post * eta_treat_post) / np.mean(w_treat_post)\n",
+ " inf_treat = inf_treat_post - inf_treat_pre\n",
+ " \n",
+ " inf_cont_1 = (reg_att_cont - w_cont * eta_cont)\n",
+ " M1 = np.mean(w_cont[:, np.newaxis] * int_cov, axis=0)\n",
+ " inf_cont_2_post = np.dot(asy_lin_rep_ols_post, M1)\n",
+ " inf_cont_2_pre = np.dot(asy_lin_rep_ols_pre, M1)\n",
+ " inf_control = (inf_cont_1 + inf_cont_2_post - inf_cont_2_pre) / np.mean(w_cont)\n",
+ " \n",
+ " reg_att_inf_func = (inf_treat - inf_control)\n",
+ " se_reg_att = np.std(reg_att_inf_func) / np.sqrt(n)\n",
+ " \n",
+ " return reg_att, reg_att_inf_func, se_reg_att\n",
+ "# reg_did_rc(y, post, D, None, w)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "594.302814443883\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(475.21400407689976,\n",
+ " array([ 423.36996588, 9421.35676191, 14394.01780432, ...,\n",
+ " 2231.68464404, 25987.47241037, 1468.3490895 ]))"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "from scipy import stats\n",
+ "import statsmodels.api as sm\n",
+ "\n",
+ "def ipw_did_panel(y1, y0, D, covariates, i_weights=None):\n",
+ " # Convert inputs to numpy arrays\n",
+ " y1 = np.asarray(y1)\n",
+ " y0 = np.asarray(y0)\n",
+ " D = np.asarray(D)\n",
+ " \n",
+ " # Sample size\n",
+ " n = len(D)\n",
+ " \n",
+ " # Generate deltaY\n",
+ " deltaY = y1 - y0\n",
+ " \n",
+ " # Add constant to covariate vector\n",
+ " if covariates is None:\n",
+ " int_cov = np.ones((n, 1))\n",
+ " else:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " # Weights\n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " elif np.min(i_weights) < 0:\n",
+ " raise ValueError(\"i_weights must be non-negative\")\n",
+ " \n",
+ " # Normalize weights\n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ " \n",
+ " # Pscore estimation (logit) and its fitted values\n",
+ " PS = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights).fit()\n",
+ " if not PS.converged:\n",
+ " print(\"Warning: glm algorithm did not converge\")\n",
+ " if np.any(np.isnan(PS.params)):\n",
+ " raise ValueError(\"Propensity score model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " ps_fit = PS.predict()\n",
+ " ps_fit = np.minimum(ps_fit, 1 - 1e-16)\n",
+ " \n",
+ " # Compute IPW estimator\n",
+ " w_treat = i_weights * D\n",
+ " w_cont = i_weights * ps_fit * (1 - D) / (1 - ps_fit)\n",
+ " att_treat = w_treat * deltaY\n",
+ " att_cont = w_cont * deltaY\n",
+ " eta_treat = np.mean(att_treat) / np.mean(i_weights * D)\n",
+ " eta_cont = np.mean(att_cont) / np.mean(i_weights * D)\n",
+ " ipw_att = eta_treat - eta_cont\n",
+ " \n",
+ " # Get the influence function to compute standard error\n",
+ " score_ps = i_weights[:, np.newaxis] * (D - ps_fit)[:, np.newaxis] * int_cov\n",
+ " Hessian_ps = PS.cov_params() * n\n",
+ " asy_lin_rep_ps = np.dot(score_ps, Hessian_ps)\n",
+ " \n",
+ " # Get the influence function of control component\n",
+ " att_lin1 = att_treat - att_cont\n",
+ " mom_logit = np.mean(att_cont[:, np.newaxis] * int_cov, axis=0)\n",
+ " att_lin2 = np.dot(asy_lin_rep_ps, mom_logit)\n",
+ " \n",
+ " # Get the influence function of the IPW estimator\n",
+ " att_inf_func = (att_lin1 - att_lin2 - i_weights * D * ipw_att) / np.mean(i_weights * D)\n",
+ " print(np.std(att_inf_func) / np.sqrt(n))\n",
+ " return ipw_att, att_inf_func\n",
+ "ipw_did_panel(y1, y0, d, None, w)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1.34605216e-14]\n",
+ "594.3028308657471\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(475.2140041160869,\n",
+ " array([ 423.37312598, 9421.37209022, 14394.03516237, ...,\n",
+ " 2231.7013017 , 25987.48210317, 1468.3600495 ]))"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def std_ipw_did_panel(y1, y0, D, covariates, i_weights = None):\n",
+ " D = np.asarray(D).flatten()\n",
+ " n = len(D)\n",
+ " delta_y = np.asarray(y1 - y0).flatten()\n",
+ " int_cov = np.ones((n, 1))\n",
+ " \n",
+ " if covariates is not None:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " elif np.min(i_weights) < 0:\n",
+ " raise ValueError(\"i_weights must be non-negative\")\n",
+ " \n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ " pscore_model = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights)\n",
+ " pscore_results = pscore_model.fit()\n",
+ " # print(D.mean())\n",
+ " # print(pscore_results.summary2())\n",
+ " if not pscore_results.converged:\n",
+ " print(\"Warning: glm algorithm did not converge\")\n",
+ " if np.any(np.isnan(pscore_results.params)):\n",
+ " raise ValueError(\"Propensity score model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " ps_fit = pscore_results.predict()\n",
+ " ps_fit = np.minimum(ps_fit, 1 - 1e-16)\n",
+ "\n",
+ " w_treat = i_weights * D\n",
+ " w_cont = i_weights * ps_fit * (1 - D) / (1 - ps_fit)\n",
+ " \n",
+ " att_treat = w_treat * delta_y\n",
+ " att_cont = w_cont * delta_y\n",
+ "\n",
+ " eta_treat = mean(att_treat) / mean(w_treat)\n",
+ " eta_cont = mean(att_cont) / mean(w_cont)\n",
+ "\n",
+ " ipw_att = eta_treat - eta_cont\n",
+ "\n",
+ " score_ps = i_weights[:, np.newaxis] * (D - ps_fit)[:, np.newaxis] * int_cov\n",
+ " Hessian_ps = pscore_results.cov_params() * n\n",
+ " asy_lin_rep_ps = np.dot(score_ps, Hessian_ps)\n",
+ "\n",
+ " inf_treat = (att_treat - w_treat * eta_treat) / mean(w_treat)\n",
+ " inf_cont_1 = att_cont - w_cont * eta_cont\n",
+ " pre_m2 = w_cont * (delta_y - eta_cont)\n",
+ " M2 = np.mean(pre_m2[:, np.newaxis] * int_cov, axis = 0)\n",
+ " print(M2)\n",
+ " inf_cont_2 = np.dot(asy_lin_rep_ps, M2)\n",
+ "\n",
+ " inf_control = (inf_cont_1 + inf_cont_2) / np.mean(w_cont)\n",
+ " att_inf_func = inf_treat - inf_control\n",
+ " print(np.std(att_inf_func) / np.sqrt(n))\n",
+ " return ipw_att, att_inf_func\n",
+ "\n",
+ "std_ipw_did_panel(y1, y0, d, None, w)\n",
+ " "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/drdid/rc.ipynb b/drdid/rc.ipynb
new file mode 100644
index 0000000..e4b1718
--- /dev/null
+++ b/drdid/rc.ipynb
@@ -0,0 +1,1266 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "no_panel = pd.read_csv(\"../csdid_comparing/data/R_nopanel_reg.csv\")\n",
+ "y = no_panel['y']\n",
+ "post = no_panel['post']\n",
+ "d = no_panel['d']\n",
+ "x = None\n",
+ "w = no_panel['i_w']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(-32.44762929840418,\n",
+ " array([ 4.20279333e+02, 4.12088741e+02, -9.76295144e+01, 4.13877211e+02,\n",
+ " -1.33180878e+02, 3.46372794e+02, 3.33751840e+01, -2.36632190e+02,\n",
+ " -2.87278416e+02, -5.01885265e+01, -2.56919638e+01, -4.36605269e+02,\n",
+ " -5.82711118e+02, -1.74750149e+02, -1.05917795e+02, 3.17576284e+02,\n",
+ " 5.22646939e+02, -3.92097558e+02, -1.74362318e+02, 3.25937355e+02,\n",
+ " 1.15121568e+01, 9.67115269e+01, 2.37900939e+02, -2.31235308e+02,\n",
+ " -1.67406181e+02, 2.49385835e+02, 7.23160188e+02, 1.39005288e+02,\n",
+ " -5.82824544e+01, -2.09472417e+02, -1.67529498e+02, -1.98673017e+02,\n",
+ " -1.10899807e+02, -2.23694179e+01, -8.20674968e+02, 5.35943555e+02,\n",
+ " 6.10833988e+02, 4.23615577e+02, 1.39299854e+02, 2.29206492e+01,\n",
+ " -6.33676385e+02, -3.95707388e+02, 1.89663298e+01, 1.01334683e+02,\n",
+ " 6.97094075e+01, -1.95271048e+02, -3.44512247e+01, -6.72134538e+02,\n",
+ " -2.95689468e+02, 2.80919019e+01, -2.30572474e+02, 3.06118680e+02,\n",
+ " -2.59103267e+02, 8.90688138e+01, -1.17068460e+02, -1.74855945e+02,\n",
+ " -4.21579168e+02, -3.49463513e+02, 3.58548856e+02, 2.74070866e+02,\n",
+ " 8.94619601e+02, 2.72925232e+02, 1.97267771e+02, 6.14555614e+02,\n",
+ " -2.12822803e+02, -7.44529056e+02, -2.50036755e+02, 1.11879526e+02,\n",
+ " -3.64114609e+01, 1.26722740e+01, 2.26052626e+02, 3.52279091e+02,\n",
+ " -1.58602401e+02, -2.14067743e+02, 5.13864039e+02, 5.90513691e+01,\n",
+ " 2.11973267e+01, -2.98048648e+02, 9.06978843e+01, -2.29889520e+02,\n",
+ " -1.04690681e+02, -6.31612943e+01, -8.17161981e+01, 5.40279596e+02,\n",
+ " -7.61069251e+02, -1.26693439e+02, 1.08962854e+03, 4.88141074e+01,\n",
+ " -1.79310459e+02, -1.91501515e+02, -3.82569248e+00, 5.77276974e+01,\n",
+ " -3.82459487e+02, -1.47957299e+02, -2.42341055e+02, -3.11245840e+02,\n",
+ " 8.91754223e+00, -3.57126608e+02, -1.50323548e+02, -8.92144522e+02,\n",
+ " 3.07419995e+01, 1.01969269e+02, -1.04012348e+01, -2.97688581e+02,\n",
+ " 5.50136302e+02, -1.30448299e+02, 1.41973622e+02, 1.12012366e+01,\n",
+ " 7.30720715e+02, -1.02790672e+01, 3.55718639e+01, -2.16537343e+01,\n",
+ " -3.71137958e+02, 7.26427434e+00, 9.35600864e+01, -1.25298182e+02,\n",
+ " 1.08647879e+02, -7.50123676e+01, 4.18921461e+01, -5.15391747e+01,\n",
+ " 5.38422193e+02, 1.37600027e+02, 2.20954413e+02, -3.04370282e+01,\n",
+ " 9.07984789e+01, -4.98343301e+01, 4.58732880e+02, -1.92490096e+02,\n",
+ " -9.38128679e+01, -1.15051818e+02, 1.31099359e+02, 2.87583768e+02,\n",
+ " 3.07189365e+01, 3.53247809e+00, -8.18715843e+01, -2.04526732e+02,\n",
+ " 2.32926404e+02, -2.36087460e+02, -1.37678998e+02, 1.78550004e+00,\n",
+ " 4.87878523e+02, 3.71738576e+02, -8.84080383e+02, -6.24049505e+00,\n",
+ " 8.29056945e+01, 4.57399586e+01, 2.97905782e+02, 5.91714970e+02,\n",
+ " 4.48105143e+01, -6.92797977e+01, 2.55354734e+02, -7.89225035e+01,\n",
+ " 1.43838620e+02, -3.93631793e+02, 4.76133717e+02, 9.16183723e+02,\n",
+ " -1.44540113e+02, -2.57889488e+02, -3.95310404e+02, 4.25819137e+01,\n",
+ " 4.38843307e+01, 9.34049278e+01, -4.17071476e+01, -8.38207367e+01,\n",
+ " 5.60322689e+01, 1.06349442e+01, 3.09546182e+02, 3.48801424e+02,\n",
+ " -1.05768649e+02, 7.43660156e+01, -2.19060274e+02, 1.99052090e+02,\n",
+ " -1.17751964e+02, 1.02684457e+02, 4.49686979e+02, 3.90078176e+01,\n",
+ " 3.22412138e+02, -1.38490567e+03, -1.48277066e+02, 2.18563981e+01,\n",
+ " 6.36586969e+02, 1.66128829e+01, 3.98819578e+01, 5.43968673e+02,\n",
+ " 1.70087616e+01, -1.12945500e+03, 8.16595074e+01, 2.92711225e+00,\n",
+ " 6.03544081e+01, -1.14685923e+01, 2.77633051e+02, 6.06500732e+02,\n",
+ " 4.56211154e+02, 1.75456319e+02, -1.12916110e+03, 1.31566141e+02,\n",
+ " 1.84878012e+01, 8.49839785e+02, 1.75981779e+02, 1.43690437e+02,\n",
+ " -3.04227043e+01, 2.46340357e+02, 2.13050478e+02, -1.62891702e+02,\n",
+ " -1.37749443e+02, 3.04928038e+02, -2.08060148e+02, 3.08937749e+02,\n",
+ " 3.24723186e+02, -1.98816285e+02, 7.95471735e+01, 1.90358746e+02,\n",
+ " 1.83465649e+02, -5.38050094e+02, -1.12880154e+02, -3.62295270e+02,\n",
+ " 1.62922507e+01, -1.85170208e+02, 9.86646641e+01, 2.98685240e+01,\n",
+ " 2.01606241e+02, -7.98828254e+00, -3.61456958e+02, 1.11270776e+03,\n",
+ " 2.66621054e+01, 3.01496066e+02, -1.25142347e+03, 5.28971079e+02,\n",
+ " -2.52826572e+02, 8.86681627e+01, 6.21972418e+01, -4.72960721e+01,\n",
+ " -6.85708127e+02, -2.30109786e-01, 7.36786880e+01, -6.28705485e+02,\n",
+ " 1.19764430e+02, 1.06302137e+02, -4.01034981e+01, -2.61794027e+01,\n",
+ " 3.89238246e+02, -5.44269320e+01, 4.80295667e+02, -5.80919478e+01,\n",
+ " 6.61172014e+01, 1.28778609e+02, -1.82567742e+02, 1.57025548e+02,\n",
+ " 3.59861570e+01, -2.75698382e+00, 2.38948770e+02, 5.33809115e+02,\n",
+ " 1.85202542e+02, 7.37137754e+01, -2.74456697e+01, -6.41486558e+01,\n",
+ " 1.64255000e+02, -2.83741759e+02, -3.24654203e+01, -5.25302612e+01,\n",
+ " -2.75642049e+02, 1.07046194e+02, 5.51808554e+00, -1.22072567e+03,\n",
+ " -1.73135777e+02, 5.47295474e+01, 9.73685499e-01, -6.55737638e+02,\n",
+ " 1.64032793e+02, -1.38059072e+02, 5.82306220e+02, -5.71054927e+02,\n",
+ " -2.20443680e+02, -6.59930511e+02, 9.73798237e+01, 4.41430648e+01,\n",
+ " -2.63824326e+02, 2.80282041e+02, -1.60419255e+02, 8.76880647e+01,\n",
+ " -5.71382813e+02, 4.80383968e+01, 7.52916372e+02, 4.18394491e+01,\n",
+ " 3.69554916e+01, 3.22772613e+02, 7.54799751e+02, -1.52723120e+02,\n",
+ " 7.21342420e+00, -3.11212384e+01, 9.21140746e+01, 7.29970195e+01,\n",
+ " -3.02675395e+02, -3.06709950e+01, -1.92026723e+02, 3.28174587e+02,\n",
+ " -1.85300133e+01, 2.86472860e+02, 1.89005851e+02, 2.10833505e+01,\n",
+ " -3.76490155e+01, -1.04587893e+02, -3.08966353e+02, -4.19383505e+02,\n",
+ " -9.62418025e+01, -6.79089974e+01, 4.15009663e+01, 5.80745707e+02,\n",
+ " -2.20347325e+02, -1.77342315e+02, 1.04110341e+03, -2.16501200e+02,\n",
+ " 6.78717380e+01, -7.46661995e+01, -1.39958107e+02, 5.77063258e+01,\n",
+ " 2.45207141e+02, 1.94375988e+02, -9.38592562e+01, -1.06004405e+02,\n",
+ " 6.08026697e+01, 7.18244505e+01, 1.44416539e+02, 3.50710007e+00,\n",
+ " -8.44465217e+01, 2.04115121e+01, 2.58869463e+02, -7.55866500e+01,\n",
+ " -3.34409317e+02, -1.49632569e+02, -1.94744007e+01, -3.08896588e+02,\n",
+ " 3.78444598e+00, 8.54656724e+01, 7.52081951e+01, -1.37559812e+01,\n",
+ " -1.79173346e+02, -1.26028340e+03, -2.27194696e+01, 5.23688677e+01,\n",
+ " -1.66735091e+01, 4.85622033e+02, 3.10564652e+02, -1.58757232e+01,\n",
+ " -3.31621347e+02, -2.95531899e+02, 1.81221712e+02, 5.44508260e+02,\n",
+ " 4.27005201e+02, -5.30307615e+01, -5.31205211e+01, -4.85459228e+00,\n",
+ " 2.04815006e+01, 3.31166265e+01, -4.76947792e+02, 1.58542450e+02,\n",
+ " -2.13128664e+02, 9.27735792e+01, 9.66676771e+01, 6.14797866e+02,\n",
+ " -7.97564986e+01, 1.96958210e+03, 1.00142737e+02, -7.69715849e+01,\n",
+ " 2.77540953e+02, 1.26623564e+01, 3.40581064e+02, 4.55149595e+01,\n",
+ " -8.45653748e+01, 4.80985850e+02, -4.05184099e+01, 8.89421433e+01,\n",
+ " 1.27895179e+02, 1.11252558e+02, -7.57096911e+02, 2.44395909e+02,\n",
+ " 3.72071963e+02, -5.65043484e+01, -5.79231700e+01, -3.43390489e+02,\n",
+ " 4.39757382e+02, -6.23852608e+02, 5.01436050e+02, -3.15742453e+02,\n",
+ " 3.15395399e+00, 2.18964255e+02, 4.35718769e+02, 2.29079288e+00,\n",
+ " 3.90011337e+02, -5.54336886e+01, -1.85429839e+02, 1.13745778e+03,\n",
+ " 3.29118076e+02, -8.54123978e+02, -1.51382280e+02, -3.97047777e+00,\n",
+ " -1.06355206e+02, 7.71569957e+01, -1.56360634e+02, 5.34456838e+01,\n",
+ " 1.12573895e+02, 2.04149311e+02, 7.59839845e+01, 5.80588764e+01,\n",
+ " -4.66659536e+01, 3.34146533e+02, -4.81443136e+02, 3.77281573e+02,\n",
+ " 1.34721590e+02, -1.44255644e+01, -6.46960548e+02, 2.64408950e+02,\n",
+ " -3.47214931e+00, 1.19992507e+03, -2.11893490e+02, -1.10071695e+02,\n",
+ " -7.01521184e+01, -2.94239248e+02, 1.83645643e+02, -1.53451490e+02,\n",
+ " -1.22644517e+03, 2.26048089e-01, -1.42655169e+02, 2.79850492e+01,\n",
+ " 7.66303443e+01, -9.47729029e+01, 1.97191524e+01, 4.86517029e+02,\n",
+ " 2.23495596e+02, -3.73850230e+02, 2.63377040e+02, 2.09272457e+02,\n",
+ " -5.88289035e+01, -4.84050079e+01, 2.70740960e+02, -7.99068709e+02,\n",
+ " -3.66878057e+02, -7.63973385e+02, -4.13292537e+02, 1.98086211e+02,\n",
+ " -1.03536004e+02, 1.23697088e+01, -4.17482670e+01, 2.95590024e+02,\n",
+ " -1.69574644e+02, 3.79833701e+02, 1.34128663e+01, 1.84745002e+02,\n",
+ " -1.01469190e+02, -3.08288157e+02, 3.65354380e+02, -2.14962852e+02,\n",
+ " -8.16359137e+02, -4.53838509e+00, 3.52296414e+02, -1.96579018e+02,\n",
+ " -1.72003357e+02, 1.62263555e+02, -9.15909341e+01, -5.25450352e+02,\n",
+ " 8.28896229e+01, 3.64925822e+02, -4.15238930e+02, -3.00551889e+01,\n",
+ " 1.73624960e+02, -6.73581968e+01, 8.56118071e+01, 2.44054679e+02,\n",
+ " 2.65423530e+02, -7.84664065e+01, -5.63917200e+02, -2.26535357e+02,\n",
+ " -3.39408860e+02, 7.67487880e+02, 5.25857778e+02, -2.07862652e+02,\n",
+ " -7.90111584e+01, -6.42927035e+01, -1.50921781e+02, -2.48296982e+02,\n",
+ " 1.40104644e+01, -1.13081160e+02, -4.03496516e+02, -1.17400980e+02,\n",
+ " 2.21279937e+01, 7.98486666e+02, -5.69456778e+02, -2.41940064e+01,\n",
+ " 1.52367607e+01, -4.39050709e+02, -8.43244926e+01, -8.38253452e+02,\n",
+ " -1.77774862e+02, 4.31605119e+01, -1.65602382e+02, -7.68777733e+00,\n",
+ " 7.61283887e+02, -6.00505098e+01, 2.62589767e+01, 2.63702108e+02,\n",
+ " -6.06462876e+00, 5.90112292e+02, 9.99913222e+01, -2.17212078e+02,\n",
+ " -9.88708978e+01, 6.52856038e+01, -2.81793258e+02, -1.11179922e+02,\n",
+ " -5.42097873e+01, -3.90954654e+01, 3.74594954e+02, 2.72618163e+02,\n",
+ " 3.61725770e+00, 9.37238289e+02, -1.86511741e+02, -2.62908249e+02,\n",
+ " -6.39473433e+01, -3.42312827e+01, -2.28387976e+02, 1.08426708e+03,\n",
+ " -4.22935998e+02, 2.08515221e+01, -2.56221345e+02, 8.32881054e+01,\n",
+ " 3.97289413e+01, -2.33463296e+02, 3.40873141e+02, 3.04041387e+01,\n",
+ " -6.00880346e+01, -1.22311002e+02, -6.29413687e+01, 3.92335543e+02,\n",
+ " 1.57091140e+02, 6.03117352e+02, -1.93660362e+02, -1.76073525e+02,\n",
+ " -1.64792918e+02, 1.02192153e+01, -3.95728672e+01, 2.25835832e+02,\n",
+ " 9.13616106e+00, 3.55868569e+02, 3.96693384e+02, -1.90870573e+01,\n",
+ " -1.60780134e+02, -2.45235488e+02, -1.48860482e+03, 7.43950379e+01,\n",
+ " -3.34775598e+01, 2.26194525e+02, -4.95994232e+01, 3.93951431e+01,\n",
+ " -4.85774377e+02, 2.52694946e+02, 3.78179188e+02, -5.38009201e+02,\n",
+ " -6.77073195e+02, -2.16132871e+02, -9.02103410e+01, -1.39579233e+02,\n",
+ " -5.68728596e+02, 3.54164345e+02, 4.17933784e+02, 3.65128843e+02,\n",
+ " 7.10504011e+02, -7.13660438e+02, -1.95515893e+02, -6.82693278e+01,\n",
+ " 7.19606895e+02, 2.33966257e+02, -2.91438767e+02, 2.13310113e+02,\n",
+ " -4.04302008e+01, -4.69323798e+01, 2.25558720e+02, -1.47377346e+02,\n",
+ " 5.14975226e+01, -6.83498434e+02, 8.97311036e+01, 9.05280226e+01,\n",
+ " -6.01453081e+01, -5.25376670e+01, -3.34957656e+01, 2.44167522e+02,\n",
+ " 4.07360806e+01, -6.29154610e+01, -2.85165709e+01, 1.98980830e+02,\n",
+ " 3.36941256e+02, 3.36992960e+02, 3.61621697e+01, 2.57201878e+02,\n",
+ " 8.94476921e+01, 6.25883246e+01, 1.15355213e-01, -1.60152028e+02,\n",
+ " -7.65209479e+01, 2.77812887e+02, -1.76437507e+02, -4.89874615e+02,\n",
+ " 1.03260434e+02, 1.15217603e+01, -6.86051203e+01, -1.01725767e+02,\n",
+ " -6.73476290e+01, -4.01050336e+01, -3.87783949e+01, 2.98653000e+02,\n",
+ " -1.21809694e+02, -1.55348806e+02, -4.74086689e+01, 7.28132277e-01,\n",
+ " 7.39277402e+02, -1.78882057e+02, -1.39357645e+02, 4.20879411e+02,\n",
+ " 2.47663214e+02, 3.87762849e+02, -3.53553239e+02, 1.53811079e+01,\n",
+ " -5.98409627e+01, 7.90043706e+02, 3.35673586e+02, -5.50396116e+01,\n",
+ " -1.43100045e+02, 1.05812230e+02, -1.80425632e+02, 1.55410674e+01,\n",
+ " -7.65281383e+01, -1.95451382e+01, 7.68200620e+02, -7.88180714e+02,\n",
+ " -4.96891673e+02, 2.67819744e+01, 1.08721925e+03, -8.91722486e+02,\n",
+ " 4.40010959e+02, -1.25558258e+02, 1.03061925e+02, -4.22449953e+02,\n",
+ " 5.44713306e+00, -4.06047163e+01, -1.63793221e+01, -3.13656357e+02,\n",
+ " -6.08353989e+01, 3.54361002e+02, 1.76285933e+01, 1.51923724e+02,\n",
+ " 1.92162101e+01, -2.62252449e+02, 2.02090670e+02, -1.23609033e+02,\n",
+ " -6.88516934e+01, 2.02429496e+02, -1.38312387e+02, -1.60926571e+02,\n",
+ " 4.46518113e+01, 7.15780963e+02, 4.45396724e+01, 5.87099470e+01,\n",
+ " 1.04298963e+02, 3.74076581e+00, 6.79639872e+02, -4.52498268e+02,\n",
+ " -8.28534260e+02, -1.75377723e+02, -3.62640190e+02, 1.35476358e+02,\n",
+ " 1.39786683e+02, -2.30201779e+02, 1.07177174e+02, 9.06074582e+01,\n",
+ " -4.91419617e+02, 1.19215574e+02, -1.36430281e+02, 1.12422664e+02,\n",
+ " -3.91660651e+02, -2.32839084e+02, 1.04064666e+02, -1.65936318e+02,\n",
+ " 5.29146537e-01, -1.22682970e+03, 3.89356937e+01, -4.27925229e+02,\n",
+ " 4.76398283e+02, 1.64435764e+02, -5.29973771e+02, 3.95660885e+02,\n",
+ " -2.12435245e+01, -4.03862861e+02, -2.56973074e+02, 9.68488657e+01,\n",
+ " 1.90932283e+02, -1.37530842e+02, 3.95660795e+01, -4.58660771e+02,\n",
+ " 4.37097085e+02, 2.23017611e+02, 1.22532720e+01, 4.43078364e+01,\n",
+ " 5.40731243e+02, 2.43631329e+02, -1.18018489e+02, 6.16205626e+02,\n",
+ " -3.31632365e+01, 2.02470941e+02, 4.75694668e+02, -1.06257617e+02,\n",
+ " 4.13072327e+01, 4.49723049e+00, -1.26222173e+02, -8.85746700e+02,\n",
+ " 6.88777057e+02, 2.34166799e+02, -1.06900940e+02, -2.41058391e+02,\n",
+ " 8.03518036e+02, 5.92568792e+01, -2.21402522e+01, -1.82033496e+02,\n",
+ " 2.99036012e+02, 6.37655125e+02, -2.80906713e+02, 2.09324200e+02,\n",
+ " 1.83220503e+02, -4.30854094e+02, -5.44780240e+01, -4.48988206e+01,\n",
+ " 6.05954499e+02, -8.21326067e+01, -6.88605527e+01, -1.21427939e+01,\n",
+ " -9.00152897e+02, 3.73341568e+02, 4.02988610e+01, -2.90163692e+01,\n",
+ " -2.33802619e+02, 2.47549209e+02, 1.02884364e+02, -8.05420329e+01,\n",
+ " -5.54586137e+02, -1.99658437e+02, -8.78917263e+02, 9.52892481e+01,\n",
+ " 6.46478189e+02, 1.04216103e+02, 1.69069879e+02, -1.67364855e+02,\n",
+ " 1.03502358e+02, -5.57970384e+01, 1.81731216e+01, 3.41954874e+02,\n",
+ " 1.59680564e+02, -2.82358640e+01, 2.61037844e+02, 4.52394889e+00,\n",
+ " 9.81410028e+01, 2.71677980e+02, -1.05795831e+02, 3.66139152e+02,\n",
+ " 8.95496317e+00, 1.00358159e+01, 2.17295623e+02, -2.31537428e+02,\n",
+ " -2.15978970e+01, -2.79191160e+02, 2.19354022e+02, -6.59585594e+01,\n",
+ " -5.58244609e+01, -5.51836688e+01, -1.12465523e+03, -7.02358334e+01,\n",
+ " 3.10641629e+01, -1.13231174e+02, 3.15678592e+01, 1.96942607e+02,\n",
+ " 5.48713097e+02, -1.40086969e+02, -2.22228352e+01, 1.17398879e+02,\n",
+ " 4.37152834e+02, -6.21798981e+01, -1.64561657e+02, -2.57025128e+02,\n",
+ " 4.11699329e+01, 2.60826779e+02, -1.45360428e+02, -1.61680577e+03,\n",
+ " 2.47927887e+01, -1.41960640e+01, -2.70745623e+02, -2.79026588e+02,\n",
+ " -3.78512694e+02, -1.41108300e+02, -2.24785998e+00, 3.30243719e+02,\n",
+ " -7.03044201e+02, 1.64922389e+02, 7.33348145e+02, 1.47303755e+02,\n",
+ " 7.06221907e+02, -6.94255302e+00, 3.38171017e+02, 5.04749210e+02,\n",
+ " 3.22640657e+01, 1.68166245e+02, -2.14026135e+02, -1.62346444e+02,\n",
+ " -1.46112249e+02, 1.54855691e+02, 3.71393526e+01, -1.04704816e+02,\n",
+ " 2.16268193e+01, 3.38807780e+02, -3.04966547e+02, 2.58482151e+02,\n",
+ " -3.31159399e+00, -7.49718016e-01, -1.62154661e+02, -3.83102076e+01,\n",
+ " 2.42428727e+02, -1.53997929e+02, -3.45879227e+02, -3.07630457e+01,\n",
+ " 3.46320238e+01, 1.00201240e+02, 2.11664303e+02, 5.90214281e+02,\n",
+ " -5.96766676e+02, -8.23427152e+01, 1.63642467e+02, 4.74277645e+02,\n",
+ " -1.70621908e+02, 7.62292684e+01, 3.56639038e+01, -4.60248962e+02,\n",
+ " 2.48614482e+02, 1.21550784e+02, 5.58497353e+01, -5.89683481e+01,\n",
+ " -5.02041594e+02, 1.96406334e+01, -3.06529530e+02, -5.09020350e+01,\n",
+ " -1.23702466e+02, -1.72287679e+02, -3.65591655e+01, -5.50474970e+02,\n",
+ " -6.94567467e+01, -5.34370089e+02, -3.45226236e+00, -1.03854627e+02,\n",
+ " 2.36692364e+02, 7.02179892e+02, -7.61028259e-01, -4.07904227e+02,\n",
+ " -9.75531587e+01, 2.82244405e+02, 1.41124894e+01, 1.09383800e+02,\n",
+ " 5.63751727e+01, -7.11996539e+00, 1.98468044e+02, 2.41003767e+01,\n",
+ " 2.92926186e+01, 1.21397993e+01, -3.89796531e+01, -1.04284892e+01,\n",
+ " -1.79177479e+02, 2.15083225e+02, -3.78225255e+02, 1.70771921e+02,\n",
+ " -2.82337234e+01, 2.67594644e+02, -5.56561383e+02, 5.37376020e+00,\n",
+ " -4.10702199e+01, -4.17935263e+02, -1.11855531e+02, 1.54278095e+01,\n",
+ " -6.38703309e+01, -9.66506324e+01, -2.11503155e+01, -3.61509327e+02,\n",
+ " 2.95453414e+02, -1.31214255e+02, -9.17096183e+02, 1.45599246e+02,\n",
+ " 1.28740325e+02, -5.60369466e+02, 3.07651833e+02, -2.62805720e+02,\n",
+ " 1.76900643e+01, 6.23344322e+01, 4.04255109e+02, 4.35344447e+02,\n",
+ " -2.94036451e+01, 1.90840844e+02, -6.69218772e+02, -1.65055963e+01,\n",
+ " 9.08236248e+01, 8.04144817e+01, -2.34742058e+02, -1.99511017e+00,\n",
+ " 1.78905650e+01, -1.85270017e+01, 8.34858627e+01, -2.71532337e+01,\n",
+ " -1.61634039e+02, -2.18522363e+01, 6.41564793e+01, -1.17988135e+02,\n",
+ " -7.91123583e+01, 3.50110433e+02, 2.85223270e+00, -2.57747131e+01,\n",
+ " 5.71105511e+02, 4.87136833e+02, 6.20880696e+02, 4.07236369e+01,\n",
+ " 1.55305175e+02, 4.07285221e+02, -1.54810757e+02, -1.14964559e+02,\n",
+ " -3.23552775e+01, -2.23286233e+01, 2.14934372e+02, -2.59321590e+02,\n",
+ " 9.60100864e+00, -2.24871393e+01, 4.33309302e+02, 1.23502172e+02,\n",
+ " -3.93325256e+02, 3.66260206e+02, -1.03455591e+02, 2.19379838e+01,\n",
+ " -5.38480193e+02, -7.84929402e+02, 3.65232859e+01, 1.10127050e+02,\n",
+ " 7.26935632e+02, -4.98308610e+02, 4.05822748e+01, -2.23530352e+02,\n",
+ " 1.78072720e+01, -1.37540927e+02, -9.90726995e+01, 1.04018270e+02,\n",
+ " -2.91227335e+02, -8.22906489e+01, -5.01549706e+02, 1.20884921e+02,\n",
+ " 8.48964873e+01, 2.24099340e+02, -8.24509118e+01, 1.77549928e+02,\n",
+ " -1.17315958e+02, -6.76135021e+01, -1.88175420e+02, -1.44918835e+02,\n",
+ " 5.92201253e+02, -1.01987421e+03, 1.09268459e+03, 9.33464461e+00,\n",
+ " 2.01651359e+02, 2.37049175e+01, -6.92801624e+02, -6.83773347e+01,\n",
+ " 2.18194258e+02, 3.00241782e+02, -3.06167625e+02, 1.61619599e+02,\n",
+ " -1.79403848e+03, 1.02841722e+00, 4.01435084e+02, -1.20904004e+03,\n",
+ " -2.60210821e+02, -5.86922855e+01, 1.29339085e+01, 7.17676622e+02,\n",
+ " -7.26463985e+02, 6.19181585e+01, 2.06696166e+02, -1.44014718e+01,\n",
+ " -4.62302557e+01, 3.53429384e+01, 9.51814255e+01, 4.81736997e+01,\n",
+ " 8.03370733e+02, -1.21790083e+02, -5.13789734e+01, -2.13694805e+02,\n",
+ " 2.89897409e+02, 5.45625619e+01, -5.39667776e+01, 1.81784077e+01,\n",
+ " 4.08197700e+01, -2.49524624e+02, -1.78084064e+01, -3.85273563e+02]))"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import statsmodels.api as sm\n",
+ "\n",
+ "def drdid_rc(y, post, D, covariates, i_weights=None):\n",
+ " # Convert inputs to numpy arrays\n",
+ " D = np.asarray(D).flatten()\n",
+ " n = len(D)\n",
+ " y = np.asarray(y).flatten()\n",
+ " post = np.asarray(post).flatten()\n",
+ " \n",
+ " # Add constant to covariate matrix\n",
+ " if covariates is None:\n",
+ " int_cov = np.ones((n, 1))\n",
+ " else:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " # Normalize weights\n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " i_weights = np.asanyarray(i_weights).flatten()\n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ " \n",
+ " # Compute the Pscore by MLE\n",
+ " pscore_model = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights)\n",
+ " pscore_results = pscore_model.fit()\n",
+ " if not pscore_results.converged:\n",
+ " print(\"Warning: glm algorithm did not converge\")\n",
+ " if np.any(np.isnan(pscore_results.params)):\n",
+ " raise ValueError(\"Propensity score model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " ps_fit = pscore_results.predict()\n",
+ " ps_fit = np.minimum(ps_fit, 1 - 1e-16)\n",
+ " \n",
+ " # Compute the Outcome regression for the control group at the pre-treatment period\n",
+ " mask_cont_pre = (D == 0) & (post == 0)\n",
+ " reg_cont_pre = sm.WLS(y[mask_cont_pre], int_cov[mask_cont_pre], weights=i_weights[mask_cont_pre]).fit()\n",
+ " if np.any(np.isnan(reg_cont_pre.params)):\n",
+ " raise ValueError(\"Outcome regression model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " out_y_cont_pre = np.dot(int_cov, reg_cont_pre.params)\n",
+ " \n",
+ " # Compute the Outcome regression for the control group at the post-treatment period\n",
+ " mask_cont_post = (D == 0) & (post == 1)\n",
+ " reg_cont_post = sm.WLS(y[mask_cont_post], int_cov[mask_cont_post], weights=i_weights[mask_cont_post]).fit()\n",
+ " if np.any(np.isnan(reg_cont_post.params)):\n",
+ " raise ValueError(\"Outcome regression model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " out_y_cont_post = np.dot(int_cov, reg_cont_post.params)\n",
+ " \n",
+ " # Combine the ORs for control group\n",
+ " out_y_cont = post * out_y_cont_post + (1 - post) * out_y_cont_pre\n",
+ " \n",
+ " # Compute the Outcome regression for the treated group at the pre-treatment period\n",
+ " mask_treat_pre = (D == 1) & (post == 0)\n",
+ " reg_treat_pre = sm.WLS(y[mask_treat_pre], int_cov[mask_treat_pre], weights=i_weights[mask_treat_pre]).fit()\n",
+ " out_y_treat_pre = np.dot(int_cov, reg_treat_pre.params)\n",
+ " \n",
+ " # Compute the Outcome regression for the treated group at the post-treatment period\n",
+ " mask_treat_post = (D == 1) & (post == 1)\n",
+ " reg_treat_post = sm.WLS(y[mask_treat_post], int_cov[mask_treat_post], weights=i_weights[mask_treat_post]).fit()\n",
+ " out_y_treat_post = np.dot(int_cov, reg_treat_post.params)\n",
+ " \n",
+ " # Compute weights\n",
+ " w_treat_pre = i_weights * D * (1 - post)\n",
+ " w_treat_post = i_weights * D * post\n",
+ " w_cont_pre = i_weights * ps_fit * (1 - D) * (1 - post) / (1 - ps_fit)\n",
+ " w_cont_post = i_weights * ps_fit * (1 - D) * post / (1 - ps_fit)\n",
+ " w_d = i_weights * D\n",
+ " w_dt1 = i_weights * D * post\n",
+ " w_dt0 = i_weights * D * (1 - post)\n",
+ " \n",
+ " # Elements of the influence function (summands)\n",
+ " eta_treat_pre = w_treat_pre * (y - out_y_cont) / np.mean(w_treat_pre)\n",
+ " eta_treat_post = w_treat_post * (y - out_y_cont) / np.mean(w_treat_post)\n",
+ " eta_cont_pre = w_cont_pre * (y - out_y_cont) / np.mean(w_cont_pre)\n",
+ " eta_cont_post = w_cont_post * (y - out_y_cont) / np.mean(w_cont_post)\n",
+ " \n",
+ " # Extra elements for the locally efficient DRDID\n",
+ " eta_d_post = w_d * (out_y_treat_post - out_y_cont_post) / np.mean(w_d)\n",
+ " eta_dt1_post = w_dt1 * (out_y_treat_post - out_y_cont_post) / np.mean(w_dt1)\n",
+ " eta_d_pre = w_d * (out_y_treat_pre - out_y_cont_pre) / np.mean(w_d)\n",
+ " eta_dt0_pre = w_dt0 * (out_y_treat_pre - out_y_cont_pre) / np.mean(w_dt0)\n",
+ " \n",
+ " # Estimator of each component\n",
+ " att_treat_pre = np.mean(eta_treat_pre)\n",
+ " att_treat_post = np.mean(eta_treat_post)\n",
+ " att_cont_pre = np.mean(eta_cont_pre)\n",
+ " att_cont_post = np.mean(eta_cont_post)\n",
+ " att_d_post = np.mean(eta_d_post)\n",
+ " att_dt1_post = np.mean(eta_dt1_post)\n",
+ " att_d_pre = np.mean(eta_d_pre)\n",
+ " att_dt0_pre = np.mean(eta_dt0_pre)\n",
+ " \n",
+ " # ATT estimator\n",
+ " dr_att = (att_treat_post - att_treat_pre) - (att_cont_post - att_cont_pre) + \\\n",
+ " (att_d_post - att_dt1_post) - (att_d_pre - att_dt0_pre)\n",
+ " \n",
+ " # Compute influence functions\n",
+ " weights_ols_pre = i_weights * (1 - D) * (1 - post)\n",
+ " weights_ols_pre = np.array(weights_ols_pre)\n",
+ " wols_x_pre = weights_ols_pre[:, np.newaxis] * int_cov\n",
+ " wols_eX_pre = weights_ols_pre[:, np.newaxis] * (y - out_y_cont_pre)[:, np.newaxis] * int_cov\n",
+ " XpX_inv_pre = np.linalg.inv(np.dot(wols_x_pre.T, int_cov) / n)\n",
+ " asy_lin_rep_ols_pre = np.dot(wols_eX_pre, XpX_inv_pre)\n",
+ " \n",
+ " weights_ols_post = i_weights * (1 - D) * post\n",
+ " \n",
+ " wols_x_post = weights_ols_post[:, np.newaxis] * int_cov\n",
+ " wols_eX_post = weights_ols_post[:, np.newaxis] * (y - out_y_cont_post)[:, np.newaxis] * int_cov\n",
+ " XpX_inv_post = np.linalg.inv(np.dot(wols_x_post.T, int_cov) / n)\n",
+ " asy_lin_rep_ols_post = np.dot(wols_eX_post, XpX_inv_post)\n",
+ " \n",
+ " weights_ols_pre_treat = i_weights * D * (1 - post)\n",
+ " wols_x_pre_treat = weights_ols_pre_treat[:, np.newaxis] * int_cov\n",
+ " wols_eX_pre_treat = weights_ols_pre_treat[:, np.newaxis] * (y - out_y_treat_pre)[:, np.newaxis] * int_cov\n",
+ " XpX_inv_pre_treat = np.linalg.inv(np.dot(wols_x_pre_treat.T, int_cov) / n)\n",
+ " asy_lin_rep_ols_pre_treat = np.dot(wols_eX_pre_treat, XpX_inv_pre_treat)\n",
+ " \n",
+ " weights_ols_post_treat = i_weights * D * post\n",
+ " wols_x_post_treat = weights_ols_post_treat[:, np.newaxis] * int_cov\n",
+ " wols_eX_post_treat = weights_ols_post_treat[:, np.newaxis] * (y - out_y_treat_post)[:, np.newaxis] * int_cov\n",
+ " XpX_inv_post_treat = np.linalg.inv(np.dot(wols_x_post_treat.T, int_cov) / n)\n",
+ " asy_lin_rep_ols_post_treat = np.dot(wols_eX_post_treat, XpX_inv_post_treat)\n",
+ " \n",
+ " score_ps = i_weights[:, np.newaxis] * (D - ps_fit)[:, np.newaxis] * int_cov\n",
+ " Hessian_ps = pscore_results.cov_params() * n\n",
+ " asy_lin_rep_ps = np.dot(score_ps, Hessian_ps)\n",
+ " \n",
+ " # Influence function components\n",
+ " inf_treat_pre = eta_treat_pre - w_treat_pre * att_treat_pre / np.mean(w_treat_pre)\n",
+ " inf_treat_post = eta_treat_post - w_treat_post * att_treat_post / np.mean(w_treat_post)\n",
+ " \n",
+ " M1_post = -np.mean(w_treat_post[:, np.newaxis] * post[:, np.newaxis] * int_cov, axis=0) / np.mean(w_treat_post)\n",
+ " M1_pre = -np.mean(w_treat_pre[:, np.newaxis] * (1 - post)[:, np.newaxis] * int_cov, axis=0) / np.mean(w_treat_pre)\n",
+ " \n",
+ " inf_treat_or_post = np.dot(asy_lin_rep_ols_post, M1_post)\n",
+ " inf_treat_or_pre = np.dot(asy_lin_rep_ols_pre, M1_pre)\n",
+ " inf_treat_or = inf_treat_or_post + inf_treat_or_pre\n",
+ " \n",
+ " inf_treat = inf_treat_post - inf_treat_pre + inf_treat_or\n",
+ " \n",
+ " inf_cont_pre = eta_cont_pre - w_cont_pre * att_cont_pre / np.mean(w_cont_pre)\n",
+ " inf_cont_post = eta_cont_post - w_cont_post * att_cont_post / np.mean(w_cont_post)\n",
+ " \n",
+ " M2_pre = np.mean(w_cont_pre[:, np.newaxis] * (y - out_y_cont - att_cont_pre)[:, np.newaxis] * int_cov, axis=0) / np.mean(w_cont_pre)\n",
+ " M2_post = np.mean(w_cont_post[:, np.newaxis] * (y - out_y_cont - att_cont_post)[:, np.newaxis] * int_cov, axis=0) / np.mean(w_cont_post)\n",
+ " \n",
+ " inf_cont_ps = np.dot(asy_lin_rep_ps, M2_post - M2_pre)\n",
+ " \n",
+ " M3_post = -np.mean(w_cont_post[:, np.newaxis] * post[:, np.newaxis] * int_cov, axis=0) / np.mean(w_cont_post)\n",
+ " M3_pre = -np.mean(w_cont_pre[:, np.newaxis] * (1 - post)[:, np.newaxis] * int_cov, axis=0) / np.mean(w_cont_pre)\n",
+ " \n",
+ " inf_cont_or_post = np.dot(asy_lin_rep_ols_post, M3_post)\n",
+ " inf_cont_or_pre = np.dot(asy_lin_rep_ols_pre, M3_pre)\n",
+ " inf_cont_or = inf_cont_or_post + inf_cont_or_pre\n",
+ " \n",
+ " inf_cont = inf_cont_post - inf_cont_pre + inf_cont_ps + inf_cont_or\n",
+ " \n",
+ " dr_att_inf_func1 = inf_treat - inf_cont\n",
+ " \n",
+ " inf_eff1 = eta_d_post - w_d * att_d_post / np.mean(w_d)\n",
+ " inf_eff2 = eta_dt1_post - w_dt1 * att_dt1_post / np.mean(w_dt1)\n",
+ " inf_eff3 = eta_d_pre - w_d * att_d_pre / np.mean(w_d)\n",
+ " inf_eff4 = eta_dt0_pre - w_dt0 * att_dt0_pre / np.mean(w_dt0)\n",
+ " inf_eff = (inf_eff1 - inf_eff2) - (inf_eff3 - inf_eff4)\n",
+ " \n",
+ " mom_post = np.mean((w_d[:, np.newaxis] / np.mean(w_d) - w_dt1[:, np.newaxis] / np.mean(w_dt1)) * int_cov, axis=0)\n",
+ " mom_pre = np.mean((w_d[:, np.newaxis] / np.mean(w_d) - w_dt0[:, np.newaxis] / np.mean(w_dt0)) * int_cov, axis=0)\n",
+ " inf_or_post = np.dot(asy_lin_rep_ols_post_treat - asy_lin_rep_ols_post, mom_post)\n",
+ " inf_or_pre = np.dot(asy_lin_rep_ols_pre_treat - asy_lin_rep_ols_pre, mom_pre)\n",
+ " inf_or = inf_or_post - inf_or_pre\n",
+ " \n",
+ " dr_att_inf_func = dr_att_inf_func1 + inf_eff + inf_or\n",
+ " \n",
+ " return dr_att, dr_att_inf_func\n",
+ "\n",
+ "drdid_rc(y, post, d, x, w)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11.18310945627825\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(-32.44762929840434,\n",
+ " array([ 4.20279333e+02, 4.12088741e+02, -9.76295144e+01, 4.13877211e+02,\n",
+ " -1.33180878e+02, 3.46372794e+02, 3.33751840e+01, -2.36632190e+02,\n",
+ " -2.87278416e+02, -5.01885265e+01, -2.56919638e+01, -4.36605269e+02,\n",
+ " -5.82711118e+02, -1.74750149e+02, -1.05917795e+02, 3.17576284e+02,\n",
+ " 5.22646939e+02, -3.92097558e+02, -1.74362318e+02, 3.25937355e+02,\n",
+ " 1.15121568e+01, 9.67115269e+01, 2.37900939e+02, -2.31235308e+02,\n",
+ " -1.67406181e+02, 2.49385835e+02, 7.23160188e+02, 1.39005288e+02,\n",
+ " -5.82824544e+01, -2.09472417e+02, -1.67529498e+02, -1.98673017e+02,\n",
+ " -1.10899807e+02, -2.23694179e+01, -8.20674968e+02, 5.35943555e+02,\n",
+ " 6.10833988e+02, 4.23615577e+02, 1.39299854e+02, 2.29206492e+01,\n",
+ " -6.33676385e+02, -3.95707388e+02, 1.89663298e+01, 1.01334683e+02,\n",
+ " 6.97094075e+01, -1.95271048e+02, -3.44512247e+01, -6.72134538e+02,\n",
+ " -2.95689468e+02, 2.80919019e+01, -2.30572474e+02, 3.06118680e+02,\n",
+ " -2.59103267e+02, 8.90688138e+01, -1.17068460e+02, -1.74855945e+02,\n",
+ " -4.21579168e+02, -3.49463513e+02, 3.58548856e+02, 2.74070866e+02,\n",
+ " 8.94619601e+02, 2.72925232e+02, 1.97267771e+02, 6.14555614e+02,\n",
+ " -2.12822803e+02, -7.44529056e+02, -2.50036755e+02, 1.11879526e+02,\n",
+ " -3.64114609e+01, 1.26722740e+01, 2.26052626e+02, 3.52279091e+02,\n",
+ " -1.58602401e+02, -2.14067743e+02, 5.13864039e+02, 5.90513691e+01,\n",
+ " 2.11973267e+01, -2.98048648e+02, 9.06978843e+01, -2.29889520e+02,\n",
+ " -1.04690681e+02, -6.31612943e+01, -8.17161981e+01, 5.40279596e+02,\n",
+ " -7.61069251e+02, -1.26693439e+02, 1.08962854e+03, 4.88141074e+01,\n",
+ " -1.79310459e+02, -1.91501515e+02, -3.82569248e+00, 5.77276974e+01,\n",
+ " -3.82459487e+02, -1.47957299e+02, -2.42341055e+02, -3.11245840e+02,\n",
+ " 8.91754223e+00, -3.57126608e+02, -1.50323548e+02, -8.92144522e+02,\n",
+ " 3.07419995e+01, 1.01969269e+02, -1.04012348e+01, -2.97688581e+02,\n",
+ " 5.50136302e+02, -1.30448299e+02, 1.41973622e+02, 1.12012366e+01,\n",
+ " 7.30720715e+02, -1.02790672e+01, 3.55718639e+01, -2.16537343e+01,\n",
+ " -3.71137958e+02, 7.26427434e+00, 9.35600864e+01, -1.25298182e+02,\n",
+ " 1.08647879e+02, -7.50123676e+01, 4.18921461e+01, -5.15391747e+01,\n",
+ " 5.38422193e+02, 1.37600027e+02, 2.20954413e+02, -3.04370282e+01,\n",
+ " 9.07984789e+01, -4.98343301e+01, 4.58732880e+02, -1.92490096e+02,\n",
+ " -9.38128679e+01, -1.15051818e+02, 1.31099359e+02, 2.87583768e+02,\n",
+ " 3.07189365e+01, 3.53247809e+00, -8.18715843e+01, -2.04526732e+02,\n",
+ " 2.32926404e+02, -2.36087460e+02, -1.37678998e+02, 1.78550004e+00,\n",
+ " 4.87878523e+02, 3.71738576e+02, -8.84080383e+02, -6.24049505e+00,\n",
+ " 8.29056945e+01, 4.57399586e+01, 2.97905782e+02, 5.91714970e+02,\n",
+ " 4.48105143e+01, -6.92797977e+01, 2.55354734e+02, -7.89225035e+01,\n",
+ " 1.43838620e+02, -3.93631793e+02, 4.76133717e+02, 9.16183723e+02,\n",
+ " -1.44540113e+02, -2.57889488e+02, -3.95310404e+02, 4.25819137e+01,\n",
+ " 4.38843307e+01, 9.34049278e+01, -4.17071476e+01, -8.38207367e+01,\n",
+ " 5.60322689e+01, 1.06349442e+01, 3.09546182e+02, 3.48801424e+02,\n",
+ " -1.05768649e+02, 7.43660156e+01, -2.19060274e+02, 1.99052090e+02,\n",
+ " -1.17751964e+02, 1.02684457e+02, 4.49686979e+02, 3.90078176e+01,\n",
+ " 3.22412138e+02, -1.38490567e+03, -1.48277066e+02, 2.18563981e+01,\n",
+ " 6.36586969e+02, 1.66128829e+01, 3.98819578e+01, 5.43968673e+02,\n",
+ " 1.70087616e+01, -1.12945500e+03, 8.16595074e+01, 2.92711225e+00,\n",
+ " 6.03544081e+01, -1.14685923e+01, 2.77633051e+02, 6.06500732e+02,\n",
+ " 4.56211154e+02, 1.75456319e+02, -1.12916110e+03, 1.31566141e+02,\n",
+ " 1.84878012e+01, 8.49839785e+02, 1.75981779e+02, 1.43690437e+02,\n",
+ " -3.04227043e+01, 2.46340357e+02, 2.13050478e+02, -1.62891702e+02,\n",
+ " -1.37749443e+02, 3.04928038e+02, -2.08060148e+02, 3.08937749e+02,\n",
+ " 3.24723186e+02, -1.98816285e+02, 7.95471735e+01, 1.90358746e+02,\n",
+ " 1.83465649e+02, -5.38050094e+02, -1.12880154e+02, -3.62295270e+02,\n",
+ " 1.62922507e+01, -1.85170208e+02, 9.86646641e+01, 2.98685240e+01,\n",
+ " 2.01606241e+02, -7.98828254e+00, -3.61456958e+02, 1.11270776e+03,\n",
+ " 2.66621054e+01, 3.01496066e+02, -1.25142347e+03, 5.28971079e+02,\n",
+ " -2.52826572e+02, 8.86681627e+01, 6.21972418e+01, -4.72960721e+01,\n",
+ " -6.85708127e+02, -2.30109786e-01, 7.36786880e+01, -6.28705485e+02,\n",
+ " 1.19764430e+02, 1.06302137e+02, -4.01034981e+01, -2.61794027e+01,\n",
+ " 3.89238246e+02, -5.44269320e+01, 4.80295667e+02, -5.80919478e+01,\n",
+ " 6.61172014e+01, 1.28778609e+02, -1.82567742e+02, 1.57025548e+02,\n",
+ " 3.59861570e+01, -2.75698382e+00, 2.38948770e+02, 5.33809115e+02,\n",
+ " 1.85202542e+02, 7.37137754e+01, -2.74456697e+01, -6.41486558e+01,\n",
+ " 1.64255000e+02, -2.83741759e+02, -3.24654203e+01, -5.25302612e+01,\n",
+ " -2.75642049e+02, 1.07046194e+02, 5.51808554e+00, -1.22072567e+03,\n",
+ " -1.73135777e+02, 5.47295474e+01, 9.73685499e-01, -6.55737638e+02,\n",
+ " 1.64032793e+02, -1.38059072e+02, 5.82306220e+02, -5.71054927e+02,\n",
+ " -2.20443680e+02, -6.59930511e+02, 9.73798237e+01, 4.41430648e+01,\n",
+ " -2.63824326e+02, 2.80282041e+02, -1.60419255e+02, 8.76880647e+01,\n",
+ " -5.71382813e+02, 4.80383968e+01, 7.52916372e+02, 4.18394491e+01,\n",
+ " 3.69554916e+01, 3.22772613e+02, 7.54799751e+02, -1.52723120e+02,\n",
+ " 7.21342420e+00, -3.11212384e+01, 9.21140746e+01, 7.29970195e+01,\n",
+ " -3.02675395e+02, -3.06709950e+01, -1.92026723e+02, 3.28174587e+02,\n",
+ " -1.85300133e+01, 2.86472860e+02, 1.89005851e+02, 2.10833505e+01,\n",
+ " -3.76490155e+01, -1.04587893e+02, -3.08966353e+02, -4.19383505e+02,\n",
+ " -9.62418025e+01, -6.79089974e+01, 4.15009663e+01, 5.80745707e+02,\n",
+ " -2.20347325e+02, -1.77342315e+02, 1.04110341e+03, -2.16501200e+02,\n",
+ " 6.78717380e+01, -7.46661995e+01, -1.39958107e+02, 5.77063258e+01,\n",
+ " 2.45207141e+02, 1.94375988e+02, -9.38592562e+01, -1.06004405e+02,\n",
+ " 6.08026697e+01, 7.18244505e+01, 1.44416539e+02, 3.50710007e+00,\n",
+ " -8.44465217e+01, 2.04115121e+01, 2.58869463e+02, -7.55866500e+01,\n",
+ " -3.34409317e+02, -1.49632569e+02, -1.94744007e+01, -3.08896588e+02,\n",
+ " 3.78444598e+00, 8.54656724e+01, 7.52081951e+01, -1.37559812e+01,\n",
+ " -1.79173346e+02, -1.26028340e+03, -2.27194696e+01, 5.23688677e+01,\n",
+ " -1.66735091e+01, 4.85622033e+02, 3.10564652e+02, -1.58757232e+01,\n",
+ " -3.31621347e+02, -2.95531899e+02, 1.81221712e+02, 5.44508260e+02,\n",
+ " 4.27005201e+02, -5.30307615e+01, -5.31205211e+01, -4.85459228e+00,\n",
+ " 2.04815006e+01, 3.31166265e+01, -4.76947792e+02, 1.58542450e+02,\n",
+ " -2.13128664e+02, 9.27735792e+01, 9.66676771e+01, 6.14797866e+02,\n",
+ " -7.97564986e+01, 1.96958210e+03, 1.00142737e+02, -7.69715849e+01,\n",
+ " 2.77540953e+02, 1.26623564e+01, 3.40581064e+02, 4.55149595e+01,\n",
+ " -8.45653748e+01, 4.80985850e+02, -4.05184099e+01, 8.89421433e+01,\n",
+ " 1.27895179e+02, 1.11252558e+02, -7.57096911e+02, 2.44395909e+02,\n",
+ " 3.72071963e+02, -5.65043484e+01, -5.79231700e+01, -3.43390489e+02,\n",
+ " 4.39757382e+02, -6.23852608e+02, 5.01436050e+02, -3.15742453e+02,\n",
+ " 3.15395399e+00, 2.18964255e+02, 4.35718769e+02, 2.29079288e+00,\n",
+ " 3.90011337e+02, -5.54336886e+01, -1.85429839e+02, 1.13745778e+03,\n",
+ " 3.29118076e+02, -8.54123978e+02, -1.51382280e+02, -3.97047777e+00,\n",
+ " -1.06355206e+02, 7.71569957e+01, -1.56360634e+02, 5.34456838e+01,\n",
+ " 1.12573895e+02, 2.04149311e+02, 7.59839845e+01, 5.80588764e+01,\n",
+ " -4.66659536e+01, 3.34146533e+02, -4.81443136e+02, 3.77281573e+02,\n",
+ " 1.34721590e+02, -1.44255644e+01, -6.46960548e+02, 2.64408950e+02,\n",
+ " -3.47214931e+00, 1.19992507e+03, -2.11893490e+02, -1.10071695e+02,\n",
+ " -7.01521184e+01, -2.94239248e+02, 1.83645643e+02, -1.53451490e+02,\n",
+ " -1.22644517e+03, 2.26048089e-01, -1.42655169e+02, 2.79850492e+01,\n",
+ " 7.66303443e+01, -9.47729029e+01, 1.97191524e+01, 4.86517029e+02,\n",
+ " 2.23495596e+02, -3.73850230e+02, 2.63377040e+02, 2.09272457e+02,\n",
+ " -5.88289035e+01, -4.84050079e+01, 2.70740960e+02, -7.99068709e+02,\n",
+ " -3.66878057e+02, -7.63973385e+02, -4.13292537e+02, 1.98086211e+02,\n",
+ " -1.03536004e+02, 1.23697088e+01, -4.17482670e+01, 2.95590024e+02,\n",
+ " -1.69574644e+02, 3.79833701e+02, 1.34128663e+01, 1.84745002e+02,\n",
+ " -1.01469190e+02, -3.08288157e+02, 3.65354380e+02, -2.14962852e+02,\n",
+ " -8.16359137e+02, -4.53838509e+00, 3.52296414e+02, -1.96579018e+02,\n",
+ " -1.72003357e+02, 1.62263555e+02, -9.15909341e+01, -5.25450352e+02,\n",
+ " 8.28896229e+01, 3.64925822e+02, -4.15238930e+02, -3.00551889e+01,\n",
+ " 1.73624960e+02, -6.73581968e+01, 8.56118071e+01, 2.44054679e+02,\n",
+ " 2.65423530e+02, -7.84664065e+01, -5.63917200e+02, -2.26535357e+02,\n",
+ " -3.39408860e+02, 7.67487880e+02, 5.25857778e+02, -2.07862652e+02,\n",
+ " -7.90111584e+01, -6.42927035e+01, -1.50921781e+02, -2.48296982e+02,\n",
+ " 1.40104644e+01, -1.13081160e+02, -4.03496516e+02, -1.17400980e+02,\n",
+ " 2.21279937e+01, 7.98486666e+02, -5.69456778e+02, -2.41940064e+01,\n",
+ " 1.52367607e+01, -4.39050709e+02, -8.43244926e+01, -8.38253452e+02,\n",
+ " -1.77774862e+02, 4.31605119e+01, -1.65602382e+02, -7.68777733e+00,\n",
+ " 7.61283887e+02, -6.00505098e+01, 2.62589767e+01, 2.63702108e+02,\n",
+ " -6.06462876e+00, 5.90112292e+02, 9.99913222e+01, -2.17212078e+02,\n",
+ " -9.88708978e+01, 6.52856038e+01, -2.81793258e+02, -1.11179922e+02,\n",
+ " -5.42097873e+01, -3.90954654e+01, 3.74594954e+02, 2.72618163e+02,\n",
+ " 3.61725770e+00, 9.37238289e+02, -1.86511741e+02, -2.62908249e+02,\n",
+ " -6.39473433e+01, -3.42312827e+01, -2.28387976e+02, 1.08426708e+03,\n",
+ " -4.22935998e+02, 2.08515221e+01, -2.56221345e+02, 8.32881054e+01,\n",
+ " 3.97289413e+01, -2.33463296e+02, 3.40873141e+02, 3.04041387e+01,\n",
+ " -6.00880346e+01, -1.22311002e+02, -6.29413687e+01, 3.92335543e+02,\n",
+ " 1.57091140e+02, 6.03117352e+02, -1.93660362e+02, -1.76073525e+02,\n",
+ " -1.64792918e+02, 1.02192153e+01, -3.95728672e+01, 2.25835832e+02,\n",
+ " 9.13616106e+00, 3.55868569e+02, 3.96693384e+02, -1.90870573e+01,\n",
+ " -1.60780134e+02, -2.45235488e+02, -1.48860482e+03, 7.43950379e+01,\n",
+ " -3.34775598e+01, 2.26194525e+02, -4.95994232e+01, 3.93951431e+01,\n",
+ " -4.85774377e+02, 2.52694946e+02, 3.78179188e+02, -5.38009201e+02,\n",
+ " -6.77073195e+02, -2.16132871e+02, -9.02103410e+01, -1.39579233e+02,\n",
+ " -5.68728596e+02, 3.54164345e+02, 4.17933784e+02, 3.65128843e+02,\n",
+ " 7.10504011e+02, -7.13660438e+02, -1.95515893e+02, -6.82693278e+01,\n",
+ " 7.19606895e+02, 2.33966257e+02, -2.91438767e+02, 2.13310113e+02,\n",
+ " -4.04302008e+01, -4.69323798e+01, 2.25558720e+02, -1.47377346e+02,\n",
+ " 5.14975226e+01, -6.83498434e+02, 8.97311036e+01, 9.05280226e+01,\n",
+ " -6.01453081e+01, -5.25376670e+01, -3.34957656e+01, 2.44167522e+02,\n",
+ " 4.07360806e+01, -6.29154610e+01, -2.85165709e+01, 1.98980830e+02,\n",
+ " 3.36941256e+02, 3.36992960e+02, 3.61621697e+01, 2.57201878e+02,\n",
+ " 8.94476921e+01, 6.25883246e+01, 1.15355213e-01, -1.60152028e+02,\n",
+ " -7.65209479e+01, 2.77812887e+02, -1.76437507e+02, -4.89874615e+02,\n",
+ " 1.03260434e+02, 1.15217603e+01, -6.86051203e+01, -1.01725767e+02,\n",
+ " -6.73476290e+01, -4.01050336e+01, -3.87783949e+01, 2.98653000e+02,\n",
+ " -1.21809694e+02, -1.55348806e+02, -4.74086689e+01, 7.28132277e-01,\n",
+ " 7.39277402e+02, -1.78882057e+02, -1.39357645e+02, 4.20879411e+02,\n",
+ " 2.47663214e+02, 3.87762849e+02, -3.53553239e+02, 1.53811079e+01,\n",
+ " -5.98409627e+01, 7.90043706e+02, 3.35673586e+02, -5.50396116e+01,\n",
+ " -1.43100045e+02, 1.05812230e+02, -1.80425632e+02, 1.55410674e+01,\n",
+ " -7.65281383e+01, -1.95451382e+01, 7.68200620e+02, -7.88180714e+02,\n",
+ " -4.96891673e+02, 2.67819744e+01, 1.08721925e+03, -8.91722486e+02,\n",
+ " 4.40010959e+02, -1.25558258e+02, 1.03061925e+02, -4.22449953e+02,\n",
+ " 5.44713306e+00, -4.06047163e+01, -1.63793221e+01, -3.13656357e+02,\n",
+ " -6.08353989e+01, 3.54361002e+02, 1.76285933e+01, 1.51923724e+02,\n",
+ " 1.92162101e+01, -2.62252449e+02, 2.02090670e+02, -1.23609033e+02,\n",
+ " -6.88516934e+01, 2.02429496e+02, -1.38312387e+02, -1.60926571e+02,\n",
+ " 4.46518113e+01, 7.15780963e+02, 4.45396724e+01, 5.87099470e+01,\n",
+ " 1.04298963e+02, 3.74076581e+00, 6.79639872e+02, -4.52498268e+02,\n",
+ " -8.28534260e+02, -1.75377723e+02, -3.62640190e+02, 1.35476358e+02,\n",
+ " 1.39786683e+02, -2.30201779e+02, 1.07177174e+02, 9.06074582e+01,\n",
+ " -4.91419617e+02, 1.19215574e+02, -1.36430281e+02, 1.12422664e+02,\n",
+ " -3.91660651e+02, -2.32839084e+02, 1.04064666e+02, -1.65936318e+02,\n",
+ " 5.29146537e-01, -1.22682970e+03, 3.89356937e+01, -4.27925229e+02,\n",
+ " 4.76398283e+02, 1.64435764e+02, -5.29973771e+02, 3.95660885e+02,\n",
+ " -2.12435245e+01, -4.03862861e+02, -2.56973074e+02, 9.68488657e+01,\n",
+ " 1.90932283e+02, -1.37530842e+02, 3.95660795e+01, -4.58660771e+02,\n",
+ " 4.37097085e+02, 2.23017611e+02, 1.22532720e+01, 4.43078364e+01,\n",
+ " 5.40731243e+02, 2.43631329e+02, -1.18018489e+02, 6.16205626e+02,\n",
+ " -3.31632365e+01, 2.02470941e+02, 4.75694668e+02, -1.06257617e+02,\n",
+ " 4.13072327e+01, 4.49723049e+00, -1.26222173e+02, -8.85746700e+02,\n",
+ " 6.88777057e+02, 2.34166799e+02, -1.06900940e+02, -2.41058391e+02,\n",
+ " 8.03518036e+02, 5.92568792e+01, -2.21402522e+01, -1.82033496e+02,\n",
+ " 2.99036012e+02, 6.37655125e+02, -2.80906713e+02, 2.09324200e+02,\n",
+ " 1.83220503e+02, -4.30854094e+02, -5.44780240e+01, -4.48988206e+01,\n",
+ " 6.05954499e+02, -8.21326067e+01, -6.88605527e+01, -1.21427939e+01,\n",
+ " -9.00152897e+02, 3.73341568e+02, 4.02988610e+01, -2.90163692e+01,\n",
+ " -2.33802619e+02, 2.47549209e+02, 1.02884364e+02, -8.05420329e+01,\n",
+ " -5.54586137e+02, -1.99658437e+02, -8.78917263e+02, 9.52892481e+01,\n",
+ " 6.46478189e+02, 1.04216103e+02, 1.69069879e+02, -1.67364855e+02,\n",
+ " 1.03502358e+02, -5.57970384e+01, 1.81731216e+01, 3.41954874e+02,\n",
+ " 1.59680564e+02, -2.82358640e+01, 2.61037844e+02, 4.52394889e+00,\n",
+ " 9.81410028e+01, 2.71677980e+02, -1.05795831e+02, 3.66139152e+02,\n",
+ " 8.95496317e+00, 1.00358159e+01, 2.17295623e+02, -2.31537428e+02,\n",
+ " -2.15978970e+01, -2.79191160e+02, 2.19354022e+02, -6.59585594e+01,\n",
+ " -5.58244609e+01, -5.51836688e+01, -1.12465523e+03, -7.02358334e+01,\n",
+ " 3.10641629e+01, -1.13231174e+02, 3.15678592e+01, 1.96942607e+02,\n",
+ " 5.48713097e+02, -1.40086969e+02, -2.22228352e+01, 1.17398879e+02,\n",
+ " 4.37152834e+02, -6.21798981e+01, -1.64561657e+02, -2.57025128e+02,\n",
+ " 4.11699329e+01, 2.60826779e+02, -1.45360428e+02, -1.61680577e+03,\n",
+ " 2.47927887e+01, -1.41960640e+01, -2.70745623e+02, -2.79026588e+02,\n",
+ " -3.78512694e+02, -1.41108300e+02, -2.24785998e+00, 3.30243719e+02,\n",
+ " -7.03044201e+02, 1.64922389e+02, 7.33348145e+02, 1.47303755e+02,\n",
+ " 7.06221907e+02, -6.94255302e+00, 3.38171017e+02, 5.04749210e+02,\n",
+ " 3.22640657e+01, 1.68166245e+02, -2.14026135e+02, -1.62346444e+02,\n",
+ " -1.46112249e+02, 1.54855691e+02, 3.71393526e+01, -1.04704816e+02,\n",
+ " 2.16268193e+01, 3.38807780e+02, -3.04966547e+02, 2.58482151e+02,\n",
+ " -3.31159399e+00, -7.49718016e-01, -1.62154661e+02, -3.83102076e+01,\n",
+ " 2.42428727e+02, -1.53997929e+02, -3.45879227e+02, -3.07630457e+01,\n",
+ " 3.46320238e+01, 1.00201240e+02, 2.11664303e+02, 5.90214281e+02,\n",
+ " -5.96766676e+02, -8.23427152e+01, 1.63642467e+02, 4.74277645e+02,\n",
+ " -1.70621908e+02, 7.62292684e+01, 3.56639038e+01, -4.60248962e+02,\n",
+ " 2.48614482e+02, 1.21550784e+02, 5.58497353e+01, -5.89683481e+01,\n",
+ " -5.02041594e+02, 1.96406334e+01, -3.06529530e+02, -5.09020350e+01,\n",
+ " -1.23702466e+02, -1.72287679e+02, -3.65591655e+01, -5.50474970e+02,\n",
+ " -6.94567467e+01, -5.34370089e+02, -3.45226236e+00, -1.03854627e+02,\n",
+ " 2.36692364e+02, 7.02179892e+02, -7.61028259e-01, -4.07904227e+02,\n",
+ " -9.75531587e+01, 2.82244405e+02, 1.41124894e+01, 1.09383800e+02,\n",
+ " 5.63751727e+01, -7.11996539e+00, 1.98468044e+02, 2.41003767e+01,\n",
+ " 2.92926186e+01, 1.21397993e+01, -3.89796531e+01, -1.04284892e+01,\n",
+ " -1.79177479e+02, 2.15083225e+02, -3.78225255e+02, 1.70771921e+02,\n",
+ " -2.82337234e+01, 2.67594644e+02, -5.56561383e+02, 5.37376020e+00,\n",
+ " -4.10702199e+01, -4.17935263e+02, -1.11855531e+02, 1.54278095e+01,\n",
+ " -6.38703309e+01, -9.66506324e+01, -2.11503155e+01, -3.61509327e+02,\n",
+ " 2.95453414e+02, -1.31214255e+02, -9.17096183e+02, 1.45599246e+02,\n",
+ " 1.28740325e+02, -5.60369466e+02, 3.07651833e+02, -2.62805720e+02,\n",
+ " 1.76900643e+01, 6.23344322e+01, 4.04255109e+02, 4.35344447e+02,\n",
+ " -2.94036451e+01, 1.90840844e+02, -6.69218772e+02, -1.65055963e+01,\n",
+ " 9.08236248e+01, 8.04144817e+01, -2.34742058e+02, -1.99511017e+00,\n",
+ " 1.78905650e+01, -1.85270017e+01, 8.34858627e+01, -2.71532337e+01,\n",
+ " -1.61634039e+02, -2.18522363e+01, 6.41564793e+01, -1.17988135e+02,\n",
+ " -7.91123583e+01, 3.50110433e+02, 2.85223270e+00, -2.57747131e+01,\n",
+ " 5.71105511e+02, 4.87136833e+02, 6.20880696e+02, 4.07236369e+01,\n",
+ " 1.55305175e+02, 4.07285221e+02, -1.54810757e+02, -1.14964559e+02,\n",
+ " -3.23552775e+01, -2.23286233e+01, 2.14934372e+02, -2.59321590e+02,\n",
+ " 9.60100864e+00, -2.24871393e+01, 4.33309302e+02, 1.23502172e+02,\n",
+ " -3.93325256e+02, 3.66260206e+02, -1.03455591e+02, 2.19379838e+01,\n",
+ " -5.38480193e+02, -7.84929402e+02, 3.65232859e+01, 1.10127050e+02,\n",
+ " 7.26935632e+02, -4.98308610e+02, 4.05822748e+01, -2.23530352e+02,\n",
+ " 1.78072720e+01, -1.37540927e+02, -9.90726995e+01, 1.04018270e+02,\n",
+ " -2.91227335e+02, -8.22906489e+01, -5.01549706e+02, 1.20884921e+02,\n",
+ " 8.48964873e+01, 2.24099340e+02, -8.24509118e+01, 1.77549928e+02,\n",
+ " -1.17315958e+02, -6.76135021e+01, -1.88175420e+02, -1.44918835e+02,\n",
+ " 5.92201253e+02, -1.01987421e+03, 1.09268459e+03, 9.33464461e+00,\n",
+ " 2.01651359e+02, 2.37049175e+01, -6.92801624e+02, -6.83773347e+01,\n",
+ " 2.18194258e+02, 3.00241782e+02, -3.06167625e+02, 1.61619599e+02,\n",
+ " -1.79403848e+03, 1.02841722e+00, 4.01435084e+02, -1.20904004e+03,\n",
+ " -2.60210821e+02, -5.86922855e+01, 1.29339085e+01, 7.17676622e+02,\n",
+ " -7.26463985e+02, 6.19181585e+01, 2.06696166e+02, -1.44014718e+01,\n",
+ " -4.62302557e+01, 3.53429384e+01, 9.51814255e+01, 4.81736997e+01,\n",
+ " 8.03370733e+02, -1.21790083e+02, -5.13789734e+01, -2.13694805e+02,\n",
+ " 2.89897409e+02, 5.45625619e+01, -5.39667776e+01, 1.81784077e+01,\n",
+ " 4.08197700e+01, -2.49524624e+02, -1.78084064e+01, -3.85273563e+02]))"
+ ]
+ },
+ "execution_count": 50,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def std_ipw_did_rc(y, post, D, covariates, i_weights = None):\n",
+ " D = np.asarray(D).flatten()\n",
+ " y = np.asarray(y).flatten()\n",
+ " post = np.asarray(post).flatten()\n",
+ " n = len(D)\n",
+ " if covariates is None:\n",
+ " int_cov = np.ones((n, 1))\n",
+ " else:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " # Pesos\n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " else:\n",
+ " i_weights = np.asarray(i_weights)\n",
+ " if np.min(i_weights) < 0:\n",
+ " raise ValueError(\"i_weights must be non-negative\")\n",
+ " \n",
+ " # Normalizar pesos\n",
+ " i_weights = np.asarray(i_weights).flatten()\n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ "\n",
+ " pscore_model = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights)\n",
+ " pscore_results = pscore_model.fit()\n",
+ " if not pscore_results.converged:\n",
+ " print(\"Warning: glm algorithm did not converge\")\n",
+ " if np.any(np.isnan(pscore_results.params)):\n",
+ " raise ValueError(\"Propensity score model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " ps_fit = pscore_results.predict()\n",
+ " ps_fit = np.minimum(ps_fit, 1 - 1e-16)\n",
+ "\n",
+ "\n",
+ " w_treat_pre = i_weights * D * (1 - post)\n",
+ " w_treat_post = i_weights * D * post\n",
+ " # print(np.mean(w_treat_pre))\n",
+ "\n",
+ " w_cont_pre = i_weights * ps_fit * (1 - D) * (1 - post)/(1 - ps_fit)\n",
+ " w_cont_post = i_weights * ps_fit * (1 - D) * post/(1 - ps_fit)\n",
+ "\n",
+ " # Elements of the influence function (summands)\n",
+ " eta_treat_pre = w_treat_pre * y / np.mean(w_treat_pre)\n",
+ " eta_treat_post = w_treat_post * y / np.mean(w_treat_post)\n",
+ " # print(eta_treat_pre)\n",
+ "\n",
+ " eta_cont_pre = w_cont_pre * y / np.mean(w_cont_pre)\n",
+ " eta_cont_post = w_cont_post * y / np.mean(w_cont_post)\n",
+ "\n",
+ " # Estimator of each component\n",
+ " att_treat_pre = np.mean(eta_treat_pre)\n",
+ " att_treat_post = np.mean(eta_treat_post)\n",
+ " att_cont_pre = np.mean(eta_cont_pre)\n",
+ " att_cont_post = np.mean(eta_cont_post)\n",
+ " ipw_att = (att_treat_post - att_treat_pre) - (att_cont_post - att_cont_pre)\n",
+ "\n",
+ " score_ps = (i_weights * (D - ps_fit))[:, np.newaxis] * int_cov\n",
+ " Hessian_ps = pscore_results.cov_params() * n\n",
+ " asy_lyn_rep_ps = np.dot(score_ps, Hessian_ps)\n",
+ "\n",
+ " inf_treat_pre = eta_treat_pre - w_treat_pre * att_treat_pre/np.mean(w_treat_pre)\n",
+ " inf_treat_post = eta_treat_post - w_treat_post * att_treat_post/np.mean(w_treat_post)\n",
+ " inf_treat = inf_treat_post - inf_treat_pre\n",
+ " # Now, get the influence function of control component\n",
+ " # Leading term of the influence function: no estimation effect\n",
+ " inf_cont_pre = eta_cont_pre - w_cont_pre * att_cont_pre/np.mean(w_cont_pre)\n",
+ " inf_cont_post = eta_cont_post - w_cont_post * att_cont_post/np.mean(w_cont_post)\n",
+ " inf_cont = inf_cont_post - inf_cont_pre\n",
+ "\n",
+ " # Estimation effect from gamma hat (pscore)\n",
+ " # Derivative matrix (k x 1 vector)\n",
+ " \n",
+ " M2_pre = np.mean((w_cont_pre *(y - att_cont_pre))[:, np.newaxis] * int_cov, axis = 0)/np.mean(w_cont_pre)\n",
+ " M2_post = np.mean((w_cont_post *(y - att_cont_post))[:, np.newaxis] * int_cov, axis = 0)/np.mean(w_cont_post)\n",
+ "\n",
+ " # Now the influence function related to estimation effect of pscores\n",
+ " M2 = M2_post - M2_pre\n",
+ " # print()\n",
+ "\n",
+ " inf_cont_ps = np.dot(asy_lyn_rep_ps, M2)\n",
+ "\n",
+ " # Influence function for the control component\n",
+ " inf_cont = inf_cont + inf_cont_ps\n",
+ "\n",
+ " #get the influence function of the DR estimator (put all pieces together)\n",
+ " att_inf_func = inf_treat - inf_cont\n",
+ " # print(np.std(att_inf_func) / np.sqrt(n))\n",
+ " return ipw_att, att_inf_func\n",
+ "\n",
+ "std_ipw_did_rc(y, post, d, x, w)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "58.43262881229079\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(-93.44964719410004,\n",
+ " array([-1771.69299088, 2734.43642941, -665.22558666, -1301.21607242,\n",
+ " -996.11260525, -2158.69871885, -504.89793626, 2423.86068757,\n",
+ " 1877.09030231, -719.10776276, 1235.01984081, -3215.35303448,\n",
+ " 2386.43646889, -2175.74096025, -1086.57851202, 2126.26135247,\n",
+ " -1623.59474401, 1273.22968698, 894.95899003, 1285.88717946,\n",
+ " 2013.14752053, -1315.44015778, -1026.47833569, -2502.18129446,\n",
+ " 1564.16066321, -1200.80571295, 3687.54887082, -2585.69577723,\n",
+ " 1737.14743007, 2156.46916393, -1131.66288093, 1072.00424587,\n",
+ " 356.97367802, 351.36368266, 1595.18126677, -1323.98704545,\n",
+ " -1255.93969133, -1506.35891117, -1584.95043603, 479.97949402,\n",
+ " -3282.15157163, 710.78455687, -1417.11138741, 2501.57210618,\n",
+ " 961.07317252, -1683.69307958, 296.03996423, 879.46613051,\n",
+ " 2340.35863086, -1611.71256146, 1097.07919423, -1940.81707908,\n",
+ " 480.70920993, -2786.26114977, -1162.92145886, -1472.05666335,\n",
+ " -2044.25184954, 2286.32069912, 1583.0849025 , -2028.31435029,\n",
+ " 3387.80665047, 1110.71952821, -2057.90124649, -1406.70400881,\n",
+ " -2678.62125054, -3218.86415708, 652.36081771, 3045.04627911,\n",
+ " -779.92134306, -918.74720028, -1008.58184304, -1638.01020249,\n",
+ " 1044.4985155 , -2854.04327601, 3081.17614328, -745.83804107,\n",
+ " 396.63514827, 2072.74250936, -2722.50028153, 2746.04710334,\n",
+ " 1801.34758228, -532.35606628, 741.95434407, -1596.78037149,\n",
+ " -3440.17039739, -1382.50573776, 3897.58714348, -363.61264523,\n",
+ " 2661.17983552, 1127.77692574, -1421.84132741, 1423.90313571,\n",
+ " -1279.7410341 , -839.98810695, 1002.53754141, -1717.11890703,\n",
+ " 2595.86686095, -2817.55589014, -1252.75607073, -3241.93280724,\n",
+ " 2923.11415872, -506.47139114, 637.97958752, 567.53280982,\n",
+ " -1910.88512632, -1943.88181593, -1723.88632985, -889.4782608 ,\n",
+ " 3307.81074752, -430.22712697, -2175.07180988, -2366.77239088,\n",
+ " 1071.63637084, -2592.81371937, -1204.87458143, -761.22689489,\n",
+ " 955.06351376, -1032.88264276, 2702.83072888, 2740.5690063 ,\n",
+ " -1551.13859133, -471.81781504, 860.49229409, 1899.83981132,\n",
+ " 876.9599662 , 1328.49935293, -1910.36968312, 2249.22471819,\n",
+ " -1017.36949851, -1714.11930645, 1132.73999958, 2888.07614525,\n",
+ " -659.80356625, 1173.16566146, 595.16853259, -2558.92960279,\n",
+ " -2264.23696816, 519.2876513 , -1916.37016546, 3118.98680212,\n",
+ " -2034.85921318, 2095.69321434, -3617.08419665, -2631.39553407,\n",
+ " -2571.03485392, -910.25542094, 1913.24207773, 3007.01570665,\n",
+ " 3025.61889462, -2151.74511183, 2767.1282508 , 1406.43983444,\n",
+ " 2156.09969918, 1994.23513548, 3000.74347992, 3206.93715676,\n",
+ " -1432.32320602, 1352.73592906, 1081.38155669, -2177.61249715,\n",
+ " -675.7825555 , -1084.11343382, 671.41465633, 1144.15683623,\n",
+ " -848.69247153, 386.14175933, 1450.88910053, 2569.72607252,\n",
+ " -1742.31181667, 603.72965801, -1113.60962826, -2370.21161651,\n",
+ " 2305.00989528, -342.88359848, 2745.99316412, 1234.36827675,\n",
+ " -2204.85554135, -3649.96418942, 2644.48638617, 1640.7989497 ,\n",
+ " -1303.29901877, -963.13849624, 650.26714886, 2344.49205591,\n",
+ " -2444.03887982, -3490.54710953, 1657.83573886, -429.76691451,\n",
+ " 971.43999219, 1444.16142726, -1332.6761587 , -753.80278181,\n",
+ " -1295.66288667, -2203.06274595, -3739.74584698, 2908.85465725,\n",
+ " 788.80966951, -846.9578562 , -2006.21858656, 1062.26460843,\n",
+ " 1308.97334956, 2622.57838299, 2281.59544952, -2610.41236121,\n",
+ " -2570.43192466, 2085.31576822, -1089.99962694, 3009.50651827,\n",
+ " 3190.08242902, 764.85346096, 2166.44286327, -1314.31850717,\n",
+ " 1529.8660305 , 533.55919085, 1116.50987114, -1798.17159756,\n",
+ " -867.24220256, -2044.24229115, -2265.82185981, 714.85212421,\n",
+ " 643.65564431, 2969.32169063, -1467.93924844, 3203.34567165,\n",
+ " -799.35535782, -766.01252567, -2851.68746347, 2830.6823963 ,\n",
+ " 1489.70089497, -1690.40329366, 1916.25499695, 1560.59373446,\n",
+ " 464.75321798, -407.90706297, 1968.50451478, 2158.57889193,\n",
+ " -1245.62717245, -2712.67424608, -1575.87355314, 609.06412613,\n",
+ " -1759.56301462, -1539.59640257, 2102.14988929, -2585.69716551,\n",
+ " 1948.35691064, -959.7909776 , -1303.17153457, 724.79954286,\n",
+ " -914.26648645, -1740.50412275, 808.26769145, 1393.42103493,\n",
+ " 1454.41915472, 1512.98477357, 1630.14522196, 256.1792697 ,\n",
+ " 1724.28939298, 2475.32212828, 451.65470165, -907.42403718,\n",
+ " -2189.70112717, 460.74715837, -421.22513225, 308.73990714,\n",
+ " -2213.25438291, -603.96167028, -734.84133319, 1243.6661329 ,\n",
+ " -1420.68978323, -1876.64977892, -1801.50602087, -2145.50279721,\n",
+ " -2555.98350273, -3278.33626392, 395.77242611, 1913.10276 ,\n",
+ " -1680.82368859, 2655.53100648, -1527.69643643, -1477.53090956,\n",
+ " 1821.37533877, 1504.82500691, -1299.47404806, 672.45750055,\n",
+ " -365.56147886, 1140.0531407 , -1758.25206157, 2356.32694942,\n",
+ " -424.46646898, -1204.27183857, 1081.15991819, -1956.55465717,\n",
+ " 450.51448263, -2615.50361028, -1092.3959512 , 2793.68079598,\n",
+ " -1053.32122609, -1517.06659907, 905.11156725, 2643.29640911,\n",
+ " 1495.81294626, 1085.28964795, -2485.92006447, -2173.97112266,\n",
+ " -2381.42989767, 1258.42796446, -1332.37806862, 2765.94944066,\n",
+ " 2088.2365377 , -3003.55317999, 4173.01276016, -1678.95094716,\n",
+ " -455.41604475, -1318.60508958, 2403.99022211, 504.96662006,\n",
+ " -462.95551358, 1660.26835734, 1114.11366235, 983.13432758,\n",
+ " -1161.18648989, 588.48849558, -429.87369026, 341.12429041,\n",
+ " 2662.61161797, -536.42169105, 759.12786241, 868.05783143,\n",
+ " -2778.53688011, -2884.55161881, 553.62567528, -2018.53788345,\n",
+ " -2858.37482622, -1433.53267886, -1778.64792077, 485.02888305,\n",
+ " 1698.3260984 , -3575.30266263, -2403.34176964, 339.80723682,\n",
+ " 586.40604317, 3130.8978188 , 651.03726464, -580.79436868,\n",
+ " 2353.15871734, -2445.08321516, 3082.38448848, -362.89805554,\n",
+ " 2875.3473652 , 2204.04535845, -688.42084612, -1864.67857861,\n",
+ " 2900.2178106 , 378.46948172, 937.00118321, 3194.60239319,\n",
+ " -1032.75332662, -536.05481574, 2405.77212303, 2985.79823855,\n",
+ " -784.06420919, 4609.19316311, -2189.82760429, 1214.18891704,\n",
+ " -812.09432391, 1172.85932583, -1766.4023341 , -900.96381475,\n",
+ " 2242.57313862, 2355.53766688, 575.65560024, -1565.24322515,\n",
+ " -1331.95499438, -1304.9072576 , 1506.03805617, 2376.65417998,\n",
+ " 2703.48351205, 442.53250644, 2003.14669993, 1271.05057831,\n",
+ " -1538.75926454, 856.16519565, 3624.06062374, -1764.8704627 ,\n",
+ " 847.16068531, 2340.86345312, 3199.28740498, 1475.26770119,\n",
+ " -905.48513765, 948.98156662, -822.15967202, -1153.1708414 ,\n",
+ " 1541.01197327, -3311.16047334, -2661.21217243, -977.42835906,\n",
+ " -1443.49749562, -1887.01249207, 1845.96233941, -1179.54920403,\n",
+ " -375.01127251, -2568.90591656, 351.56782753, -2224.62551787,\n",
+ " -598.64519739, 1545.03596115, -2542.72902914, -1599.94376529,\n",
+ " -2052.16500216, -603.26139478, 2463.15994025, -1625.09633424,\n",
+ " 333.32346256, -1561.75623817, 2156.56043675, 2510.56220988,\n",
+ " -693.89422315, 2601.16518784, 606.82485323, -1983.401254 ,\n",
+ " 1754.43848215, -410.28645049, 1836.57321914, 1047.35474469,\n",
+ " 2385.83891047, 1092.15603067, 507.46176433, 3062.09623731,\n",
+ " -545.54371933, -2334.39724187, 2410.80683371, -1352.66570528,\n",
+ " 2542.46775923, 524.69553963, -765.339492 , -2263.89864301,\n",
+ " -1710.19982826, -3646.41440575, 1664.13224425, -1885.57754739,\n",
+ " -1046.96997217, 587.09563654, 2669.69295403, -767.73946223,\n",
+ " 2161.55583907, 2287.47135902, 2369.25282143, 1667.64698903,\n",
+ " 1559.33170716, -1178.39071442, 1650.65992341, -799.17153424,\n",
+ " -2868.57228524, 444.47063642, 2803.79276211, 1895.16914585,\n",
+ " 696.29703029, 2566.55376603, 495.34810817, -2186.26279686,\n",
+ " 677.99602616, -1816.09040711, -2416.70597798, -2553.72783302,\n",
+ " -478.28438083, 754.45182451, -2748.81453694, -639.64827681,\n",
+ " -2535.42889686, 1177.98353196, -3068.71399971, -1019.61515784,\n",
+ " -3048.25211509, -1609.76461417, -1888.30365064, -2762.19248673,\n",
+ " -2528.53702823, -1186.43064416, -865.59633802, -1957.89339654,\n",
+ " -442.05069064, -899.91738041, 1495.42879268, -1804.98548029,\n",
+ " 2446.11646989, 2657.68154663, 1043.36466273, 552.90529896,\n",
+ " 311.32975222, -2455.00385902, -508.36435874, 2139.19610847,\n",
+ " -647.88162837, -1202.79276093, 1697.10789673, 981.00194261,\n",
+ " 3170.69024745, -2628.37497646, 356.01933459, 1634.03648656,\n",
+ " 2412.19222403, -719.76688947, -2312.19500813, 2744.36935991,\n",
+ " -735.77553181, -2274.10892643, 2565.18991986, 2532.44181051,\n",
+ " 684.50392772, -1410.45291576, -1994.33913686, -1446.32324674,\n",
+ " 2822.83702553, 3937.06593083, 2468.07405753, 1791.03567507,\n",
+ " -756.08749761, -2199.33347498, -1676.60694465, 3491.89175595,\n",
+ " -3105.77111045, 970.48961188, -946.67778364, 2920.33278331,\n",
+ " -1159.77126982, 1504.4459745 , -1354.11130131, -580.52625076,\n",
+ " -1222.95118144, 2708.57543379, 2202.56498598, 1622.58859998,\n",
+ " -1983.03234192, 3329.45612628, -1864.86168765, 1378.55737431,\n",
+ " 2041.72272001, -826.98003232, 1577.02595384, -2289.64597967,\n",
+ " -856.11886911, 2562.60117477, -2144.81715063, 353.35405689,\n",
+ " 908.56860009, 1987.47285836, 1502.37986964, -697.1416503 ,\n",
+ " -1734.38123668, -1823.57289635, 294.70877546, -2388.92607047,\n",
+ " 1017.25631052, -1140.19323544, 2656.20026254, -2829.71657145,\n",
+ " 642.54135073, -1011.70165571, -1647.27548212, -2308.71174696,\n",
+ " 2575.02774246, -866.04851779, -1719.62779586, -1874.43270201,\n",
+ " 3681.34145332, -3322.72703815, -3024.31679437, -2532.43180244,\n",
+ " 2737.34291185, -660.74945542, 2183.84062155, -2658.1067039 ,\n",
+ " -2596.11057461, 809.02154594, -650.23456129, -1624.33639548,\n",
+ " -1164.74135733, 1499.08242271, -1013.25201658, -1060.74435259,\n",
+ " 750.91869435, -1150.31342707, -1473.68830892, -1088.46282605,\n",
+ " -2365.35503201, 1111.76443617, 1315.14836662, 1863.34049266,\n",
+ " -1899.20814734, -988.24238211, -879.23713785, 1316.9636844 ,\n",
+ " -2290.80759974, 569.66592613, -2092.63661774, 1442.33823625,\n",
+ " -1576.19856555, 2617.11949261, 2030.79638168, 2369.54069801,\n",
+ " -1847.84465731, -741.1120735 , 323.20649614, -651.82527562,\n",
+ " 728.76217588, -1269.07621354, -607.5497075 , 1416.8947446 ,\n",
+ " 454.80499906, -2505.18504699, 530.65238647, 830.56199117,\n",
+ " 3433.65677962, 1425.0187603 , 714.52859263, -1132.43794541,\n",
+ " 2300.08751858, 1826.27724093, 1267.35917873, -441.30955772,\n",
+ " -475.7591027 , 2861.20592043, -1484.1117347 , 1913.40673523,\n",
+ " -814.54259857, 469.18855574, 891.96081434, 377.0791031 ,\n",
+ " -1035.84329488, 906.99092104, 2486.31144075, -3295.88774417,\n",
+ " 2555.93528014, -494.01730112, -1108.2264043 , 1048.84097862,\n",
+ " 2913.52968468, -1624.32995923, -1796.52527911, -1372.06527286,\n",
+ " -765.25621949, -722.7471685 , 569.36631321, 2710.84957177,\n",
+ " -1458.16249395, 1716.51616268, -550.09465093, 2837.75069768,\n",
+ " 2016.72399953, 2257.73249603, -1708.8209968 , 849.37799004,\n",
+ " -2831.08801977, -2351.49913077, 1996.3935787 , -1693.96751042,\n",
+ " 794.03089047, -1673.57803373, -741.35710937, -1951.26893057,\n",
+ " 1561.07075703, 441.3302186 , -1567.98078095, 950.83826894,\n",
+ " 1661.30145139, 921.12211371, 1882.1189721 , -1625.38348328,\n",
+ " 653.50050604, 2473.2760197 , 2231.9105723 , 2783.81258029,\n",
+ " 2414.88965068, 1835.90500441, -1662.06288499, -1519.51805278,\n",
+ " 2741.25734525, -2277.9679907 , -2060.33342696, -1824.66604277,\n",
+ " 302.3272455 , 1215.04485318, 1168.224856 , 915.47045394,\n",
+ " -2122.63277088, -1410.46872314, -3336.77744929, -1745.75890027,\n",
+ " 860.07675975, 1258.10506568, 2577.31652841, 1681.81791276,\n",
+ " 1309.25233734, -783.43695593, -448.60634624, 1996.51678873,\n",
+ " 3078.01361844, 2265.95704372, 2528.44053527, 559.93178238,\n",
+ " 2944.28585989, 1826.43705531, 957.49292267, -1469.44176626,\n",
+ " -2003.65534759, -1040.86656379, 3115.46640107, 527.03633508,\n",
+ " 781.00070477, 487.05912006, -513.64148896, -3396.40469512,\n",
+ " -1856.83752144, -2108.12042935, 2923.28653429, -3006.52285973,\n",
+ " -1212.78161963, -928.93614961, 404.872869 , -762.16508895,\n",
+ " -1833.6419319 , -1919.10471841, -2897.34765784, 3361.32067031,\n",
+ " -1955.86951684, 1734.33682574, 994.0640385 , 2386.79849447,\n",
+ " -1735.24480106, 2597.82838399, -788.81322366, 370.18836641,\n",
+ " -3327.20076265, -1798.21335112, -2824.14666615, 680.54087364,\n",
+ " -1113.6554537 , 1884.28686177, -510.1376147 , -2470.79456479,\n",
+ " -2927.71232689, 1710.53465483, -2856.8622993 , -652.02321833,\n",
+ " -1802.82632925, -1283.31201584, 941.3521787 , -1344.1623153 ,\n",
+ " -1675.74071669, -975.09704843, -434.70286598, 3513.47256683,\n",
+ " 2777.66740579, -441.26195892, -2169.81461542, -564.7000461 ,\n",
+ " -494.71483343, -1860.05799051, 2307.80025783, -1756.46255846,\n",
+ " 642.91109776, 1232.7316785 , 2530.94424384, -1571.77413697,\n",
+ " -540.85685014, -1593.04834694, 2573.78401449, 311.21633873,\n",
+ " 1501.80362327, -2079.77918752, -2436.46543536, -1401.19459548,\n",
+ " 453.63997673, -617.65832347, 919.36365701, -1473.90999317,\n",
+ " -1284.48864767, -2021.09052361, 1532.70052602, 963.77131646,\n",
+ " 931.49932067, 921.17721886, 1265.23297379, -2248.82552704,\n",
+ " 919.54082419, -1359.44096363, 1853.1752116 , 1363.02067761,\n",
+ " -770.62764697, 1270.73401122, -2595.04839828, 1596.44932557,\n",
+ " 1464.85098141, -1474.04261825, 417.41885847, 1469.5597139 ,\n",
+ " -2985.64231407, 1964.54960175, -1791.98505515, -527.52890739,\n",
+ " -1162.83120919, 427.60263726, -1242.71304596, 2575.90846261,\n",
+ " -785.79583731, 700.98451479, 659.91797646, 2602.51687481,\n",
+ " -1832.62997604, -582.64374442, 2072.16886831, -1606.82356069,\n",
+ " -2155.54339974, -1622.96028635, -752.5546088 , 3164.16272958,\n",
+ " -2873.74267056, -436.74210099, 2829.62246934, -1319.40937026,\n",
+ " -767.4190147 , -1726.35412953, 1314.02278506, -518.05271912,\n",
+ " 956.73527718, 2780.08228437, 1067.59408967, -1347.81798416,\n",
+ " 2488.54528479, 2388.17361517, -1932.92084631, 2392.93029926,\n",
+ " 691.98121942, 2863.45441941, 1290.97173989, -2213.26160583,\n",
+ " 1042.88699367, -1679.1695782 , -1821.95672579, 1945.0379768 ,\n",
+ " 1335.76129379, -1446.96325048, 1208.82534726, 2325.9144496 ,\n",
+ " 1426.56423789, 2560.07670029, -1130.71735917, 2147.1279272 ,\n",
+ " -1885.66796193, -3171.54041311, -515.97146345, -1934.95957315,\n",
+ " -1130.18405187, 1814.85040605, 396.42445474, 1153.22336642,\n",
+ " -2387.38838234, -984.2983658 , -1656.79194671, 512.97694184,\n",
+ " -464.40722209, -670.85793403, -1380.10152566, -1075.59505428,\n",
+ " 1020.43896491, -696.74128739, -1981.59398368, -564.85351604,\n",
+ " 471.65862672, 2045.94510625, 2561.04321739, -2476.73595212,\n",
+ " -2184.53211076, -1590.7768613 , 1056.19643327, 326.89026411,\n",
+ " -786.49596962, -1803.26291599, -1788.53823174, -2569.97154594,\n",
+ " -1157.77515717, 2762.93260682, -2643.43033439, -2121.08092166,\n",
+ " 1839.50106707, -1524.93779041, 1217.75295729, 1849.85201945,\n",
+ " 2445.65124093, -2809.46516162, -1000.19327694, -1200.04118698,\n",
+ " -1395.01561181, 1451.33231139, -2237.96197975, 1664.89527685,\n",
+ " 372.72464219, 1406.94334946, -2601.80452742, -1131.10906965,\n",
+ " -413.59031804, -430.98571375, 2428.44444729, -577.42581008,\n",
+ " 460.88030944, -399.57537384, 361.63566071, -739.53978394,\n",
+ " 2068.4774553 , 317.50373164, 453.76941584, -2051.34309943,\n",
+ " 2501.20300174, -1872.30264744, -605.33685932, -895.12840218,\n",
+ " 3457.74512372, -874.26408381, -2000.5332622 , -375.61529654,\n",
+ " -1771.732175 , 3513.39079202, 1898.53613931, -2123.43982302,\n",
+ " -1855.55908987, 403.26111243, -2186.35008812, -2010.10870557,\n",
+ " 1670.73175174, -682.83932791, -1824.05052817, 1315.67758341,\n",
+ " 2238.63825971, -754.07248481, -2128.8740116 , -1117.82814274,\n",
+ " 1275.4669288 , -2534.56242268, 346.23040278, -778.57175669,\n",
+ " 2884.451503 , 2067.40660077, -2539.8165362 , 1673.97197473,\n",
+ " -520.57617721, 2652.97964586, 2304.44022309, -526.46684596,\n",
+ " -1309.85679365, 2055.42385252, 2041.40731586, -1610.16123777,\n",
+ " -992.10449941, 2816.87554037, 1081.38695232, 1547.77978543,\n",
+ " -1557.21338002, 641.54991226, -1471.45433787, 1901.0064566 ,\n",
+ " 3213.54494144, 1919.02613766, 4075.31988024, -449.24630456,\n",
+ " 2442.58946249, -2637.33989739, 510.31545751, -719.86175251,\n",
+ " -1165.06095127, -2464.13179681, 2520.23301323, -1140.28681108,\n",
+ " -4411.26836428, -553.01310133, -1211.9205527 , -3114.2194403 ,\n",
+ " 2030.91966949, 1425.94939427, 1351.07431513, -2025.98523337,\n",
+ " -3237.20378818, -2167.40583612, 3159.90648033, -1769.30458385,\n",
+ " -831.58913516, -1495.04530792, -1899.34493108, -1122.85293137,\n",
+ " 2281.7356933 , -1364.98996535, -2494.84612274, 1148.39416309,\n",
+ " 2718.9386729 , 1476.66330731, 541.5016123 , 394.50840119,\n",
+ " 2058.86333601, -2794.09787517, 1051.44981086, 1157.3380883 ]))"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import statsmodels.api as sm\n",
+ "\n",
+ "def ipw_did_rc(y, post, D, covariates=None, i_weights=None):\n",
+ " # D, y, post como arrays numpy\n",
+ " D = np.asarray(D).flatten()\n",
+ " y = np.asarray(y).flatten()\n",
+ " post = np.asarray(post).flatten()\n",
+ " n = len(D)\n",
+ " \n",
+ " # Agregar constante a covariates si existe\n",
+ " if covariates is None:\n",
+ " int_cov = np.ones((n, 1))\n",
+ " else:\n",
+ " covariates = np.asarray(covariates)\n",
+ " if np.all(covariates[:, 0] == 1):\n",
+ " int_cov = covariates\n",
+ " else:\n",
+ " int_cov = np.column_stack((np.ones(n), covariates))\n",
+ " \n",
+ " # Pesos\n",
+ " if i_weights is None:\n",
+ " i_weights = np.ones(n)\n",
+ " else:\n",
+ " i_weights = np.asarray(i_weights)\n",
+ " if np.min(i_weights) < 0:\n",
+ " raise ValueError(\"i_weights must be non-negative\")\n",
+ " \n",
+ " # Normalizar pesos\n",
+ " i_weights = np.asarray(i_weights).flatten()\n",
+ " i_weights = i_weights / np.mean(i_weights)\n",
+ " \n",
+ " # Estimación de Propensity Score (logit)\n",
+ " pscore_model = sm.GLM(D, int_cov, family=sm.families.Binomial(), freq_weights=i_weights)\n",
+ " pscore_results = pscore_model.fit()\n",
+ " if not pscore_results.converged:\n",
+ " print(\"Warning: glm algorithm did not converge\")\n",
+ " if np.any(np.isnan(pscore_results.params)):\n",
+ " raise ValueError(\"Propensity score model coefficients have NA components. \\n Multicollinearity (or lack of variation) of covariates is a likely reason.\")\n",
+ " ps_fit = pscore_results.predict()\n",
+ " ps_fit = np.minimum(ps_fit, 1 - 1e-16)\n",
+ " \n",
+ " # Calcular IPW estimator\n",
+ " w_treat_pre = i_weights * D * (1 - post)\n",
+ " w_treat_post = i_weights * D * post\n",
+ " w_cont_pre = i_weights * ps_fit * (1 - D) * (1 - post) / (1 - ps_fit)\n",
+ " w_cont_post = i_weights * ps_fit * (1 - D) * post / (1 - ps_fit)\n",
+ " \n",
+ " Pi_hat = np.mean(i_weights * D)\n",
+ " lambda_hat = np.mean(i_weights * post)\n",
+ " one_minus_lambda_hat = np.mean(i_weights * (1 - post))\n",
+ " \n",
+ " eta_treat_pre = w_treat_pre * y / (Pi_hat * one_minus_lambda_hat)\n",
+ " eta_treat_post = w_treat_post * y / (Pi_hat * lambda_hat)\n",
+ " eta_cont_pre = w_cont_pre * y / (Pi_hat * one_minus_lambda_hat)\n",
+ " eta_cont_post = w_cont_post * y / (Pi_hat * lambda_hat)\n",
+ " \n",
+ " att_treat_pre = np.mean(eta_treat_pre)\n",
+ " att_treat_post = np.mean(eta_treat_post)\n",
+ " att_cont_pre = np.mean(eta_cont_pre)\n",
+ " att_cont_post = np.mean(eta_cont_post)\n",
+ " \n",
+ " ipw_att = (att_treat_post - att_treat_pre) - (att_cont_post - att_cont_pre)\n",
+ " \n",
+ " # Cálculo de la función de influencia\n",
+ " score_ps = i_weights[:, np.newaxis] * (D - ps_fit)[:, np.newaxis] * int_cov\n",
+ " Hessian_ps = pscore_results.cov_params() * n\n",
+ " asy_lin_rep_ps = np.dot(score_ps, Hessian_ps)\n",
+ " \n",
+ " inf_treat_post = (eta_treat_post - att_treat_post) - \\\n",
+ " (i_weights * D - Pi_hat) * att_treat_post / Pi_hat - \\\n",
+ " (i_weights * post - lambda_hat) * att_treat_post / lambda_hat\n",
+ " \n",
+ " inf_treat_pre = (eta_treat_pre - att_treat_pre) - \\\n",
+ " (i_weights * D - Pi_hat) * att_treat_pre / Pi_hat - \\\n",
+ " (i_weights * (1 - post) - one_minus_lambda_hat) * att_treat_pre / one_minus_lambda_hat\n",
+ " \n",
+ " inf_cont_post = (eta_cont_post - att_cont_post) - \\\n",
+ " (i_weights * D - Pi_hat) * att_cont_post / Pi_hat - \\\n",
+ " (i_weights * post - lambda_hat) * att_cont_post / lambda_hat\n",
+ " \n",
+ " inf_cont_pre = (eta_cont_pre - att_cont_pre) - \\\n",
+ " (i_weights * D - Pi_hat) * att_cont_pre / Pi_hat - \\\n",
+ " (i_weights * (1 - post) - one_minus_lambda_hat) * att_cont_pre / one_minus_lambda_hat\n",
+ " \n",
+ " mom_logit_pre = -eta_cont_pre[:, np.newaxis] * int_cov\n",
+ " mom_logit_pre = np.mean(mom_logit_pre, axis=0)\n",
+ " \n",
+ " mom_logit_post = -eta_cont_post[:, np.newaxis] * int_cov\n",
+ " mom_logit_post = np.mean(mom_logit_post, axis=0)\n",
+ " \n",
+ " inf_logit = asy_lin_rep_ps @ (mom_logit_post - mom_logit_pre)\n",
+ " \n",
+ " att_inf_func = (inf_treat_post - inf_treat_pre) - (inf_cont_post - inf_cont_pre) + inf_logit\n",
+ " print(np.std(att_inf_func) / np.sqrt(n))\n",
+ " return ipw_att, att_inf_func\n",
+ "ipw_did_rc(y, post, d, x, w)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0 0.683004\n",
+ "1 0.658181\n",
+ "2 0.151920\n",
+ "3 0.528905\n",
+ "4 0.246862\n",
+ " ... \n",
+ "995 0.127359\n",
+ "996 0.606165\n",
+ "997 0.856390\n",
+ "998 0.333366\n",
+ "999 0.438967\n",
+ "Name: i_w, Length: 1000, dtype: float64"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "w"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}