# Reproducing SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data
Authored by: Alexander Schnapp, Eduard Ma, Deepali Prabhu
---
<br>
Machine learning has previously shown great success in forecasting future trends based on historical data in various domains. In order to try and predict events in the future it is important to understand what factors to incorporate from historical data as well as the target of prediction. While prediction of occurrence of an event can be straight forward, wherein the previous occurrences are used to predict a single future estimate, it is important to understand that this may not be robust enough to be incorporated in the real world. In the case of measuring the effects of intervention among a population, it is important to take into consideration specific patient properties, presence of censoring and covariate shifts in order to effectively manage risk. Survival analysis or time to event analysis is a set of techniques that does exactly this.

<center> <small> <i>source: https://infocenter.informationbuilders.com/wf80/index.jsp?topic=%2Fpubdocs%2FRStat16%2Fsource%2Ftopic45.htm</i></small> </center>
Survival analysis can be tricky and complex due to the inherent nature of time to event data and the presence of censoring and covariate shifts. Covariate shifts can occur mainly from sources such as non randomized treatments (confounding), informative censoring and event induced covariate shift. Traditional methods such as Cox proportional hazards model[3], Random Survival Forests[4] and DeepHit[5] don't exclusively account for this factors. SurvITE[1] is a novel deep learning model that estimates treatment-specific hazard functions and combats covariate shifts in the data by enforcing balanced representations.
The overall architecture of SurvITE can be seen below. SurvITE in contrast to its predecessors contains two models. The first module contains a fully connected neural network specifically to learn a balanced representation of the input by penalizing the discrepancy between the baseline distribution and each at-risk treatment specific distribution to simultaneously tackle all three sources of shifts. The second module is also a fully connected neural network aimed at estimate the final hazard function. However, there are separate output heads for each treatment and time step combination that uses the same input representation to predict the hazard.
<br>

