DIY Interactive Segmentation with napari#

napari is a very flexible and “hackable” tool. In this tutorial we will make a custom interactive segmentation tool from scratch.

In this tutorial we will write an interactive segmentation tool and use it on data hosted on Zebrahub.

3D volumetric image from zebrahub of a developing zebrafish 3D volumetric segmentation of with some of the zebrahub image segmented but not very well

Setup#

# this cell is required to run these notebooks on Binder. Make sure that you also have a desktop tab open.
import os
if 'BINDER_SERVICE_HOST' in os.environ:
    os.environ['DISPLAY'] = ':1.0'
!pip install scikit-learn scikit-image ome-zarr
Collecting scikit-learn
  Downloading scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Requirement already satisfied: scikit-image in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (0.24.0)
Collecting ome-zarr
  Downloading ome_zarr-0.9.0-py3-none-any.whl.metadata (3.4 kB)
Requirement already satisfied: numpy>=1.19.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-learn) (2.0.2)
Requirement already satisfied: scipy>=1.6.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-learn) (1.14.1)
Collecting joblib>=1.2.0 (from scikit-learn)
  Downloading joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Downloading threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: networkx>=2.8 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (3.3)
Requirement already satisfied: pillow>=9.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (10.4.0)
Requirement already satisfied: imageio>=2.33 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (2.35.1)
Requirement already satisfied: tifffile>=2022.8.12 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (2024.8.28)
Requirement already satisfied: packaging>=21 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (24.1)
Requirement already satisfied: lazy-loader>=0.4 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (0.4)
Collecting aiohttp<4 (from ome-zarr)
  Downloading aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Requirement already satisfied: dask in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (2024.8.2)
Collecting distributed (from ome-zarr)
  Downloading distributed-2024.8.2-py3-none-any.whl.metadata (3.3 kB)
Requirement already satisfied: fsspec!=2021.07.0,>=0.8 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (2024.6.1)
Requirement already satisfied: requests in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (2.32.3)
Requirement already satisfied: toolz in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (0.12.1)
Requirement already satisfied: zarr>=2.8.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (2.18.2)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp<4->ome-zarr)
  Downloading aiohappyeyeballs-2.4.0-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp<4->ome-zarr)
  Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Requirement already satisfied: attrs>=17.3.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from aiohttp<4->ome-zarr) (24.2.0)
Collecting frozenlist>=1.1.1 (from aiohttp<4->ome-zarr)
  Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp<4->ome-zarr)
  Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting yarl<2.0,>=1.0 (from aiohttp<4->ome-zarr)
  Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (31 kB)
Collecting s3fs (from fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
  Downloading s3fs-2024.6.1-py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: asciitree in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from zarr>=2.8.1->ome-zarr) (0.3.3)
Requirement already satisfied: numcodecs>=0.10.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from zarr>=2.8.1->ome-zarr) (0.13.0)
Requirement already satisfied: fasteners in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from zarr>=2.8.1->ome-zarr) (0.19)
Requirement already satisfied: click>=8.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (8.1.7)
Requirement already satisfied: cloudpickle>=3.0.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (3.0.0)
Requirement already satisfied: partd>=1.4.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (1.4.2)
Requirement already satisfied: pyyaml>=5.3.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (6.0.2)
Requirement already satisfied: importlib-metadata>=4.13.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (8.4.0)
Requirement already satisfied: jinja2>=2.10.3 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (3.1.4)
Requirement already satisfied: locket>=1.0.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (1.0.0)
Collecting msgpack>=1.0.2 (from distributed->ome-zarr)
  Downloading msgpack-1.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Requirement already satisfied: psutil>=5.8.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (6.0.0)
Collecting sortedcontainers>=2.0.5 (from distributed->ome-zarr)
  Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl.metadata (10 kB)
Collecting tblib>=1.6.0 (from distributed->ome-zarr)
  Downloading tblib-3.0.0-py3-none-any.whl.metadata (25 kB)
