Survival prediction with cancer WSIs

Survival prediction with cancer WSIs#

In this tutorial, you will be introduced to building survival prediction models using both machine learning and deep learning approaches with WSIs.

We will use a subset of WSIs from TCGA-READ for illustration purposes only. The provided code and analyses are intended solely as a demonstration of workflow and usage. The results shown here are not validated and should not be used for drawing biological or clinical conclusions.

from huggingface_hub import hf_hub_download

metadata = hf_hub_download(
    "rendeirolab/lazyslide-data",
    "TCGA_READ_survival.csv",
    repo_type="dataset",
    local_dir=".",
)
import pandas as pd

metadata = pd.read_csv(metadata)
metadata.head(5)
PATIENT_ID AGE AJCC_STAGING_EDITION BIOPSY_SITE DAYS_LAST_FOLLOWUP DAYS_TO_BIRTH DAYS_TO_DEATH DISEASE_TYPE ETHNICITY ICD_10 ... SEX VITAL_STATUS YEAR_OF_DIAGNOSIS OS_STATUS OS_MONTHS PROJECT_ID PROJECT_NAME PROJECT_STATE FILE_ID FILE_NAME
0 TCGA-AG-3601 68.0 6th Rectum 0.0 -24837.0 NaN Rectal Adenocarcinoma NaN C19 ... Male Alive 2007.0 0:LIVING 0.000000 TCGA-READ Rectal Adenocarcinoma released d76f3d2c-f30d-4592-a51d-620e17419222 TCGA-AG-3601-01Z-00-DX1.30ac783e-ba70-49ef-8be...
1 TCGA-AF-6136 72.0 7th Rectum 232.0 -26490.0 NaN Rectal Adenocarcinoma NOT HISPANIC OR LATINO C19 ... Female Alive 2011.0 0:LIVING 24.342970 TCGA-READ Rectal Adenocarcinoma released ec1f9ff4-f634-42fb-9c64-dee4950df4df TCGA-AF-6136-01Z-00-DX1.a0e22964-b7b4-43ba-bfa...
2 TCGA-AH-6549 66.0 7th Rectum 6.0 -24337.0 NaN Rectal Adenocarcinoma NaN C20 ... Male Alive 2010.0 0:LIVING 17.477004 TCGA-READ Rectal Adenocarcinoma released 1f07b827-255a-4d32-9eea-f2daff7bd937 TCGA-AH-6549-01Z-00-DX1.38ea40f7-4ebf-49cc-801...
3 TCGA-AG-A01Y 49.0 5th Rectum 0.0 -18112.0 NaN Rectal Adenocarcinoma NaN C20 ... Female Alive 2004.0 0:LIVING 0.000000 TCGA-READ Rectal Adenocarcinoma released b7c07212-1bc0-48b3-8b69-6adffa0fb08f TCGA-AG-A01Y-01Z-00-DX1.3F49940B-3758-419B-89C...
4 TCGA-AG-3883 69.0 6th Rectum 31.0 -25415.0 NaN Rectal Adenocarcinoma NaN C20 ... Male Alive 2008.0 0:LIVING 1.018397 TCGA-READ Rectal Adenocarcinoma released 2d8f32f9-bb04-46cd-8d63-91b591e9b8b3 TCGA-AG-3883-01Z-00-DX1.2a21ffb1-8a60-4424-b74...

5 rows × 31 columns

To get all svs files, run the following code:

slides = snapshot_download(
    "rendeirolab/lazyslide-data",
    repo_type="dataset",
    local_dir="tcga_read",
    allow_patterns=["tcga_read/*.svs"],
)

Now let’s run feature extration on these slides, you can parallel it however you want depends on your infrastructure setup, here is a demo code to parallel across GPU nodes with SLURM.

from dask.distributed import Client
from dask_jobqueue import SLURMCluster

def wsi_feature_extraction(slide):

    from wsidata import open_wsi
    import lazyslide as zs

    wsi = open_wsi(s, attach_thumbnail=False)
    zs.pp.find_tissues(wsi)
    zs.pp.tile_tissues(wsi, 448, mpp=0.5, background_fraction=0.5)

    zs.tl.feature_extraction(wsi, "titan", pbar=False)
    zs.tl.feature_aggregation(wsi, "titan", encoder="titan")
    wsi.write()

cluster = SLURMCluster(
    queue="gpu",
    cores=8,
    processes=1,
    memory="10 GB",
    job_extra_directives=[
        "-q gpu",
        "--gres=gpu:h100pcie:1",
        "--time=1:00:00",
    ],
    worker_extra_args=["--resources GPU=1"],
    log_directory="./dask-logs",
)

client = Client(cluster)
cluster.scale(10)  # Get 10 workers, each with one H100 to run

futures = [
    client.submit(wsi_feature_extraction, f"slides/{slide}", resources={"GPU": 1})
    for slide in matadata["FILE_NAME"]
]

When you finished with the processing, you can aggregate the slide features with

from wsidata import agg_wsi

matadata["slide_path"] = [f"tcga_read/{s}" for s in matadata["FILE_NAME"]]
adata = agg_wsi(matadata, wsi_col="slide_path", feature_key="titan")

We prepared a pre-computed features matrix if you don’t want to run the feature extration.

