# -----------------------------------------------------------------------------.
# MIT License
# Copyright (c) 2024-2026 ximage developers
#
# This file is part of ximage.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----------------------------------------------------------------------------.
"""Labels identification."""
# import dask_image.ndmeasure
# from dask_image.ndmeasure import as dask_label_image
import dask.array
import dask_image.ndmeasure
import numpy as np
import xarray as xr
from skimage.measure import label as label_image
from skimage.morphology import dilation as skimage_dilation
from skimage.morphology import disk
from ximage.utils.checks import are_all_natural_numbers
# TODO:
# - Enable to label in n-dimensions
# - (2D+VERTICAL) --> CORE PROFILES
# - (2D+TIME) --> TRACKING
####--------------------------------------------------------------------------.
def _binary_dilation(mask, footprint):
mask = skimage_dilation(mask, footprint=footprint)
return mask # noqa: RET504
def _mask_buffer(mask, footprint):
"""Dilate the mask by n pixel in all directions.
If footprint = 0 or None, no dilation occur.
If footprint is a positive integer, it create a disk(footprint)
If footprint is a 2D array, it must represent the neighborhood expressed
as a 2-D array of 1's and 0's.
For more info: https://scikit-image.org/docs/stable/api/skimage.morphology.html#skimage.morphology.binary_dilation
"""
# scikitimage > 0.19
if not isinstance(footprint, (int, np.ndarray, type(None))):
raise TypeError("`footprint` must be an integer, numpy 2D array or None.")
if isinstance(footprint, np.ndarray) and footprint.ndim != 2:
raise ValueError("If providing the footprint for dilation as np.array, it must be 2D.")
if isinstance(footprint, int):
if footprint < 0:
raise ValueError("Footprint must be equal or larger than 1.")
footprint = None if footprint == 0 else disk(radius=footprint)
# Apply dilation
if footprint is not None:
mask = _binary_dilation(mask, footprint=footprint)
return mask
def _check_array(arr):
"""Check array and return a numpy.ndarray."""
shape = arr.shape
if len(shape) != 2:
raise ValueError("Expecting a 2D array.")
if np.any(np.array(shape) == 0):
raise ValueError("Expecting non-zero dimensions.")
# Convert to numpy array
return np.asanyarray(arr)
def _no_labels_result(arr, return_labels_stats):
"""Define results for array without labels."""
labels = np.zeros(arr.shape)
n_labels = 0
values = []
if return_labels_stats:
return labels, n_labels, values
return labels
def _check_sort_by(sort_by):
"""Check ``sort_by`` argument."""
if not (callable(sort_by) or isinstance(sort_by, str)):
raise TypeError("'sort_by' must be a string or a function.")
if isinstance(sort_by, str):
valid_stats = [
"area",
"maximum",
"minimum",
"mean",
"median",
"sum",
"standard_deviation",
"variance",
]
if sort_by not in valid_stats:
raise ValueError(f"Valid 'sort_by' values are: {valid_stats}.")
def _check_stats(stats):
"""Check ``stats`` argument."""
if not (callable(stats) or isinstance(stats, str)):
raise TypeError("'stats' must be a string or a function.")
if isinstance(stats, str):
valid_stats = [
"area",
"maximum",
"minimum",
"mean",
"median",
"sum",
"standard_deviation",
"variance",
]
if stats not in valid_stats:
raise ValueError(f"Valid 'stats' values are: {valid_stats}.")
# TODO: check stats function works on a dummy array (reduce to single value)
return stats
def _get_label_value_stats(arr, label_arr, label_indices=None, stats="area", labeled_comprehension_kwargs=None):
"""Compute label value statistics over which to later sort on.
If ``label_indices`` is None, by default would return the stats of the entire array.
If ``label_indices`` is 0, return ``np.nan``.
If ``label_indices`` is not inside ``label_arr``, return 0.
"""
# Check stats argument and label indices
if labeled_comprehension_kwargs is None:
labeled_comprehension_kwargs = {}
stats = _check_stats(stats)
if label_indices is None:
label_indices = np.unique(label_arr)
# Compute labels stats values
if callable(stats):
labeled_comprehension_kwargs.setdefault("out_dtype", float)
labeled_comprehension_kwargs.setdefault("default", None)
labeled_comprehension_kwargs.setdefault("pass_positions", False)
values = dask_image.ndmeasure.labeled_comprehension(
image=arr,
label_image=label_arr,
index=label_indices,
func=stats,
**labeled_comprehension_kwargs,
)
else:
func = getattr(dask_image.ndmeasure, stats)
values = func(image=arr, label_image=label_arr, index=label_indices)
# Compute values
return values.compute()
def _get_labels_stats(
arr,
label_arr,
label_indices=None,
stats="area",
sort_decreasing=True,
labeled_comprehension_kwargs=None,
):
"""Return label and label statistics sorted by statistic value."""
if labeled_comprehension_kwargs is None:
labeled_comprehension_kwargs = {}
if label_indices is None:
label_indices = np.unique(label_arr)
# Get labels area values
values = _get_label_value_stats(
arr,
label_arr=label_arr,
label_indices=label_indices,
stats=stats,
labeled_comprehension_kwargs=labeled_comprehension_kwargs,
)
# Get sorting index based on values
sort_index = np.argsort(values)[::-1] if sort_decreasing else np.argsort(values)
# Sort values
values = values[sort_index]
label_indices = label_indices[sort_index]
return label_indices, values
def _vec_translate(arr, my_dict):
"""Remap array <value> based on the dictionary key-value pairs.
This function is used to redefine label array integer values based on the
label area_size/max_intensity value.
"""
# TODO: Remove keys not in arr to speed up maybe
return np.vectorize(my_dict.__getitem__)(arr)
def _get_labels_with_requested_occurrence(label_arr, vmin, vmax):
"""Get label indices with requested occurrence."""
# Compute label occurrence
label_indices, label_occurrence = np.unique(label_arr, return_counts=True)
# Remove label 0 and associate pixel count if present
if label_indices[0] == 0:
label_indices = label_indices[1:]
label_occurrence = label_occurrence[1:]
# Get index with required occurrence
valid_area_indices = np.where(np.logical_and(label_occurrence >= vmin, label_occurrence <= vmax))[0]
# Return list of valid label indices
return label_indices[valid_area_indices] if len(valid_area_indices) > 0 else []
def _ensure_valid_label_arr(label_arr):
"""Ensure ``label_arr`` does contain only positive values.
NaN values are converted to 0.
The output array type is int.
"""
# Ensure data are numpy
label_arr = np.asanyarray(label_arr)
# Set NaN to 0
label_arr[np.isnan(label_arr)] = 0
# Check that label arr values are positive integers
if not are_all_natural_numbers(label_arr.flatten(), zero_allowed=True):
raise ValueError("The label array must contain only positive integers.")
# Ensure label array is integer dtype
return label_arr.astype(int)
def _ensure_valid_label_indices(label_indices):
"""Ensure valid label indices are integers and does not contains 0 and NaN."""
label_indices = np.delete(label_indices, np.where(label_indices == 0)[0].flatten())
label_indices = np.delete(label_indices, np.where(np.isnan(label_indices))[0].flatten())
return label_indices.astype(int)
[docs]
def get_label_indices(arr):
"""Get label indices from numpy.ndarray, dask.Array and xarray.DataArray.
It removes 0 and ``np.NaN`` values. Output type is ``int``.
"""
arr = np.asanyarray(arr)
arr = arr[~np.isnan(arr)]
arr = arr.astype(int) # otherwise precision error in unique
label_indices = np.unique(arr)
return _ensure_valid_label_indices(label_indices)
def _check_unique_label_indices(label_indices):
_, c = np.unique(label_indices, return_counts=True)
if np.any(c > 1):
raise ValueError("'label_indices' must be uniques.")
def _get_new_label_value_dict(label_indices, max_label):
"""Create dictionary mapping from current label value to new label value."""
# Initialize dictionary with keys corresponding to all possible labels indices
val_dict = dict.fromkeys(range(0, max_label + 1), 0)
# Update the dictionary keys with the selected label_indices
# - Assume 0 not in label_indices
n_labels = len(label_indices)
label_indices = label_indices.tolist()
label_indices_new = np.arange(1, n_labels + 1, dtype=int).tolist()
val_dict.update(dict(zip(label_indices, label_indices_new, strict=True)))
return val_dict
def _np_redefine_label_array(label_arr, label_indices=None):
"""Relabel a numpy/dask array from 0 to len(label_indices)."""
# Ensure data are numpy
label_arr = np.asanyarray(label_arr)
if label_indices is None:
label_indices = np.unique(label_arr)
else:
_check_unique_label_indices(label_indices)
# Ensure label indices are integer, without 0 and NaN
label_indices = _ensure_valid_label_indices(label_indices)
# Ensure label array values are integer
label_arr = _ensure_valid_label_arr(label_arr) # output is int, without NaN
# Check there are label_indices
if len(label_indices) == 0:
raise ValueError("No labels available.")
# Compute max label index
max_label = max(label_indices)
# Set to 0 labels in label_arr larger than max_label
# - These are some of the labels that were set to 0 because of mask or area filtering
label_arr[label_arr > max_label] = 0
# Initialize dictionary with keys corresponding to all possible labels indices
val_dict = _get_new_label_value_dict(label_indices, max_label)
# Redefine the id of the labels
return _vec_translate(label_arr, val_dict)
def _xr_redefine_label_array(dataarray, label_indices=None):
"""Relabel a xarray.DataArray from 0 to len(label_indices)."""
relabeled_arr = _np_redefine_label_array(dataarray.data, label_indices=label_indices)
da_label = dataarray.copy()
da_label.data = relabeled_arr
return da_label
[docs]
def redefine_label_array(data, label_indices=None):
"""Redefine labels of a label array from 0 to len(label_indices).
If ``label_indices`` is ``None``, it takes the unique values of ``label_arr``.
If ``label_indices`` contains a 0, it is discarded !
If ``label_indices`` is not unique, raise an error !
Native label values not present in label_indices are set to 0.
The first label in ``label_indices`` becomes 1, the second 2, and so on.
"""
if isinstance(data, xr.DataArray):
return _xr_redefine_label_array(data, label_indices=label_indices)
if isinstance(data, (np.ndarray, dask.array.Array)):
return _np_redefine_label_array(data, label_indices=label_indices)
raise TypeError(f"This method does not accept {type(data)}")
[docs]
def get_data_array(xr_obj, variable=None):
"""Check xarray object and variable validity."""
# Check inputs
if not isinstance(xr_obj, (xr.Dataset, xr.DataArray)):
raise TypeError("'xr_obj' must be a xr.Dataset or xr.DataArray.")
if isinstance(xr_obj, xr.Dataset):
# Check valid variable is specified
if variable is None:
raise ValueError("An xr.Dataset 'variable' must be specified.")
if variable not in xr_obj.data_vars:
raise ValueError(f"'{variable}' is not a variable of the xr.Dataset.")
elif variable is not None:
raise ValueError("'variable' must not be specified when providing a xr.DataArray.")
# Return DataArray
return xr_obj[variable] if isinstance(xr_obj, xr.Dataset) else xr_obj
[docs]
def check_core_dims(core_dims, data_array):
"""Check core_dims argument and infer if needed."""
# Infer core_dims if 2D array
if data_array.ndim == 2:
core_dims = tuple(data_array.dims) if core_dims is None else tuple(core_dims)
# Otherwise should be specified
else:
if core_dims is None:
raise ValueError(
"For DataArray with ndim > 2, `core_dims` must be specified.",
)
core_dims = tuple(core_dims)
# Check core_dims are two (currently) !
if len(core_dims) != 2:
raise ValueError("`core_dims` must contain exactly two dimensions. 3D-array labelling not yet implemented.")
# Check valid core_dims
missing = set(core_dims) - set(data_array.dims)
if missing:
raise ValueError(
f"`core_dims` {core_dims} are not all dimensions of the DataArray. " f"Missing: {missing}",
)
return core_dims
def _get_labels(
arr,
min_value_threshold=-np.inf,
max_value_threshold=np.inf,
min_area_threshold=1,
max_area_threshold=np.inf,
footprint=None,
sort_by="area",
sort_decreasing=True,
labeled_comprehension_kwargs=None,
return_labels_stats=True,
):
"""
Function deriving the labels array and associated labels info.
Parameters
----------
arr : numpy.ndarray
Array to be labelled.
min_value_threshold : float, optional
The minimum value to define the interior of a label.
The default is -np.inf.
max_value_threshold : float, optional
The maximum value to define the interior of a label.
The default is np.inf.
min_area_threshold : float, optional
The minimum number of connected pixels to be defined as a label.
The default is 1.
max_area_threshold : float, optional
The maximum number of connected pixels to be defined as a label.
The default is np.inf.
footprint : (int, numpy.ndarray or None), optional
This argument enables to dilate the mask derived after applying
min_value_threshold and max_value_threshold.
If footprint = 0 or None, no dilation occur.
If footprint is a positive integer, it create a disk(footprint)
If footprint is a 2D array, it must represent the neighborhood expressed
as a 2-D array of 1's and 0's.
The default is None (no dilation).
sort_by : (callable or str), optional
A function or statistics to define the order of the labels.
Valid string statistics are "area", "maximum", "minimum", "mean",
"median", "sum", "standard_deviation", "variance".
The default is "area".
sort_decreasing : bool, optional
If True, sort labels by decreasing 'sort_by' value.
The default is True.
labeled_comprehension_kwargs : dict, optional
Additional arguments to be passed to dask_image.ndmeasure.labeled_comprehension
if sort_by is a callable. May contain
out_dtype : dtype, optional
Dtype to use for result.
The default is float.
default : (int, float or None), optional
Default return value when a element of index does not exist in the label array.
The default is None.
pass_positions : bool, optional
If True, pass linear indices to 'sort_by' as a second argument.
The default is False.
The default is {}.
return_labels_stats: bool
Whether to return label statistics. The default is True.
If False, it returns just the labelled array.
Returns
-------
labels_arr, numpy.ndarray
Label array. 0 values corresponds to no label.
n_labels, int
Number of labels in the labels array.
values, numpy.arrays
Array of length n_labels with the stats values associated to each label.
"""
# ---------------------------------.
# TODO: this could be extended to work with dask >2D array
# - dask_image.ndmeasure.label https://image.dask.org/en/latest/dask_image.ndmeasure.html
# - dask_image.ndmorph.binary_dilation https://image.dask.org/en/latest/dask_image.ndmorph.html#dask_image.ndmorph.binary_dilation
# ---------------------------------.
# Check array validity
if labeled_comprehension_kwargs is None:
labeled_comprehension_kwargs = {}
arr = _check_array(arr)
# ---------------------------------.
# Define masks
# - mask_native: True when between min and max thresholds
# - mask_nan: True where is not finite (inf or nan)
mask_native = np.logical_and(arr >= min_value_threshold, arr <= max_value_threshold)
mask_nan = ~np.isfinite(arr)
# ---------------------------------.
# Dilate (buffer) the native mask
# - This enable to assign closely connected mask_native areas to the same label
mask = _mask_buffer(mask_native, footprint=footprint)
# ---------------------------------.
# Get area labels
# - 0 represent the outer area
label_arr = label_image(mask) # 0.977-1.37 ms
# mask = mask.astype(int)
# labels, num_features = dask_label_image(mask) # THIS WORK in n-dimensions
# %time labels = labels.compute() # 5-6.5 ms
# ---------------------------------.
# Count initial label occurrence
label_indices = np.unique(label_arr, return_counts=False)
n_initial_labels = len(label_indices)
if n_initial_labels == 1: # only 0 label
return _no_labels_result(arr, return_labels_stats=return_labels_stats)
# ---------------------------------.
# Set areas outside the mask_native to label value 0
label_arr[~mask_native] = 0
# Set NaN pixels to label value 0
label_arr[mask_nan] = 0
# ---------------------------------.
# Filter label by area
label_indices = _get_labels_with_requested_occurrence(
label_arr=label_arr,
vmin=min_area_threshold,
vmax=max_area_threshold,
)
if len(label_indices) == 0:
return _no_labels_result(arr, return_labels_stats=return_labels_stats)
# ---------------------------------.
# Sort labels by statistics (i.e. label area, label max value ...)
label_indices, values = _get_labels_stats(
arr=arr,
label_arr=label_arr,
label_indices=label_indices,
stats=sort_by,
sort_decreasing=sort_decreasing,
labeled_comprehension_kwargs=labeled_comprehension_kwargs,
)
# ---------------------------------.
# TODO: optionally here calculate a list of label_stats
# --> values would be a n_label_stats x n_labels array !
# --> dask_image.ndmeasure.labeled_comprehension
# ---------------------------------.
# Relabel labels array (from 1 to n_labels)
labels_arr = redefine_label_array(label_arr, label_indices=label_indices)
n_labels = len(label_indices)
# ---------------------------------.
# Return results
if return_labels_stats:
return labels_arr, n_labels, values
return labels_arr
[docs]
def label(
xr_obj,
*,
variable=None,
core_dims=None,
min_value_threshold=-np.inf,
max_value_threshold=np.inf,
min_area_threshold=1,
max_area_threshold=np.inf,
footprint=None,
sort_by="area",
sort_decreasing=True,
labeled_comprehension_kwargs=None,
label_name="label",
):
"""
Compute labels and and add as a coordinates to an xarray object.
Parameters
----------
xr_obj : xarray.DataArray or xarray.Dataset
xarray object.
variable : str, optional
Dataset variable to exploit to derive the labels array.
Must be specified only if the input object is an `xarray.Dataset`.
core_dims : tuple of str, optional
Names of the two dimensions along which the labeling is applied.
If the xarray DataArray is two-dimensional and ``core_dims`` is not provided,
the core dimensions are inferred automatically from DataArray.dims.
If the xarray DataArray has more than two dimensions, ``core_dims`` must be
specified explicitly. In this case, labeling is applied independently
over all remaining (non-core) dimensions.
Example: for a 3D DataArray with dimensions ``(x, y, time)``,
use ``core_dims=("x", "y")`` to apply labeling to each timestep.
min_value_threshold : float, optional
The minimum value to define the interior of a label.
The default is ``-np.inf``.
max_value_threshold : float, optional
The maximum value to define the interior of a label.
The default is ``np.inf``.
min_area_threshold : float, optional
The minimum number of connected pixels to be defined as a label.
The default is 1.
max_area_threshold : float, optional
The maximum number of connected pixels to be defined as a label.
The default is ``np.inf``.
footprint : int, numpy.ndarray or None, optional
This argument enables to dilate the mask derived after applying
min_value_threshold and max_value_threshold.
If ``footprint = 0`` or ``None``, no dilation occur.
If ``footprint`` is a positive integer, it create a ``disk(footprint)``
If ``footprint`` is a 2D array, it must represent the neighborhood expressed
as a 2-D array of 1's and 0's.
The default is ``None`` (no dilation).
sort_by : callable or str, optional
A function or statistics to define the order of the labels.
Valid string statistics are ``"area"``, ``"maximum"``, ``"minimum"``, ``"mean"``,
``"median"``, ``"sum"``, ``"standard_deviation"``, ``"variance"``.
The default is ``"area"``.
sort_decreasing : bool, optional
If ``True``, sort labels by decreasing ``sort_by`` value.
The default is ``True``.
labeled_comprehension_kwargs : dict, optional
Additional arguments to be passed to `dask_image.ndmeasure.labeled_comprehension`.
if ``sort_by`` is a callable.
Returns
-------
xr_obj : (xarray.DataArray or xarray.Dataset)
xarray object with the new label coordinate.
In the label coordinate, non-labels values are set to np.nan.
"""
# Check xarray input
if labeled_comprehension_kwargs is None:
labeled_comprehension_kwargs = {}
# Retrieve datarray to label
data_array = get_data_array(xr_obj=xr_obj, variable=variable)
# Check arguments
_check_sort_by(sort_by)
core_dims = check_core_dims(core_dims, data_array)
# Define kwargs
kwargs = {
"min_value_threshold": min_value_threshold,
"max_value_threshold": max_value_threshold,
"min_area_threshold": min_area_threshold,
"max_area_threshold": max_area_threshold,
"footprint": footprint,
"sort_by": sort_by,
"sort_decreasing": sort_decreasing,
"labeled_comprehension_kwargs": labeled_comprehension_kwargs,
"return_labels_stats": False,
}
# Apply over non-core dimensions
da_labels = xr.apply_ufunc(
_get_labels,
data_array,
kwargs=kwargs,
input_core_dims=[list(core_dims)],
output_core_dims=[list(core_dims)],
vectorize=True,
dask="parallelized",
output_dtypes=[float],
dask_gufunc_kwargs={"output_sizes": {"parameters": 3}},
)
# If input array was in memory compute labels
if hasattr(data_array, "chunks"):
da_labels = da_labels.compute()
if da_labels.max() == 0:
raise ValueError("No labels identified. You might want to change the labeling parameters.")
# Conversion to DataArray if needed
da_labels.name = f"labels_{sort_by}"
da_labels.attrs = {}
# Set labels values == 0 to np.nan (useful for plotting)
da_labels = da_labels.where(da_labels > 0)
# Assign label to xr.DataArray coordinate
return xr_obj.assign_coords({label_name: da_labels})
[docs]
def highlight_label(xr_obj, label_name, label_id):
"""Set all labels values to 0 except for 'label_id'."""
xr_obj = xr_obj.copy(deep=True) # required otherwise overwrite original data
label_arr = xr_obj[label_name].data
label_arr[label_arr != label_id] = 0
xr_obj[label_name].data = label_arr
return xr_obj