Requirement already satisfied: tornado>=6.2.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (6.4.1)
Requirement already satisfied: urllib3>=1.26.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (2.2.2)
Collecting zict>=3.0.0 (from distributed->ome-zarr)
  Downloading zict-3.0.0-py2.py3-none-any.whl.metadata (899 bytes)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from requests->ome-zarr) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from requests->ome-zarr) (3.8)
Requirement already satisfied: certifi>=2017.4.17 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from requests->ome-zarr) (2024.8.30)
Requirement already satisfied: zipp>=0.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from importlib-metadata>=4.13.0->dask->ome-zarr) (3.20.1)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from jinja2>=2.10.3->distributed->ome-zarr) (2.1.5)
Collecting aiobotocore<3.0.0,>=2.5.4 (from s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
  Downloading aiobotocore-2.14.0-py3-none-any.whl.metadata (23 kB)
Collecting botocore<1.35.8,>=1.35.0 (from aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
  Downloading botocore-1.35.7-py3-none-any.whl.metadata (5.7 kB)
Requirement already satisfied: wrapt<2.0.0,>=1.10.10 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (1.16.0)
Collecting aioitertools<1.0.0,>=0.5.1 (from aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
  Downloading aioitertools-0.11.0-py3-none-any.whl.metadata (3.3 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from botocore<1.35.8,>=1.35.0->aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from botocore<1.35.8,>=1.35.0->aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.35.8,>=1.35.0->aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (1.16.0)
Downloading scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.3 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/13.3 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.3/13.3 MB 113.2 MB/s eta 0:00:00
?25h
Downloading ome_zarr-0.9.0-py3-none-any.whl (37 kB)
Downloading aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.3 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 124.0 MB/s eta 0:00:00
?25h
Downloading joblib-1.4.2-py3-none-any.whl (301 kB)
Downloading threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Downloading distributed-2024.8.2-py3-none-any.whl (1.0 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.0 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 109.8 MB/s eta 0:00:00
?25h
Downloading aiohappyeyeballs-2.4.0-py3-none-any.whl (12 kB)
Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (272 kB)
Downloading msgpack-1.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (409 kB)
Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (128 kB)
Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Downloading tblib-3.0.0-py3-none-any.whl (12 kB)
Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (328 kB)
Downloading zict-3.0.0-py2.py3-none-any.whl (43 kB)
Downloading s3fs-2024.6.1-py3-none-any.whl (29 kB)
Downloading aiobotocore-2.14.0-py3-none-any.whl (77 kB)
Downloading aioitertools-0.11.0-py3-none-any.whl (23 kB)
Downloading botocore-1.35.7-py3-none-any.whl (12.5 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/12.5 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.5/12.5 MB 150.0 MB/s eta 0:00:00
?25h
Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)
Installing collected packages: sortedcontainers, zict, threadpoolctl, tblib, multidict, msgpack, joblib, jmespath, frozenlist, aioitertools, aiohappyeyeballs, yarl, scikit-learn, botocore, aiosignal, distributed, aiohttp, aiobotocore, s3fs, ome-zarr
Successfully installed aiobotocore-2.14.0 aiohappyeyeballs-2.4.0 aiohttp-3.10.5 aioitertools-0.11.0 aiosignal-1.3.1 botocore-1.35.7 distributed-2024.8.2 frozenlist-1.4.1 jmespath-1.0.1 joblib-1.4.2 msgpack-1.0.8 multidict-6.0.5 ome-zarr-0.9.0 s3fs-2024.6.1 scikit-learn-1.5.1 sortedcontainers-2.4.0 tblib-3.0.0 threadpoolctl-3.5.0 yarl-1.9.4 zict-3.0.0
from appdirs import user_data_dir
import os
import zarr
import dask.array as da
import toolz as tz

from sklearn.ensemble import RandomForestClassifier

from skimage import data, segmentation, feature, future
from skimage.feature import multiscale_basic_features
from skimage.io import imread, imshow
import numpy as np
from functools import partial
import napari
import threading

from ome_zarr.io import parse_url
from ome_zarr.reader import Reader

from functools import partial
from psygnal import debounced
from superqt import ensure_main_thread

import logging
import sys

LOGGER = logging.getLogger("halfway_to_i2k_2023_america")
LOGGER.setLevel(logging.DEBUG)

streamHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
streamHandler.setFormatter(formatter)
LOGGER.addHandler(streamHandler)

Reading in data#

Get data from OpenOrganelle.

def open_zebrahub():
    url = "https://public.czbiohub.org/royerlab/zebrahub/imaging/single-objective/ZSNS002.ome.zarr/"

    # read the image data
    parse_url(url, mode="r").store

    reader = Reader(parse_url(url))
    # nodes may include images, labels etc
    nodes = list(reader())
    # first node will be the image pixel data
    image_node = nodes[0]

    dask_data = image_node.data

    return dask_data

zebrahub_data = open_zebrahub()
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[4], line 17
     13     dask_data = image_node.data
     15     return dask_data
---> 17 zebrahub_data = open_zebrahub()

Cell In[4], line 5, in open_zebrahub()
      2 url = "https://public.czbiohub.org/royerlab/zebrahub/imaging/single-objective/ZSNS002.ome.zarr/"
      4 # read the image data
----> 5 parse_url(url, mode="r").store
      7 reader = Reader(parse_url(url))
      8 # nodes may include images, labels etc

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:217, in parse_url(path, mode, fmt)
    212 """Convert a path string or URL to a ZarrLocation subclass.
    213 
    214 >>> parse_url('does-not-exist')
    215 """
    216 try:
--> 217     loc = ZarrLocation(path, mode=mode, fmt=fmt)
    218     if "r" in mode and not loc.exists():
    219         return None

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:55, in ZarrLocation.__init__(self, path, mode, fmt)
     50     loader = CurrentFormat()
     51 self.__store: FSStore = (
     52     path if isinstance(path, FSStore) else loader.init_store(self.__path, mode)
     53 )
---> 55 self.__init_metadata()
     56 detected = detect_format(self.__metadata, loader)
     57 LOGGER.debug("ZarrLocation.__init__ %s detected: %s", path, detected)

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:70, in ZarrLocation.__init_metadata(self)
     66 def __init_metadata(self) -> None:
     67     """
     68     Load the Zarr metadata files for the given location.
     69     """
---> 70     self.zarray: JSONDict = self.get_json(".zarray")
     71     self.zgroup: JSONDict = self.get_json(".zgroup")
     72     self.__metadata: JSONDict = {}

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:157, in ZarrLocation.get_json(self, subpath)
    149 """
    150 Load and return a given subpath of store as JSON.
    151 
   (...)
    154 All other exceptions log at the ERROR level.
    155 """
    156 try:
--> 157     data = self.__store.get(subpath)
    158     if not data:
    159         return {}

File <frozen _collections_abc>:774, in get(self, key, default)

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/zarr/storage.py:1446, in FSStore.__getitem__(self, key)
   1444 key = self._normalize_key(key)
   1445 try:
-> 1446     return self.map[key]
   1447 except self.exceptions as e:
   1448     raise KeyError(key) from e

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/fsspec/mapping.py:155, in FSMap.__getitem__(self, key, default)
    153 k = self._key_to_str(key)
    154 try:
--> 155     result = self.fs.cat(k)
    156 except self.missing_exceptions:
    157     if default is not None:

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/fsspec/asyn.py:118, in sync_wrapper.<locals>.wrapper(*args, **kwargs)
    115 @functools.wraps(func)
    116 def wrapper(*args, **kwargs):
    117     self = obj or args[0]
--> 118     return sync(self.loop, func, *args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/fsspec/asyn.py:91, in sync(loop, func, timeout, *args, **kwargs)
     88 asyncio.run_coroutine_threadsafe(_runner(event, coro, result, timeout), loop)
     89 while True:
     90     # this loops allows thread to get interrupted
---> 91     if event.wait(1):
     92         break
     93     if timeout is not None:

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/threading.py:629, in Event.wait(self, timeout)
    627 signaled = self._flag
    628 if not signaled:
--> 629     signaled = self._cond.wait(timeout)
    630 return signaled

File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/threading.py:331, in Condition.wait(self, timeout)
    329 else:
    330     if timeout > 0:
--> 331         gotit = waiter.acquire(True, timeout)
    332     else:
    333         gotit = waiter.acquire(False)

KeyboardInterrupt: 
zebrahub_data[3].shape
# Let's choose a crop to work on

crop_3D = zebrahub_data[3][800, 0, :, :, :]
crop_3D.shape

Visualize in napari#

viewer = napari.Viewer()

scale = (9.92, 3.512, 3.512)
contrast_limits = (0, 372)

data_layer = viewer.add_image(crop_3D, scale=scale, contrast_limits=contrast_limits)
data_layer.bounding_box.visible = True

Extracting features#

def extract_features(image, feature_params):
    features_func = partial(
        multiscale_basic_features,
        intensity=feature_params["intensity"],
        edges=feature_params["edges"],
        texture=feature_params["texture"],
        sigma_min=feature_params["sigma_min"],
        sigma_max=feature_params["sigma_max"],
        channel_axis=None,
    )
    # print(f"image shape {image.shape} feature params {feature_params}")
    features = features_func(np.squeeze(image))
    return features

example_feature_params = {
    "sigma_min": 1,
    "sigma_max": 5,
    "intensity": True,
    "edges": True,
    "texture": True,
}

features = extract_features(crop_3D, example_feature_params)
features.shape

Visualize Features#

What do these features we are extracting look like?

A set of features used for the zebrahub model in this tutorial

def show_features():
    for feature_idx in range(features.shape[-1]):
        viewer.add_image(features[:, :, :, feature_idx])
        
# show_features()

Making the Interactive Segmentation Tool!#

Ok, now we’ve seen:

  • our data

  • some features we can compute for our data

Our goal is to create an image where we have labels that correspond to the zebrafish sample.

The approach is that when we annotate/draw in our painting layer, then we want our segmentations to be updated automatically.

We will do this using 3 different image layers:

  1. Our input image

  2. A layer for painting

  3. A layer for storing the machine learning generated predictions

Due to popular demand we will be using Zarr to store these layers, because that will help this approach scale to very large datasets. However, we could have used numpy arrays as well.

Create our painting and prediction layers#

zarr_path = os.path.join(user_data_dir("halfway_to_i2k_2023_america", "napari"), "diy_segmentation.zarr")
print(f"Saving outputs to zarr path: {zarr_path}")

# Create a prediction layer
prediction_data = zarr.open(
    f"{zarr_path}/prediction",
    mode='a',
    shape=crop_3D.shape,
    dtype='i4',
    dimension_separator="/",

)
prediction_layer = viewer.add_labels(prediction_data, name="Prediction", scale=data_layer.scale)

# Create a painting layer
painting_data = zarr.open(
    f"{zarr_path}/painting",
    mode='a',
    shape=crop_3D.shape,
    dtype='i4',
    dimension_separator="/",
)
painting_layer = viewer.add_labels(painting_data, name="Painting", scale=data_layer.scale)
painting_data.shape

Let’s make a UI as well#

A UI widget showing controls for feature size and type

from qtpy.QtWidgets import (
    QVBoxLayout,
    QHBoxLayout,
    QComboBox,
    QLabel,
    QCheckBox,
    QDoubleSpinBox,
    QGroupBox,
    QWidget,
)

class NapariMLWidget(QWidget):
    def __init__(self, parent=None):
        super(NapariMLWidget, self).__init__(parent)

        self.initUI()

    def initUI(self):
        layout = QVBoxLayout()

        # Dropdown for selecting the model
        model_label = QLabel("Select Model")
        self.model_dropdown = QComboBox()
        self.model_dropdown.addItems(["Random Forest"])
        model_layout = QHBoxLayout()
        model_layout.addWidget(model_label)
        model_layout.addWidget(self.model_dropdown)
        layout.addLayout(model_layout)

        # Select the range of sigma sizes
        self.sigma_start_spinbox = QDoubleSpinBox()
        self.sigma_start_spinbox.setRange(0, 256)
        self.sigma_start_spinbox.setValue(1)

        self.sigma_end_spinbox = QDoubleSpinBox()
        self.sigma_end_spinbox.setRange(0, 256)
        self.sigma_end_spinbox.setValue(5)

        sigma_layout = QHBoxLayout()
        sigma_layout.addWidget(QLabel("Sigma Range: From"))
        sigma_layout.addWidget(self.sigma_start_spinbox)
        sigma_layout.addWidget(QLabel("To"))
        sigma_layout.addWidget(self.sigma_end_spinbox)
        layout.addLayout(sigma_layout)

        # Boolean options for features
        self.intensity_checkbox = QCheckBox("Intensity")
        self.intensity_checkbox.setChecked(True)
        self.edges_checkbox = QCheckBox("Edges")
        self.texture_checkbox = QCheckBox("Texture")
        self.texture_checkbox.setChecked(True)

        features_group = QGroupBox("Features")
        features_layout = QVBoxLayout()
        features_layout.addWidget(self.intensity_checkbox)
        features_layout.addWidget(self.edges_checkbox)
        features_layout.addWidget(self.texture_checkbox)
        features_group.setLayout(features_layout)
        layout.addWidget(features_group)

        # Dropdown for data selection
        data_label = QLabel("Select Data for Model Fitting")
        self.data_dropdown = QComboBox()
        self.data_dropdown.addItems(
            ["Current Displayed Region", "Whole Image"]
        )
        self.data_dropdown.setCurrentText("Current Displayed Region")
        data_layout = QHBoxLayout()
        data_layout.addWidget(data_label)
        data_layout.addWidget(self.data_dropdown)
        layout.addLayout(data_layout)

        # Checkbox for live model fitting
        self.live_fit_checkbox = QCheckBox("Live Model Fitting")
        self.live_fit_checkbox.setChecked(True)
        layout.addWidget(self.live_fit_checkbox)

        # Checkbox for live prediction
        self.live_pred_checkbox = QCheckBox("Live Prediction")
        self.live_pred_checkbox.setChecked(True)
        layout.addWidget(self.live_pred_checkbox)

        self.setLayout(layout)
        
# Let's add this widget to napari

widget = NapariMLWidget()
viewer.window.add_dock_widget(widget, name="halfway to I2K 2023 America")

We have a widget, we have our painting and prediction layers, now what?#

We need to start connecting things together. How should we do that? napari has things called “events” that happen when things happen within napari. We want to respond to a few different event types:

  • changes in camera (e.g. camera position and rotation)

  • changes in “dims” (e.g. moving a dimension slider)

  • painting events (e.g. a user clicked, painted, and release their mouse)

When one of these events happens, we want to:

  • update our machine learning model with the new painted data

  • update our prediction with the updated ML model

# Let's start with our event listener

# We use "curry" because this allows us to "store" our viewer and widget for later use
@tz.curry
def on_data_change(event, viewer=None, widget=None):
    corner_pixels = data_layer.corner_pixels

    # Ensure the painting layer visual is updated
    painting_layer.refresh()

    # Training the ML model and generating predictions can take time
    #   we will use a "thread" to perform these calculations
    #   otherwise napari will freeze until these
    calculations are done
    thread = threading.Thread(
        target=threaded_on_data_change,
        args=(
            event,
            corner_pixels,
            viewer.dims,
            widget.model_dropdown.currentText(),
            {
                "sigma_min": widget.sigma_start_spinbox.value(),
                "sigma_max": widget.sigma_end_spinbox.value(),
                "intensity": widget.intensity_checkbox.isChecked(),
                "edges": widget.edges_checkbox.isChecked(),
                "texture": widget.texture_checkbox.isChecked(),
            },
            widget.live_fit_checkbox.isChecked(),
            widget.live_pred_checkbox.isChecked(),
            widget.data_dropdown.currentText(),
        ),
    )
    thread.start()
    thread.join()

    # Ensure the prediction layer visual is updated
    prediction_layer.refresh()
# Now we have to make the hard part of the listener

model = None

def threaded_on_data_change(
    event,
    corner_pixels,
    dims,
    model_type,
    feature_params,
    live_fit,
    live_prediction,
    data_choice,
):
    global model
    LOGGER.info(f"Labels data has changed! {event}")

    current_step = dims.current_step

    # Find a mask of indices we will use for fetching our data
    mask_idx = (slice(viewer.dims.current_step[0], viewer.dims.current_step[0]+1), slice(corner_pixels[0, 1], corner_pixels[1, 1]), slice(corner_pixels[0, 2], corner_pixels[1, 2]))
    if data_choice == "Whole Image":
        mask_idx = tuple([slice(0, sz) for sz in data_layer.data.shape])

    LOGGER.info(f"mask idx {mask_idx}, image {data_layer.data.shape}")
    active_image = data_layer.data[mask_idx]
    LOGGER.info(
        f"active image shape {active_image.shape} data choice {data_choice} painting_data {painting_data.shape} mask_idx {mask_idx}"
    )

    active_labels = painting_data[mask_idx]

    def compute_features(image, feature_params):
        """Compute features for each channel and concatenate them."""
        features = extract_features(
            image, feature_params
        )

        return features

    training_labels = None

    if data_choice == "Current Displayed Region":
        # Use only the currently displayed region.
        training_features = compute_features(
            active_image, feature_params
        )
        training_labels = np.squeeze(active_labels)

    else:
        raise ValueError(f"Invalid data choice: {data_choice}")

    if (training_labels is None) or np.any(training_labels.shape == 0):
        LOGGER.info("No training data yet. Skipping model update")
    elif live_fit:
        # Retrain model
        LOGGER.info(
            f"training model with labels {training_labels.shape} features {training_features.shape} unique labels {np.unique(training_labels[:])}"
        )
        model = update_model(training_labels, training_features, model_type)

    # Don't do live prediction on whole image, that happens earlier slicewise
    if live_prediction:
        # Update prediction_data
        prediction_features = compute_features(
            active_image, feature_params
        )
        # Add 1 becasue of the background label adjustment for the model
        prediction = predict(model, prediction_features, model_type)
        LOGGER.info(
            f"prediction {prediction.shape} prediction layer {prediction_layer.data.shape} prediction {np.transpose(prediction).shape} features {prediction_features.shape}"
        )

        if data_choice == "Whole Image":
            prediction_layer.data[mask_idx] = np.transpose(prediction)
        else:
            prediction_layer.data[mask_idx] = np.transpose(prediction)[
                np.newaxis, :
            ]
# Model training function that respects widget's model choice
def update_model(labels, features, model_type):
    features = features[labels > 0, :]
    # We shift labels - 1 because background is 0 and has special meaning, but models need to start at 0
    labels = labels[labels > 0] - 1
    
    if model_type == "Random Forest":
        clf = RandomForestClassifier(
            n_estimators=50, n_jobs=-1, max_depth=10, max_samples=0.05
        )

    print(
        f"updating model with label shape  {labels.shape} feature shape {features.shape} unique labels {np.unique(labels)}"
    )
    
    clf.fit(features, labels)

    return clf


def predict(model, features, model_type):
    # We shift labels + 1 because background is 0 and has special meaning
    prediction = future.predict_segmenter(features.reshape(-1, features.shape[-1]), model).reshape(features.shape[:-1]) + 1

    return np.transpose(prediction)
# Now connect everything together
for listener in [
    viewer.camera.events,
    viewer.dims.events,
    painting_layer.events.paint,
]:
    listener.connect(
        debounced(
            ensure_main_thread(
                on_data_change(
                    viewer=viewer,
                    widget=widget,  # pass the widget instance for easy access to settings
                )
            ),
            timeout=1000,
        )
    )