import anndata as ad

titan_features = hf_hub_download(
    "rendeirolab/lazyslide-data",
    "TCGA_READ_subset_TITAN.h5ad",
    repo_type="dataset",
    local_dir=".",
)

adata = ad.read_h5ad(titan_features)
adata
AnnData object with n_obs × n_vars = 50 × 768
    obs: 'PATIENT_ID', 'AGE', 'AJCC_STAGING_EDITION', 'BIOPSY_SITE', 'DAYS_LAST_FOLLOWUP', 'DAYS_TO_BIRTH', 'DAYS_TO_DEATH', 'DISEASE_TYPE', 'ETHNICITY', 'ICD_10', 'MORPHOLOGY', 'OTHER_PATIENT_ID', 'PATH_M_STAGE', 'PATH_N_STAGE', 'PATH_STAGE', 'PATH_T_STAGE', 'PRIMARY_DIAGNOSIS', 'PRIMARY_SITE_PATIENT', 'PRIOR_MALIGNANCY', 'PRIOR_TREATMENT', 'RACE', 'SEX', 'VITAL_STATUS', 'YEAR_OF_DIAGNOSIS', 'OS_STATUS', 'OS_MONTHS', 'PROJECT_ID', 'PROJECT_NAME', 'PROJECT_STATE', 'FILE_ID', 'FILE_NAME', 'slide_path'

Some data preprocessing to prepare for survival prediction

adata.obs["status"] = (
    adata.obs["OS_STATUS"].map({"0:LIVING": 0, "1:DECEASED": 1}).astype(bool)
)
import scanpy as sc

sc.pp.scale(adata)
sc.pp.pca(adata)
sc.pl.pca(adata, color="OS_STATUS")
../_images/ba62dca3c2fad1780687afecdc66750559c3f29b5b2d03c055b3fed958a3f486.png

Kaplan meier analysis#

Kaplan-Meier analysis is a non-parametric method used to estimate the survival probability over time, even when some data are censored. You may already saw KM plot in many publications.

In real analysis, you will plot the KM plot for different groups and compare them.

import matplotlib.pyplot as plt

from sksurv.nonparametric import kaplan_meier_estimator

time, survival_prob, conf_int = kaplan_meier_estimator(
    adata.obs["status"], adata.obs["OS_MONTHS"], conf_type="log-log"
)
with plt.rc_context({"figure.figsize": (6, 4)}):
    plt.step(time, survival_prob, where="post")
    plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
    plt.ylim(0, 1)
    plt.ylabel(r"est. probability of survival $\hat{S}(t)$")
    plt.xlabel("time $t$")
../_images/1685661ca8b76cc19e9ac44f1e2925c220b0b739940399409dd45630fd0b2434.png

Machine learning model#

Let’s start with a machine learning based model. We used the model from scikit-survival.

The metric we use to evaluate the model is called Concordance Index (cindex).

The concordance index (cindex) is a measure of how well a survival model predicts the order of events. It quantifies the agreement between the predicted risk scores and the actual observed survival times, with a value of 1.0 indicating perfect prediction and 0.5 indicating random chance. A reasonable model performance should have cindex > 0.7.

from sksurv.linear_model import CoxnetSurvivalAnalysis
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    adata.X,
    adata.obs[["status", "OS_MONTHS"]].to_records(index=False),
    test_size=0.2,
    stratify=adata.obs["status"],
    random_state=10,
)

model = CoxnetSurvivalAnalysis(
    l1_ratio=0.9, alpha_min_ratio=0.01, n_alphas=100, fit_baseline_model=True
)
model.fit(X_train, y_train)
s = model.score(X_test, y_test)
print("cindex:", model.score(X_test, y_test))
cindex: 0.6428571428571429

Nerual network#

If you have a lot of data, you can train a neural network to predict the hazard ratio.

Let’s defined a simple models with the input of slide features and output of hazard ratio.

import torch
import torch.nn as nn
import torch.nn.functional as F


# Define a model
class CoxMLP(nn.Module):
    def __init__(self, in_features, hidden_dim=32, dropout=0.3):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)  # single risk score

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)  # risk score, no activation
X_train, X_test, E_train, E_test, T_train, T_test = train_test_split(
    adata.X,
    adata.obs["status"].to_numpy(),
    adata.obs["OS_MONTHS"].to_numpy(),
    test_size=0.2,
    stratify=adata.obs["status"],
    random_state=10,
)

X_train = torch.tensor(X_train)
X_test = torch.tensor(X_test)
E_train = torch.tensor(E_train)
E_test = torch.tensor(E_test)
T_train = torch.tensor(T_train)
T_test = torch.tensor(T_test)

Here we use the torchsurv package for loss function and cindex calculation.

Here we only use a simple training loop for showcase.

from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex

# The titan features size is 768
model = CoxMLP(768)
torch.manual_seed(0)

# Trained for 10 epoches
for _ in range(10):
    estimate = model(X_train)
    loss = cox.neg_partial_log_likelihood(estimate, E_train, T_train)
    loss.backward()

# Evaluation
with torch.no_grad():
    log_hz = model(X_test)
    cindex = ConcordanceIndex()
    print(cindex(log_hz, E_test, T_test))
tensor(0.8571, dtype=torch.float64)