Zero-shot learning

Zero-shot learning#

In this tutorial we will explore how we can use the general features learned by foundational models to perform tasks they have not been trained for, without the need to retrain the model (zero-shot learning).

import lazyslide as zs
wsi = zs.datasets.lung_carcinoma(with_data=False)
wsi
WSI: /home/runner/.cache/huggingface/hub/datasets--RendeiroLab--LazySlide-data/snapshots/9644d886889040fa10e757d912f249bbf936a979/lung_carcinoma.ndpi
Reader: openslide
Dimensions: 15616×16384 (h×w), 8 Pyramids
Pixel physical size: 0.23 MPP (40X)
SpatialData object
└── Images
      └── 'wsi_thumbnail': DataArray[cyx] (3, 1817, 1906)
with coordinate systems:
    ▸ 'global', with elements:
        wsi_thumbnail (Images)
zs.pp.find_tissues(wsi)
zs.pp.tile_tissues(wsi, 512, background_fraction=0.95, mpp=0.5)

Zero-shot classification#

We will first try to do zero-shot classification with conch, a text-image model

zs.tl.feature_extraction(wsi, "conch")

We need to prepare a set of classes to predict the probability.

classes = ["lung carcinoma", "breast carcinoma", "normal tissue"]
embeds = zs.tl.text_embedding(classes, model="conch")
zs.tl.text_image_similarity(wsi, embeds, model="conch", softmax=True)

We will use top-k scoring to decide the score of each classes.

scores = zs.metrics.topk_score(wsi["conch_tiles_text_similarity"], agg_method="max")
for c, s in zip(classes, scores / scores.sum()):
    print(f"{c}: {s:.2f}")
lung carcinoma: 0.36
breast carcinoma: 0.32
normal tissue: 0.33

Alternatively, we can use Prism or Titan to query the probability directly.

The input of prism’s slide encoder must come from virchow model

zs.tl.feature_extraction(wsi, "virchow")
zs.tl.feature_aggregation(wsi, feature_key="virchow", encoder="prism", device="cpu")

zs.tl.zero_shot_score(wsi, classes, feature_key="virchow_tiles", device="cpu")
lung carcinoma breast carcinoma normal tissue
0 0.912962 0.069559 0.017479
results = zs.tl.slide_caption(
    wsi, ["what is the diagnosis of the slide?"], feature_key="virchow_tiles", device="cpu", model="prism"
)
results['caption'][0]
['</s>what is the diagnosis of the slide? </s>In situ squamous cell carcinoma with positive p16 and high-risk HPV. </s>']