DIY Interactive Segmentation with napari#
napari is a very flexible and “hackable” tool. In this tutorial we will make a custom interactive segmentation tool from scratch.
In this tutorial we will write an interactive segmentation tool and use it on data hosted on Zebrahub.
Setup#
# this cell is required to run these notebooks on Binder. Make sure that you also have a desktop tab open.
import os
if 'BINDER_SERVICE_HOST' in os.environ:
os.environ['DISPLAY'] = ':1.0'
!pip install scikit-learn scikit-image ome-zarr
Collecting scikit-learn
Downloading scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Requirement already satisfied: scikit-image in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (0.24.0)
Collecting ome-zarr
Downloading ome_zarr-0.9.0-py3-none-any.whl.metadata (3.4 kB)
Requirement already satisfied: numpy>=1.19.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-learn) (2.0.2)
Requirement already satisfied: scipy>=1.6.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-learn) (1.14.1)
Collecting joblib>=1.2.0 (from scikit-learn)
Downloading joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
Downloading threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: networkx>=2.8 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (3.3)
Requirement already satisfied: pillow>=9.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (10.4.0)
Requirement already satisfied: imageio>=2.33 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (2.35.1)
Requirement already satisfied: tifffile>=2022.8.12 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (2024.8.28)
Requirement already satisfied: packaging>=21 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (24.1)
Requirement already satisfied: lazy-loader>=0.4 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from scikit-image) (0.4)
Collecting aiohttp<4 (from ome-zarr)
Downloading aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Requirement already satisfied: dask in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (2024.8.2)
Collecting distributed (from ome-zarr)
Downloading distributed-2024.8.2-py3-none-any.whl.metadata (3.3 kB)
Requirement already satisfied: fsspec!=2021.07.0,>=0.8 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (2024.6.1)
Requirement already satisfied: requests in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (2.32.3)
Requirement already satisfied: toolz in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (0.12.1)
Requirement already satisfied: zarr>=2.8.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from ome-zarr) (2.18.2)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp<4->ome-zarr)
Downloading aiohappyeyeballs-2.4.0-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp<4->ome-zarr)
Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Requirement already satisfied: attrs>=17.3.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from aiohttp<4->ome-zarr) (24.2.0)
Collecting frozenlist>=1.1.1 (from aiohttp<4->ome-zarr)
Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp<4->ome-zarr)
Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting yarl<2.0,>=1.0 (from aiohttp<4->ome-zarr)
Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (31 kB)
Collecting s3fs (from fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
Downloading s3fs-2024.6.1-py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: asciitree in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from zarr>=2.8.1->ome-zarr) (0.3.3)
Requirement already satisfied: numcodecs>=0.10.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from zarr>=2.8.1->ome-zarr) (0.13.0)
Requirement already satisfied: fasteners in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from zarr>=2.8.1->ome-zarr) (0.19)
Requirement already satisfied: click>=8.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (8.1.7)
Requirement already satisfied: cloudpickle>=3.0.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (3.0.0)
Requirement already satisfied: partd>=1.4.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (1.4.2)
Requirement already satisfied: pyyaml>=5.3.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (6.0.2)
Requirement already satisfied: importlib-metadata>=4.13.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from dask->ome-zarr) (8.4.0)
Requirement already satisfied: jinja2>=2.10.3 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (3.1.4)
Requirement already satisfied: locket>=1.0.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (1.0.0)
Collecting msgpack>=1.0.2 (from distributed->ome-zarr)
Downloading msgpack-1.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Requirement already satisfied: psutil>=5.8.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (6.0.0)
Collecting sortedcontainers>=2.0.5 (from distributed->ome-zarr)
Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl.metadata (10 kB)
Collecting tblib>=1.6.0 (from distributed->ome-zarr)
Downloading tblib-3.0.0-py3-none-any.whl.metadata (25 kB)
Requirement already satisfied: tornado>=6.2.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (6.4.1)
Requirement already satisfied: urllib3>=1.26.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from distributed->ome-zarr) (2.2.2)
Collecting zict>=3.0.0 (from distributed->ome-zarr)
Downloading zict-3.0.0-py2.py3-none-any.whl.metadata (899 bytes)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from requests->ome-zarr) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from requests->ome-zarr) (3.8)
Requirement already satisfied: certifi>=2017.4.17 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from requests->ome-zarr) (2024.8.30)
Requirement already satisfied: zipp>=0.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from importlib-metadata>=4.13.0->dask->ome-zarr) (3.20.1)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from jinja2>=2.10.3->distributed->ome-zarr) (2.1.5)
Collecting aiobotocore<3.0.0,>=2.5.4 (from s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
Downloading aiobotocore-2.14.0-py3-none-any.whl.metadata (23 kB)
Collecting botocore<1.35.8,>=1.35.0 (from aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
Downloading botocore-1.35.7-py3-none-any.whl.metadata (5.7 kB)
Requirement already satisfied: wrapt<2.0.0,>=1.10.10 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (1.16.0)
Collecting aioitertools<1.0.0,>=0.5.1 (from aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
Downloading aioitertools-0.11.0-py3-none-any.whl.metadata (3.3 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from botocore<1.35.8,>=1.35.0->aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr)
Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from botocore<1.35.8,>=1.35.0->aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.35.8,>=1.35.0->aiobotocore<3.0.0,>=2.5.4->s3fs->fsspec[s3]!=2021.07.0,>=0.8->ome-zarr) (1.16.0)
Downloading scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.3 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/13.3 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.3/13.3 MB 113.2 MB/s eta 0:00:00
?25h
Downloading ome_zarr-0.9.0-py3-none-any.whl (37 kB)
Downloading aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.3 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 124.0 MB/s eta 0:00:00
?25h
Downloading joblib-1.4.2-py3-none-any.whl (301 kB)
Downloading threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Downloading distributed-2024.8.2-py3-none-any.whl (1.0 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.0 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 109.8 MB/s eta 0:00:00
?25h
Downloading aiohappyeyeballs-2.4.0-py3-none-any.whl (12 kB)
Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (272 kB)
Downloading msgpack-1.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (409 kB)
Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (128 kB)
Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Downloading tblib-3.0.0-py3-none-any.whl (12 kB)
Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (328 kB)
Downloading zict-3.0.0-py2.py3-none-any.whl (43 kB)
Downloading s3fs-2024.6.1-py3-none-any.whl (29 kB)
Downloading aiobotocore-2.14.0-py3-none-any.whl (77 kB)
Downloading aioitertools-0.11.0-py3-none-any.whl (23 kB)
Downloading botocore-1.35.7-py3-none-any.whl (12.5 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/12.5 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.5/12.5 MB 150.0 MB/s eta 0:00:00
?25h
Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)
Installing collected packages: sortedcontainers, zict, threadpoolctl, tblib, multidict, msgpack, joblib, jmespath, frozenlist, aioitertools, aiohappyeyeballs, yarl, scikit-learn, botocore, aiosignal, distributed, aiohttp, aiobotocore, s3fs, ome-zarr
Successfully installed aiobotocore-2.14.0 aiohappyeyeballs-2.4.0 aiohttp-3.10.5 aioitertools-0.11.0 aiosignal-1.3.1 botocore-1.35.7 distributed-2024.8.2 frozenlist-1.4.1 jmespath-1.0.1 joblib-1.4.2 msgpack-1.0.8 multidict-6.0.5 ome-zarr-0.9.0 s3fs-2024.6.1 scikit-learn-1.5.1 sortedcontainers-2.4.0 tblib-3.0.0 threadpoolctl-3.5.0 yarl-1.9.4 zict-3.0.0
from appdirs import user_data_dir
import os
import zarr
import dask.array as da
import toolz as tz
from sklearn.ensemble import RandomForestClassifier
from skimage import data, segmentation, feature, future
from skimage.feature import multiscale_basic_features
from skimage.io import imread, imshow
import numpy as np
from functools import partial
import napari
import threading
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from functools import partial
from psygnal import debounced
from superqt import ensure_main_thread
import logging
import sys
LOGGER = logging.getLogger("halfway_to_i2k_2023_america")
LOGGER.setLevel(logging.DEBUG)
streamHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
streamHandler.setFormatter(formatter)
LOGGER.addHandler(streamHandler)
Reading in data#
Get data from OpenOrganelle.
def open_zebrahub():
url = "https://public.czbiohub.org/royerlab/zebrahub/imaging/single-objective/ZSNS002.ome.zarr/"
# read the image data
parse_url(url, mode="r").store
reader = Reader(parse_url(url))
# nodes may include images, labels etc
nodes = list(reader())
# first node will be the image pixel data
image_node = nodes[0]
dask_data = image_node.data
return dask_data
zebrahub_data = open_zebrahub()
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[4], line 17
13 dask_data = image_node.data
15 return dask_data
---> 17 zebrahub_data = open_zebrahub()
Cell In[4], line 5, in open_zebrahub()
2 url = "https://public.czbiohub.org/royerlab/zebrahub/imaging/single-objective/ZSNS002.ome.zarr/"
4 # read the image data
----> 5 parse_url(url, mode="r").store
7 reader = Reader(parse_url(url))
8 # nodes may include images, labels etc
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:217, in parse_url(path, mode, fmt)
212 """Convert a path string or URL to a ZarrLocation subclass.
213
214 >>> parse_url('does-not-exist')
215 """
216 try:
--> 217 loc = ZarrLocation(path, mode=mode, fmt=fmt)
218 if "r" in mode and not loc.exists():
219 return None
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:55, in ZarrLocation.__init__(self, path, mode, fmt)
50 loader = CurrentFormat()
51 self.__store: FSStore = (
52 path if isinstance(path, FSStore) else loader.init_store(self.__path, mode)
53 )
---> 55 self.__init_metadata()
56 detected = detect_format(self.__metadata, loader)
57 LOGGER.debug("ZarrLocation.__init__ %s detected: %s", path, detected)
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:70, in ZarrLocation.__init_metadata(self)
66 def __init_metadata(self) -> None:
67 """
68 Load the Zarr metadata files for the given location.
69 """
---> 70 self.zarray: JSONDict = self.get_json(".zarray")
71 self.zgroup: JSONDict = self.get_json(".zgroup")
72 self.__metadata: JSONDict = {}
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/ome_zarr/io.py:157, in ZarrLocation.get_json(self, subpath)
149 """
150 Load and return a given subpath of store as JSON.
151
(...)
154 All other exceptions log at the ERROR level.
155 """
156 try:
--> 157 data = self.__store.get(subpath)
158 if not data:
159 return {}
File <frozen _collections_abc>:774, in get(self, key, default)
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/zarr/storage.py:1446, in FSStore.__getitem__(self, key)
1444 key = self._normalize_key(key)
1445 try:
-> 1446 return self.map[key]
1447 except self.exceptions as e:
1448 raise KeyError(key) from e
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/fsspec/mapping.py:155, in FSMap.__getitem__(self, key, default)
153 k = self._key_to_str(key)
154 try:
--> 155 result = self.fs.cat(k)
156 except self.missing_exceptions:
157 if default is not None:
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/fsspec/asyn.py:118, in sync_wrapper.<locals>.wrapper(*args, **kwargs)
115 @functools.wraps(func)
116 def wrapper(*args, **kwargs):
117 self = obj or args[0]
--> 118 return sync(self.loop, func, *args, **kwargs)
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/fsspec/asyn.py:91, in sync(loop, func, timeout, *args, **kwargs)
88 asyncio.run_coroutine_threadsafe(_runner(event, coro, result, timeout), loop)
89 while True:
90 # this loops allows thread to get interrupted
---> 91 if event.wait(1):
92 break
93 if timeout is not None:
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/threading.py:629, in Event.wait(self, timeout)
627 signaled = self._flag
628 if not signaled:
--> 629 signaled = self._cond.wait(timeout)
630 return signaled
File /opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/threading.py:331, in Condition.wait(self, timeout)
329 else:
330 if timeout > 0:
--> 331 gotit = waiter.acquire(True, timeout)
332 else:
333 gotit = waiter.acquire(False)
KeyboardInterrupt:
zebrahub_data[3].shape
# Let's choose a crop to work on
crop_3D = zebrahub_data[3][800, 0, :, :, :]
crop_3D.shape
Visualize in napari#
viewer = napari.Viewer()
scale = (9.92, 3.512, 3.512)
contrast_limits = (0, 372)
data_layer = viewer.add_image(crop_3D, scale=scale, contrast_limits=contrast_limits)
data_layer.bounding_box.visible = True
Extracting features#
def extract_features(image, feature_params):
features_func = partial(
multiscale_basic_features,
intensity=feature_params["intensity"],
edges=feature_params["edges"],
texture=feature_params["texture"],
sigma_min=feature_params["sigma_min"],
sigma_max=feature_params["sigma_max"],
channel_axis=None,
)
# print(f"image shape {image.shape} feature params {feature_params}")
features = features_func(np.squeeze(image))
return features
example_feature_params = {
"sigma_min": 1,
"sigma_max": 5,
"intensity": True,
"edges": True,
"texture": True,
}
features = extract_features(crop_3D, example_feature_params)
features.shape
Visualize Features#
What do these features we are extracting look like?
def show_features():
for feature_idx in range(features.shape[-1]):
viewer.add_image(features[:, :, :, feature_idx])
# show_features()
Making the Interactive Segmentation Tool!#
Ok, now we’ve seen:
our data
some features we can compute for our data
Our goal is to create an image where we have labels that correspond to the zebrafish sample.
The approach is that when we annotate/draw in our painting layer, then we want our segmentations to be updated automatically.
We will do this using 3 different image layers:
Our input image
A layer for painting
A layer for storing the machine learning generated predictions
Due to popular demand we will be using Zarr to store these layers, because that will help this approach scale to very large datasets. However, we could have used numpy arrays as well.
Create our painting and prediction layers#
zarr_path = os.path.join(user_data_dir("halfway_to_i2k_2023_america", "napari"), "diy_segmentation.zarr")
print(f"Saving outputs to zarr path: {zarr_path}")
# Create a prediction layer
prediction_data = zarr.open(
f"{zarr_path}/prediction",
mode='a',
shape=crop_3D.shape,
dtype='i4',
dimension_separator="/",
)
prediction_layer = viewer.add_labels(prediction_data, name="Prediction", scale=data_layer.scale)
# Create a painting layer
painting_data = zarr.open(
f"{zarr_path}/painting",
mode='a',
shape=crop_3D.shape,
dtype='i4',
dimension_separator="/",
)
painting_layer = viewer.add_labels(painting_data, name="Painting", scale=data_layer.scale)
painting_data.shape
Let’s make a UI as well#
from qtpy.QtWidgets import (
QVBoxLayout,
QHBoxLayout,
QComboBox,
QLabel,
QCheckBox,
QDoubleSpinBox,
QGroupBox,
QWidget,
)
class NapariMLWidget(QWidget):
def __init__(self, parent=None):
super(NapariMLWidget, self).__init__(parent)
self.initUI()
def initUI(self):
layout = QVBoxLayout()
# Dropdown for selecting the model
model_label = QLabel("Select Model")
self.model_dropdown = QComboBox()
self.model_dropdown.addItems(["Random Forest"])
model_layout = QHBoxLayout()
model_layout.addWidget(model_label)
model_layout.addWidget(self.model_dropdown)
layout.addLayout(model_layout)
# Select the range of sigma sizes
self.sigma_start_spinbox = QDoubleSpinBox()
self.sigma_start_spinbox.setRange(0, 256)
self.sigma_start_spinbox.setValue(1)
self.sigma_end_spinbox = QDoubleSpinBox()
self.sigma_end_spinbox.setRange(0, 256)
self.sigma_end_spinbox.setValue(5)
sigma_layout = QHBoxLayout()
sigma_layout.addWidget(QLabel("Sigma Range: From"))
sigma_layout.addWidget(self.sigma_start_spinbox)
sigma_layout.addWidget(QLabel("To"))
sigma_layout.addWidget(self.sigma_end_spinbox)
layout.addLayout(sigma_layout)
# Boolean options for features
self.intensity_checkbox = QCheckBox("Intensity")
self.intensity_checkbox.setChecked(True)
self.edges_checkbox = QCheckBox("Edges")
self.texture_checkbox = QCheckBox("Texture")
self.texture_checkbox.setChecked(True)
features_group = QGroupBox("Features")
features_layout = QVBoxLayout()
features_layout.addWidget(self.intensity_checkbox)
features_layout.addWidget(self.edges_checkbox)
features_layout.addWidget(self.texture_checkbox)
features_group.setLayout(features_layout)
layout.addWidget(features_group)
# Dropdown for data selection
data_label = QLabel("Select Data for Model Fitting")
self.data_dropdown = QComboBox()
self.data_dropdown.addItems(
["Current Displayed Region", "Whole Image"]
)
self.data_dropdown.setCurrentText("Current Displayed Region")
data_layout = QHBoxLayout()
data_layout.addWidget(data_label)
data_layout.addWidget(self.data_dropdown)
layout.addLayout(data_layout)
# Checkbox for live model fitting
self.live_fit_checkbox = QCheckBox("Live Model Fitting")
self.live_fit_checkbox.setChecked(True)
layout.addWidget(self.live_fit_checkbox)
# Checkbox for live prediction
self.live_pred_checkbox = QCheckBox("Live Prediction")
self.live_pred_checkbox.setChecked(True)
layout.addWidget(self.live_pred_checkbox)
self.setLayout(layout)
# Let's add this widget to napari
widget = NapariMLWidget()
viewer.window.add_dock_widget(widget, name="halfway to I2K 2023 America")
We have a widget, we have our painting and prediction layers, now what?#
We need to start connecting things together. How should we do that? napari has things called “events” that happen when things happen within napari. We want to respond to a few different event types:
changes in camera (e.g. camera position and rotation)
changes in “dims” (e.g. moving a dimension slider)
painting events (e.g. a user clicked, painted, and release their mouse)
When one of these events happens, we want to:
update our machine learning model with the new painted data
update our prediction with the updated ML model
# Let's start with our event listener
# We use "curry" because this allows us to "store" our viewer and widget for later use
@tz.curry
def on_data_change(event, viewer=None, widget=None):
corner_pixels = data_layer.corner_pixels
# Ensure the painting layer visual is updated
painting_layer.refresh()
# Training the ML model and generating predictions can take time
# we will use a "thread" to perform these calculations
# otherwise napari will freeze until these
calculations are done
thread = threading.Thread(
target=threaded_on_data_change,
args=(
event,
corner_pixels,
viewer.dims,
widget.model_dropdown.currentText(),
{
"sigma_min": widget.sigma_start_spinbox.value(),
"sigma_max": widget.sigma_end_spinbox.value(),
"intensity": widget.intensity_checkbox.isChecked(),
"edges": widget.edges_checkbox.isChecked(),
"texture": widget.texture_checkbox.isChecked(),
},
widget.live_fit_checkbox.isChecked(),
widget.live_pred_checkbox.isChecked(),
widget.data_dropdown.currentText(),
),
)
thread.start()
thread.join()
# Ensure the prediction layer visual is updated
prediction_layer.refresh()
# Now we have to make the hard part of the listener
model = None
def threaded_on_data_change(
event,
corner_pixels,
dims,
model_type,
feature_params,
live_fit,
live_prediction,
data_choice,
):
global model
LOGGER.info(f"Labels data has changed! {event}")
current_step = dims.current_step
# Find a mask of indices we will use for fetching our data
mask_idx = (slice(viewer.dims.current_step[0], viewer.dims.current_step[0]+1), slice(corner_pixels[0, 1], corner_pixels[1, 1]), slice(corner_pixels[0, 2], corner_pixels[1, 2]))
if data_choice == "Whole Image":
mask_idx = tuple([slice(0, sz) for sz in data_layer.data.shape])
LOGGER.info(f"mask idx {mask_idx}, image {data_layer.data.shape}")
active_image = data_layer.data[mask_idx]
LOGGER.info(
f"active image shape {active_image.shape} data choice {data_choice} painting_data {painting_data.shape} mask_idx {mask_idx}"
)
active_labels = painting_data[mask_idx]
def compute_features(image, feature_params):
"""Compute features for each channel and concatenate them."""
features = extract_features(
image, feature_params
)
return features
training_labels = None
if data_choice == "Current Displayed Region":
# Use only the currently displayed region.
training_features = compute_features(
active_image, feature_params
)
training_labels = np.squeeze(active_labels)
else:
raise ValueError(f"Invalid data choice: {data_choice}")
if (training_labels is None) or np.any(training_labels.shape == 0):
LOGGER.info("No training data yet. Skipping model update")
elif live_fit:
# Retrain model
LOGGER.info(
f"training model with labels {training_labels.shape} features {training_features.shape} unique labels {np.unique(training_labels[:])}"
)
model = update_model(training_labels, training_features, model_type)
# Don't do live prediction on whole image, that happens earlier slicewise
if live_prediction:
# Update prediction_data
prediction_features = compute_features(
active_image, feature_params
)
# Add 1 becasue of the background label adjustment for the model
prediction = predict(model, prediction_features, model_type)
LOGGER.info(
f"prediction {prediction.shape} prediction layer {prediction_layer.data.shape} prediction {np.transpose(prediction).shape} features {prediction_features.shape}"
)
if data_choice == "Whole Image":
prediction_layer.data[mask_idx] = np.transpose(prediction)
else:
prediction_layer.data[mask_idx] = np.transpose(prediction)[
np.newaxis, :
]
# Model training function that respects widget's model choice
def update_model(labels, features, model_type):
features = features[labels > 0, :]
# We shift labels - 1 because background is 0 and has special meaning, but models need to start at 0
labels = labels[labels > 0] - 1
if model_type == "Random Forest":
clf = RandomForestClassifier(
n_estimators=50, n_jobs=-1, max_depth=10, max_samples=0.05
)
print(
f"updating model with label shape {labels.shape} feature shape {features.shape} unique labels {np.unique(labels)}"
)
clf.fit(features, labels)
return clf
def predict(model, features, model_type):
# We shift labels + 1 because background is 0 and has special meaning
prediction = future.predict_segmenter(features.reshape(-1, features.shape[-1]), model).reshape(features.shape[:-1]) + 1
return np.transpose(prediction)
# Now connect everything together
for listener in [
viewer.camera.events,
viewer.dims.events,
painting_layer.events.paint,
]:
listener.connect(
debounced(
ensure_main_thread(
on_data_change(
viewer=viewer,
widget=widget, # pass the widget instance for easy access to settings
)
),
timeout=1000,
)
)