<center> <i>Architecture of SurvITE [1]</i></center>
<br>
In our blog, we aim to reproduce the results presented by SurvITE and verify its scope and utility. In order to achieve this we present our three main objectives:
1. Reproduce the **synthetic data generation process** detailed in the paper allowing us to generate separate time to event datasets encompassing a combination of censoring and covariate shifts.
2. Fix and upgrade the existing code base presented by the authors of SurvITE to run successfully. Hence then use this in the **evaluation of the model** for robustness to various censoring effects and covariate shifts.
3. **Refactor the code base** to reduce code redundancy and be more readable.
## Reproduction of the synthetic data generation process
In order to gauge the efficacy of the hazard predictors, the paper presents synthetic data generation process that emulate scenarios involving all three types of covariate shifts above. As opposed to real world data, the synthetic data is used to isolate the specific covariate shifts and capture their effect on the performance of the underlying model. SurvITE defines a known input distribution and ground truth hazard function that dictates the distribution of data used to train and test the models.
$$ X \sim \mathcal{N}(0,\Sigma) \\ where \ \Sigma = (1-\rho)I+\rho11^T $$
The hazard function is a combination of two counting process, one for the occurance of time of interest $N(t)$ and the other for occurance of censoring $C(t)$. The final time to event occurance is chosen by $min(N,C)$.
The probability distribution that underlies $N(t)$ is given by:
$$
\mathrm{\lambda^a(t|x)} = \begin{cases}
0.1\sigma(-5x_1^2 -a(1\{x_3>0\}+0.5)) & \text{if } for \ t<0 \\
0.1\sigma(10x_2 -a(1\{x_3>0\}+0.5)) & \text{otherwise.}
\end{cases}
$$
$$ where\ a \sim Bern(\xi.\sigma(\sum_{p\in P} x_p) ) $$
The probability distribution that underlies $C(t)$ is given by:
$$ \lambda_C(t|x) = 0.01\sigma(10x_4^2) $$
Here, $\rho$ is set to 0.2. $\xi$ represents the selection strength where a higher value would introduce more covariate shift induced by non randomized treatment. The set $P$ determines whether the covariate set overlaps with the event inducing covariates or not. Namely $x_1$ , $x_2$ and $x_4$ are event inducing covariates as they effect the hazard function directly and the remaining dont.
Using the above we follow the below procedure to generate data for different scenarios:
* We used the hazard and the censoring function to generate a probability at every time point t (1 to 29).
* The probability at each time point is then compared with a probability drawn from a uniform distribution at random, if this value is larger, the event or censoring did not occur, and if smaller, the event or censoring occurred.
* If no event occurance is encountered from t=1 to t=29, we set the t = 30 for X, and mark that censoring occurred.
* For shift 2 and shift 3, if both censoring and event of interest occurs at the same time, we discard the results and draw from the uniform distribution again.
## Training and Reproducing Figures
### Getting existing code to run
We first moved all the code from GitHub [1] into a google colab notebook. We found that it uses an old version of TensorFlow, which was not possible to run on Google Colab. We had to update the code to be compatible with TensorFlow 2, which was a manual cycle of attempting to run, error, fix, then repeat until all TensorFlow errors went away.
After fixing all the errors, we successfully trained a model using the code in `tutorial.ipynb` and the data in the github repository.
We trained four models using four different data sets generated as mentioned in the previous section:
* **S1** : No treatments and no censoring.
* **S2** : No treatments and informative censoring.
* **S3** : Non randomized biased treatment and no censoring.
* **S4** : Non randomized biased treatment and informative censoring.
The data set for each scenario was generated 2 times to create a training set and test set each having 5000 data points. **Each training session took approximately an hour and a half on a free google colab instance using a TPU**.
For the baseline models of Cox and RSF, we used scikit-survival package[2] and tranformed the data to be compatible.
### Refactoring the code
While evaluating the code, we found multiple places in the code that were repeating itself in the SurvITE class. The original code for `fcnet` was repeated for `tmp_A1`, and `tmp_A0`. We merged the code into a function called `tmpfcnetgen`, and called that function instead in all places where it was needed. Another location that was repeating itself was the `tf.cond` part for `idx0`, `idx1`, and that was merged into a function called gentfcond. Lastly the methods `predict_survival_A0`, `predict_survival_A1` and `predict_hazard_A0`, `predict_hazard_A1` functions were repeats with small modifications, so they were both merged into two separate functions.
## Results
We attempted to reproduce figure 3 and Table 1 of [1] by: generating the synthetic data; implementing the baseline models; and reproducing the estimated errors between the ground truth and predicted hazard and survival functions. Through estimating the treatment-specific hazard function, and subsequently the treatment specific survival functions, treatment heterogeneous treatment effects can be estimated of which two are used in figure 3 and table 1.
The synthetic data generation process delineated by the paper aims to individually represent the covariate shifts from three different type of sources. By eyeballing the distributions of the data generated from such a process, we observe that the distribution of data by hazard functions don’t show much differences among the four different scenarios considered above. However, the frequency of distributions by treatment and censoring for the various scenarios show the corresponding imbalance and indicate that the underlying process does indeed represent the distinct sources of covariate shifts appropriately. After regenerating all 4 scenarios as described in [1], we obtained a relatively close approximation of what the authors produced.
<br>

<center> <i>Distributions of synthetically generated data for different scenarios</i></center>
<br>
\
The heterogeneous treatment effect of interest is described as the the difference in treatment-specific survival times at time $\tau$, i.e. $HTE_{surv}(τ|x)$:
$$ HTE_{surv}(\tau | x) = S^{1}(\tau | x) - S^{0}(\tau | x) $$
This quantity can be computed for both the ground truth survival functions ($HTE_{surv}$ and the estimated survival functions outputted from the SurvITE model ($HTE^*_{surv}$). The estimated error then is described as the averaged root mean squared error (RMSE) between the ground truth and the estimated functions over all patient data points $n$:
$$ \epsilon_{HTE_{surv}(t)} = \sqrt{1/{n} \sum_{i=1}^{n}(HTE_{surv}(t|x_i) - HTE^*_{surv}(t|x_i))^2}$$
The figures below shows the results of evaluation of the baselines models and SurvITE on different synthetically generated datasets. We see that for S1 and S2, the SurvITE model performs slightly worse than RSF. However, with the addition of more sources of covariate shifts in datasets, we see that SurvITE performs much better than both RSF and Cox. Hence, we can conclude that in the presence of much more complex time to event data containing multiple sources of covariate shifts, SurvITE is able to exploit its representation module to force a more balanced representation in order to combat covariate shifts. Counter intuitively, in the absence of such covariate shifts, SurvITE does worse than the baselines which should not have been the case as the input representation would be balanced and should be able to pass through to the hazard estimator without being modified to account for covariate shifts.
<br>

<center> <i>Plots of reproduced results</i></center>
<br>
The heterogeneous treatment effect of interest is the difference in restricted mean survival time (RMST) up to time $L$, i.e. $HTE_{rmst}(x)$ defined as:
$$ HTE_{rmst}(x) = \sum_{t_k \leq L}(S^{0}(t_k|x) - S^{1}(t_k|x)(t_k - t_{k-1}))$$
The ground truth of this effect thus can be derived from the ground truth survival functions (which were derived from the ground truth hazard functions). The model outputs the estimated survival function directly for both treatment absent and treatment induced effect estimation. The $HTE_{rmst}(x)$ of the estimated survival functions can then be computed similarly ($HTE^*_{rmst}$). Next, the averaged root mean squared error at time $t$, $\epsilon_{HTE_{rmst}(t)}$, between the ground truth and prediction is calculated as:
$$ \epsilon_{HTE_{rmst}}(L) = \sqrt{1/{n} \sum_{i=1}^{n}(HTE_{rmst}(x_i;L) - HTE^*_{rmst}(x_i;L))^2}$$
where again $n$ is the number of data points in our test set and $L$ is the time time point up till which to infer the RMST. For both $L=10$ and $L=20$ the mean $\epsilon_{HTE_{rmst}(t)} \pm 95\%$ confidence interval were calculated. $L$ are chosen to be the $25^{th}$ and $75^{th}$ percentiles of event times.
A comparison between our reproduced results and the results presented in the paper shows that our results for the SurvITE model are close to the ones presented in the paper. However, it is noteworthy to observe that the gap between the baseline models and SurvITE is much lesser in our reproduced results compared to the baseline models. Given that we used the same hyperparameters and settings for the baseline, the results for the baseline should have been similar.
<br>
<style type="text/css">
.tg {border-collapse:collapse;border-spacing:0;margin:0px auto;}
.tg td{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
overflow:hidden;padding:10px 5px;word-break:normal;}
.tg th{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}
.tg .tg-4vkt{background-color:#D9E1F2;text-align:center;vertical-align:bottom}
.tg .tg-nrix{text-align:center;vertical-align:middle}
.tg .tg-cixx{background-color:#E2EFDA;text-align:center;vertical-align:bottom}
.tg .tg-ykua{background-color:#DDEBF7;text-align:center;vertical-align:bottom}
.tg .tg-7zrl{text-align:left;vertical-align:bottom}
.tg .tg-0lax{text-align:left;vertical-align:top}
</style>
<table class="tg">
<thead>
<tr>
<th class="tg-nrix" rowspan="3">Method</th>
<th class="tg-cixx" colspan="2">Reproduction</th>
<th class="tg-4vkt" colspan="2">Original</th>
<th class="tg-cixx" colspan="2">Reproduction</th>
<th class="tg-ykua" colspan="2">Original</th>
</tr>
<tr>
<th class="tg-cixx" colspan="2">S3 (ἐ=3,no overlap)</th>
<th class="tg-4vkt" colspan="2">S3 (ἐ=3,no overlap)</th>
<th class="tg-cixx" colspan="2">S4 (ἐ=3,no overlap)</th>
<th class="tg-ykua" colspan="2">S4 (ἐ=3,no overlap)</th>
</tr>
<tr>
<th class="tg-cixx">L=10</th>
<th class="tg-cixx">L=20</th>
<th class="tg-4vkt">L=10</th>
<th class="tg-4vkt">L=20</th>
<th class="tg-cixx">L=10</th>
<th class="tg-cixx">L=20</th>
<th class="tg-ykua">L=10</th>
<th class="tg-ykua">L=20</th>
</tr>
</thead>
<tbody>
<tr>
<td class="tg-7zrl">Cox</td>
<td class="tg-7zrl">0.307 ± 0.196</td>
<td class="tg-7zrl">0.814 ± 0.273</td>
<td class="tg-7zrl">0.434 ± 0.03</td>
<td class="tg-7zrl">1.073 ± 0.05</td>
<td class="tg-7zrl">0.292 ± 0.186</td>
<td class="tg-7zrl">0.846 ± 0.299</td>
<td class="tg-7zrl">0.424 ± 0.02</td>
<td class="tg-7zrl">1.047 ± 0.04</td>
</tr>
<tr>
<td class="tg-7zrl">RSF</td>
<td class="tg-0lax">0.222 ± 0.136</td>
<td class="tg-7zrl">0.638 ± 0.229</td>
<td class="tg-0lax">0.328 ± 0.02</td>
<td class="tg-7zrl">1.027 ± 0.03</td>
<td class="tg-7zrl">0.297 ± 0.176</td>
<td class="tg-7zrl">0.844 ± 0.302</td>
<td class="tg-7zrl">0.332 ± 0.02</td>
<td class="tg-7zrl">1.058 ± 0.03</td>
</tr>
<tr>
<td class="tg-7zrl">SurvITE</td>
<td class="tg-7zrl">0.231 ± 0.145</td>
<td class="tg-7zrl">0.673 ± 0.240</td>
<td class="tg-7zrl">0.225 ± 0.03</td>
<td class="tg-7zrl">0.687 ± 0.08</td>
<td class="tg-7zrl">0.211 ± 0.136</td>
<td class="tg-7zrl">0.643 ± 0.235</td>
<td class="tg-7zrl">0.237 ± 0.03</td>
<td class="tg-7zrl">0.703 ± 0.06</td>
</tr>
</tbody>
</table>
<br>
## Challenges
The main challenges in chronological order were updating Tensorflow, understanding the synthetic code generation algorithm and getting the baseline Cox model to work with the synthetically generated data.
* Although tensorflow had an "autoupdater", we could not do that in colab, and without a local tensorflow environment set up, we had to manually update. We still have not updated everything and still get depreciation warnings.
* The first challenge was the entire synthetic data generation algorithm is defined in many different places in the paper and had to be put together into one codebase. After figuring out what most of the synthetic data generation algorithm means in the paper, we had to assume how to determine if at any given time step censoring occurred, which we chose a uniform distribution to determine that as mentioned in the previous section.
* Setting up the baseline models was a challenge as we had to follow a two learner model. The input format for the basleline model was slightly different than the ones geenrated, hence we had to build adapter functions.
## Conclusion
All in all, we managed to reproduce Figure.3 and parts of Table 1 of [1] sufficiently. While the results and inferences were mostly similar, there were two deviations:
1. SurvITE performs worse than the baseline models in the absence of covariate shifts introduced in the paper.
2. The gap between the performance of SurvITE and the baseline models is much smaller than presented.
These insights show that the paper possibly trained a different version of the baseline model. It should be noted that this is speculation rather than a hard conclusion.
With that notion we come to the end of this blog post. We hope to have given an extensive explanation of our reproduction process and results. See ya!
# References
[1] A. Curth, C. Lee, and M. van der Schaar, “Survite: learning heterogeneous treatment effects from time-to-event data,” Advances in Neural Information Processing Systems, vol. 34, pp. 26740–26753, 2021.
[2] https://scikit-survival.readthedocs.io/en/stable/
[3] Cox, D. R. “Regression Models and Life-Tables.” Journal of the Royal Statistical Society. Series B (Methodological), vol. 34, no. 2, 1972, pp. 187–220. JSTOR, http://www.jstor.org/stable/2985181. Accessed 24 Apr. 2023.
[4] Ishwaran, Hemant, Udaya B. Kogalur, Eugene H. Blackstone, and Michael S. Lauer. "Random survival forests." (2008): 841-860.
[5] Lee, Changhee, William Zame, Jinsung Yoon, and Mihaela Van Der Schaar. "Deephit: A deep learning approach to survival analysis with competing risks." In Proceedings of the AAAI conference on artificial intelligence, vol. 32, no. 1. 2018.
# Team
Alexander Schnapp (5818583): New code variant, Reproduction of results
Deepali Prabhu (5732166) : New Data, Reproduction of results
Eduard Ma (4660668): New Data, Reproduction of results