Source code for ximage.patch.labels_patch

# -----------------------------------------------------------------------------.
# MIT License

# Copyright (c) 2024-2026 ximage developers
#
# This file is part of ximage.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# -----------------------------------------------------------------------------.
"""Functions to extract patch around labels."""
import random
import warnings
from collections.abc import Callable

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from ximage.labels.labels import highlight_label
from ximage.patch.checks import (
    are_all_natural_numbers,
    check_buffer,
    check_kernel_size,
    check_padding,
    check_partitioning_method,
    check_patch_size,
    check_stride,
)
from ximage.patch.plot2d import plot_label_patch_extraction_areas
from ximage.patch.slices import (
    enlarge_slices,
    get_nd_partitions_list_slices,
    get_slice_around_index,
    get_slice_from_idx_bounds,
    pad_slices,
)

# -----------------------------------------------------------------------------.
#### TODOs
## Partitioning
# - Option to bound min_start and max_stop to labels bbox
# - Option to define min_start and max_stop to be divisible by patch_size + stride
# - When tiling ... define start so to center tiles around label_bbox, instead of starting at label bbox start
# - Option: partition_only_when_label_bbox_exceed_patch_size

# - Add option that returns a flag if the point center is the actual identified one,
#   or was close to the boundary !

# -----------------------------------------------------------------------------.
# - Implement dilate option (to subset pixel within partitions).
#   --> slice(start, stop, step=dilate) ... with patch_size redefined at start to patch_size*dilate
#   --> Need updates of enlarge slcies, pad_slices utilities (but first test current usage !)

# -----------------------------------------------------------------------------.

## Image sliding/tiling reconstruction
# - get_index_overlapping_slices
# - trim: bool, keyword only
#   Whether or not to trim stride elements from each block after calling the map function.
#   Set this to False if your mapping function already does this for you.
#   This for when merging !

####--------------------------------------------------------------------------.


def _check_label_arr(label_arr):
    """Check label_arr."""
    # Note: If label array is all zero or nan, labels_id will be []

    # Put label array in memory
    label_arr = np.asanyarray(label_arr)

    # Set 0 label to nan
    label_arr = label_arr.astype(float)  # otherwise if int throw an error when assigning nan
    label_arr[label_arr == 0] = np.nan

    # Check labels_id are natural number >= 1
    valid_labels = np.unique(label_arr[~np.isnan(label_arr)])
    if not are_all_natural_numbers(valid_labels):
        raise ValueError("The label array contains non positive natural numbers.")

    return label_arr


def _check_labels_id(labels_id, label_arr):
    """Check labels_id."""
    # Check labels_id type
    if not isinstance(labels_id, (type(None), int, list, np.ndarray)):
        raise TypeError("labels_id must be None or a list or a np.array.")
    if isinstance(labels_id, int):
        labels_id = [labels_id]
    # Get list of valid labels
    valid_labels = np.unique(label_arr[~np.isnan(label_arr)]).astype(int)
    # If labels_id is None, assign the valid_labels
    if isinstance(labels_id, type(None)):
        return valid_labels
    # If input labels_id is a list, make it a np.array
    labels_id = np.array(labels_id).astype(int)
    # Check labels_id are natural number >= 1
    if np.any(labels_id == 0):
        raise ValueError("labels id must not contain the 0 value.")
    if not are_all_natural_numbers(labels_id):
        raise ValueError("labels id must be positive natural numbers.")
    # Check labels_id are number present in the label_arr
    invalid_labels = labels_id[~np.isin(labels_id, valid_labels)]
    if invalid_labels.size != 0:
        invalid_labels = invalid_labels.astype(int)
        raise ValueError(f"The following labels id are not valid: {invalid_labels}")
    # If no labels, no patch to extract
    n_labels = len(labels_id)
    if n_labels == 0:
        raise ValueError("No labels available.")
    return labels_id


def _check_n_patches_per_partition(n_patches_per_partition, centered_on):
    """
    Check the number of patches to extract from each partition.

    It is used only if centered_on is a callable or 'random'

    Parameters
    ----------
    n_patches_per_partition : int
        Number of patches to extract from each partition.
    centered_on : str or callable
        Method to extract the patch around a label point.

    Returns
    -------
    n_patches_per_partition: int
       The number of patches to extract from each partition.
    """
    if n_patches_per_partition < 1:
        raise ValueError("n_patches_per_partitions must be a positive integer.")
    if isinstance(centered_on, str) and centered_on not in ["random"] and n_patches_per_partition > 1:
        raise ValueError(
            "Only the pre-implemented centered_on='random' method allow n_patches_per_partition values > 1.",
        )
    return n_patches_per_partition


def _check_n_patches(n_patches):
    if n_patches is None:
        n_patches = np.inf
    if n_patches <= 0:
        raise ValueError("n_patches must be a positive integer.")
    return n_patches


def _check_n_patches_per_label(n_patches_per_label, n_patches_per_partition):
    if n_patches_per_label is None:
        n_patches_per_label = np.inf
    if n_patches_per_label <= 0:
        raise ValueError("n_patches_per_label must be a positive integer.")
    if n_patches_per_label < n_patches_per_partition:
        raise ValueError("n_patches_per_label must be equal or larger to n_patches_per_partition.")
    return n_patches_per_label


def _check_callable_centered_on(centered_on):
    """Check validity of callable centered_on."""
    input_shape = (2, 3)
    arr = np.zeros(input_shape)
    point = centered_on(arr)
    if not isinstance(point, (tuple, type(None))):
        raise ValueError("The 'centered_on' function should return a point coordinates tuple or None.")
    if len(point) != len(input_shape):
        raise ValueError(
            "The 'centered_on' function should return point coordinates having same dimensions has input array.",
        )
    for c, max_value in zip(point, input_shape, strict=True):
        if c < 0:
            raise ValueError("The point coordinate must be a positive integer.")
        if c >= max_value:
            raise ValueError("The point coordinate must be inside the array shape.")
        if np.isnan(c):
            raise ValueError("The point coordinate must not be np.nan.")
    # Check case with nan array
    try:
        point = centered_on(arr * np.nan)
    except Exception as err:
        raise ValueError(f"The 'centered_on' function should be able to deal with a np.nan ndarray. Error is {err}.")
    if point is not None:
        raise ValueError("The 'centered_on' function should return None if the input array is a np.nan ndarray.")


def _check_centered_on(centered_on):
    """Check valid centered_on to identify a point in an array."""
    if not (callable(centered_on) or isinstance(centered_on, str)):
        raise TypeError("'centered_on' must be a string or a function.")
    if isinstance(centered_on, str):
        valid_centered_on = [
            "max",
            "min",
            "centroid",
            "center_of_mass",
            "random",
            "label_bbox",  # unfixed patch_size
        ]
        if centered_on not in valid_centered_on:
            raise ValueError(f"Valid 'centered_on' values are: {valid_centered_on}.")

    if callable(centered_on):
        _check_callable_centered_on(centered_on)
    return centered_on


def _get_variable_arr(xr_obj, variable, centered_on):
    """Get variable array (in memory)."""
    if isinstance(xr_obj, xr.DataArray):
        return np.asanyarray(xr_obj.data)
    if centered_on is not None and variable is None and (centered_on in ["max", "min"] or callable(centered_on)):
        raise ValueError("'variable' must be specified if 'centered_on' is specified.")
    return np.asanyarray(xr_obj[variable].data) if variable is not None else None


def _check_variable_arr(variable_arr, label_arr):
    """Check variable array validity."""
    if variable_arr is not None and variable_arr.shape != label_arr.shape:
        raise ValueError("Arrays corresponding to 'variable' and 'label_name' must have same shape.")
    return variable_arr


def _get_point_centroid(arr):
    """Get the coordinate of label bounding box center.

    It assumes that the array has been cropped around the label.
    It returns None if all values are non-finite (i.e. np.nan).
    """
    if np.all(~np.isfinite(arr)):
        return None
    centroid = np.array(arr.shape) / 2.0
    return tuple(centroid.tolist())


def _get_point_random(arr):
    """Get random point with finite value."""
    is_finite = np.isfinite(arr)
    if np.all(~is_finite):
        return None
    points = np.argwhere(is_finite)
    return random.choice(points)


def _get_point_with_max_value(arr):
    """Get point with maximum value."""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        point = np.argwhere(arr == np.nanmax(arr))
    return None if len(point) == 0 else tuple(point[0].tolist())


def _get_point_with_min_value(arr):
    """Get point with minimum value."""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        point = np.argwhere(arr == np.nanmin(arr))
    return None if len(point) == 0 else tuple(point[0].tolist())


def _get_point_center_of_mass(arr, integer_index=True):
    """Get the coordinate of the label center of mass.

    It uses all cells which have finite values.
    If `0` value should be a non-label area, mask before with `np.nan`.
    It returns `None` if all values are non-finite (i.e. ``np.nan``).
    """
    indices = np.argwhere(np.isfinite(arr))
    if len(indices) == 0:
        return None
    center_of_mass = np.nanmean(indices, axis=0)
    if integer_index:
        center_of_mass = center_of_mass.round().astype(int)
    return tuple(center_of_mass.tolist())


[docs] def find_point(arr, centered_on: str | Callable = "max"): """Find a specific point coordinate of the array. If the coordinate can't be find, return ``None``. """ centered_on = _check_centered_on(centered_on) if centered_on == "max": point = _get_point_with_max_value(arr) elif centered_on == "min": point = _get_point_with_min_value(arr) elif centered_on == "centroid": point = _get_point_centroid(arr) elif centered_on == "center_of_mass": point = _get_point_center_of_mass(arr) elif centered_on == "random": point = _get_point_random(arr) else: # callable centered_on point = centered_on(arr) if point is not None: point = tuple(int(p) for p in point) return point
def _get_labels_bbox_slices(arr): """ Compute the bounding box slices of non-zero elements in a n-dimensional numpy array. Assume that only one unique non-zero elements values is present in the array. Assume that NaN and Inf have been replaced by zeros. Other implementations: scipy.ndimage.find_objects Parameters ---------- arr : numpy.ndarray n-dimensional numpy array. Returns ------- list_slices : list List of slices to extract the region with non-zero elements in the input array. """ # Return None if all values are zeros if not np.any(arr): return None ndims = arr.ndim coords = np.nonzero(arr) return [get_slice_from_idx_bounds(np.min(coords[i]), np.max(coords[i])) for i in range(ndims)] def _get_patch_list_slices_around_label_point( label_arr, label_id, variable_arr, patch_size, centered_on, ): """Get list_slices to extract patch around a label point. Assume ``label_arr`` must match ``variable_arr`` shape. Assume ``patch_size`` shape must match ``variable_arr`` shape . """ # Subset variable_arr around label list_slices = _get_labels_bbox_slices(label_arr == label_id) if list_slices is None: return None label_subset_arr = label_arr[tuple(list_slices)] variable_subset_arr = variable_arr[tuple(list_slices)] variable_subset_arr = np.asarray(variable_subset_arr) # if dask, make numpy # Mask variable arr outside the label variable_subset_arr[label_subset_arr != label_id] = np.nan # Find point of subset array point_subset_arr = find_point(arr=variable_subset_arr, centered_on=centered_on) # Define patch list_slices if point_subset_arr is not None: # Find point in original array point = [slc.start + c for slc, c in zip(list_slices, point_subset_arr, strict=True)] # Find patch list slices patch_list_slices = [ get_slice_around_index(p, size=size, min_start=0, max_stop=shape) for p, size, shape in zip(point, patch_size, variable_arr.shape, strict=True) ] # TODO: also return a flag if the p midpoint is conserved (by +/- 1) or not else: patch_list_slices = None return patch_list_slices def _get_patch_list_slices_around_label(label_arr, label_id, padding, min_patch_size): """Get list_slices to extract patch around a label.""" # Get label bounding box slices list_slices = _get_labels_bbox_slices(label_arr == label_id) if list_slices is None: return None # Apply padding to the slices list_slices = pad_slices(list_slices, padding=padding, valid_shape=label_arr.shape) # Increase slices to match min_patch_size return enlarge_slices(list_slices, min_size=min_patch_size, valid_shape=label_arr.shape) def _get_patch_list_slices(label_arr, label_id, variable_arr, patch_size, centered_on, padding): """Get patch n-dimensional list slices.""" if not callable(centered_on) and centered_on == "label_bbox": list_slices = _get_patch_list_slices_around_label( label_arr=label_arr, label_id=label_id, padding=padding, min_patch_size=patch_size, ) else: list_slices = _get_patch_list_slices_around_label_point( label_arr=label_arr, label_id=label_id, variable_arr=variable_arr, patch_size=patch_size, centered_on=centered_on, ) return list_slices def _get_masked_arrays(label_arr, variable_arr, partition_list_slices): """Mask labels and variable arrays outside the partitions area.""" masked_partition_label_arr = np.zeros(label_arr.shape) * np.nan masked_partition_label_arr[tuple(partition_list_slices)] = label_arr[tuple(partition_list_slices)] if variable_arr is not None: masked_partition_variable_arr = np.zeros(variable_arr.shape) * np.nan masked_partition_variable_arr[tuple(partition_list_slices)] = variable_arr[tuple(partition_list_slices)] else: masked_partition_variable_arr = None return masked_partition_label_arr, masked_partition_variable_arr def _get_patches_from_partitions_list_slices( partitions_list_slices, label_arr, variable_arr, label_id, patch_size, centered_on, n_patches_per_partition, padding, verbose=False, ): """Return patches list slices from list of partitions `list_slices`. ``n_patches_per_partition`` is 1 unless ``centered_on`` is 'random' or a callable. """ patches_list_slices = [] for partition_list_slices in partitions_list_slices: if verbose: print(f" - partition: {partition_list_slices}") masked_label_arr, masked_variable_arr = _get_masked_arrays( label_arr=label_arr, variable_arr=variable_arr, partition_list_slices=partition_list_slices, ) n = 0 for n in range(n_patches_per_partition): patch_list_slices = _get_patch_list_slices( label_arr=masked_label_arr, variable_arr=masked_variable_arr, label_id=label_id, patch_size=patch_size, centered_on=centered_on, padding=padding, ) if patch_list_slices is not None and patch_list_slices not in patches_list_slices: n += 1 # noqa PLW2901 patches_list_slices.append(patch_list_slices) return patches_list_slices def _get_list_isel_dicts(patches_list_slices, dims): """Return a list with isel dictionaries.""" return [dict(zip(dims, patch_list_slices, strict=True)) for patch_list_slices in patches_list_slices] def _extract_xr_patch(xr_obj, isel_dict, label_name, label_id, highlight_label_id): """Extract a xarray patch.""" # Extract xarray patch around label xr_obj_patch = xr_obj.isel(isel_dict) # If asked, set label array to 0 except for label_id if highlight_label_id: xr_obj_patch = highlight_label(xr_obj_patch, label_name=label_name, label_id=label_id) return xr_obj_patch def _get_patches_isel_dict_generator( xr_obj, label_name, patch_size, variable=None, # Output options n_patches=None, n_labels=None, labels_id=None, grouped_by_labels_id=False, # (Tile) label patch extraction padding=0, centered_on="max", n_patches_per_label=None, n_patches_per_partition=1, debug=False, # Label Tiling/Sliding Options partitioning_method=None, n_partitions_per_label=None, kernel_size=None, buffer=0, stride=None, include_last=True, ensure_slice_size=True, verbose=False, ): # Get label array information label_arr = xr_obj[label_name].data dims = xr_obj[label_name].dims shape = label_arr.shape # Check input arguments if n_labels is not None and labels_id is not None: raise ValueError("Specify either n_labels or labels_id.") if kernel_size is None: kernel_size = patch_size patch_size = check_patch_size(patch_size, dims, shape) buffer = check_buffer(buffer, dims, shape) padding = check_padding(padding, dims, shape) partitioning_method = check_partitioning_method(partitioning_method) stride = check_stride(stride, dims, shape, partitioning_method) kernel_size = check_kernel_size(kernel_size, dims, shape) centered_on = _check_centered_on(centered_on) n_patches = _check_n_patches(n_patches) n_patches_per_partition = _check_n_patches_per_partition(n_patches_per_partition, centered_on) n_patches_per_label = _check_n_patches_per_label(n_patches_per_label, n_patches_per_partition) label_arr = _check_label_arr(label_arr) # output is np.array ! labels_id = _check_labels_id(labels_id=labels_id, label_arr=label_arr) variable_arr = _get_variable_arr(xr_obj, variable, centered_on) # if required variable_arr = _check_variable_arr(variable_arr, label_arr) # Define number of labels from which to extract patches available_n_labels = len(labels_id) n_labels = min(available_n_labels, n_labels) if n_labels else available_n_labels if verbose: print(f"Extracting patches from {n_labels} labels.") # -------------------------------------------------------------------------. # Extract patch(es) around the label patch_counter = 0 break_flag = False for i, label_id in enumerate(labels_id[0:n_labels]): if verbose: print(f"Label ID: {label_id} ({i}/{n_labels})") # Subset label_arr around the given label label_bbox_slices = _get_labels_bbox_slices(label_arr == label_id) # Apply padding to the label bounding box label_bbox_slices = pad_slices(label_bbox_slices, padding=padding.values(), valid_shape=label_arr.shape) # --------------------------------------------------------------------. # Retrieve partitions list_slices if partitioning_method is not None: partitions_list_slices = get_nd_partitions_list_slices( label_bbox_slices, arr_shape=label_arr.shape, method=partitioning_method, kernel_size=list(kernel_size.values()), stride=list(stride.values()), buffer=list(buffer.values()), include_last=include_last, ensure_slice_size=ensure_slice_size, ) if n_partitions_per_label is not None: n_to_select = min(len(partitions_list_slices), n_partitions_per_label) partitions_list_slices = partitions_list_slices[0:n_to_select] else: partitions_list_slices = [label_bbox_slices] # --------------------------------------------------------------------. # Retrieve patches list_slices from partitions list slices patches_list_slices = _get_patches_from_partitions_list_slices( partitions_list_slices=partitions_list_slices, label_arr=label_arr, variable_arr=variable_arr, label_id=label_id, patch_size=list(patch_size.values()), centered_on=centered_on, n_patches_per_partition=n_patches_per_partition, padding=list(padding.values()), verbose=verbose, ) # ---------------------------------------------------------------------. # Retrieve patches isel_dictionaries partitions_isel_dicts = _get_list_isel_dicts(partitions_list_slices, dims=dims) patches_isel_dicts = _get_list_isel_dicts(patches_list_slices, dims=dims) n_to_select = min(len(patches_isel_dicts), n_patches_per_label) patches_isel_dicts = patches_isel_dicts[0:n_to_select] # --------------------------------------------------------------------. # If debug=True, plot patches boundaries if debug and label_arr.ndim == 2: _ = plot_label_patch_extraction_areas( xr_obj, label_name=label_name, patches_isel_dicts=patches_isel_dicts, partitions_isel_dicts=partitions_isel_dicts, ) plt.show() # ---------------------------------------------------------------------. # Return isel_dicts if grouped_by_labels_id: patch_counter += 1 if patch_counter > n_patches: break_flag = True else: yield label_id, patches_isel_dicts else: for isel_dict in patches_isel_dicts: patch_counter += 1 if patch_counter > n_patches: break_flag = True else: yield label_id, isel_dict if break_flag: break # ---------------------------------------------------------------------.
[docs] def get_patches_isel_dict_from_labels( xr_obj, label_name, patch_size, variable=None, # Output options n_patches=None, n_labels=None, labels_id=None, # Label Patch Extraction Settings centered_on="max", padding=0, n_patches_per_label=None, n_patches_per_partition=1, # Label Tiling/Sliding Options partitioning_method=None, n_partitions_per_label=None, kernel_size=None, buffer=0, stride=None, include_last=True, ensure_slice_size=True, debug=False, verbose=False, ): """ Returnisel-dictionaries to extract patches around labels. The isel-dictionaries are grouped by ``label_id`` and returned in a dictionary. Please refer to ``ximage.patch.get_patches_from_labels`` for a detailed description of the function arguments. Return ------ dict A dictionary of the form: ``{label_id: list_isel_dicts}``. """ gen = _get_patches_isel_dict_generator( xr_obj=xr_obj, label_name=label_name, patch_size=patch_size, variable=variable, n_patches=n_patches, n_labels=n_labels, labels_id=labels_id, grouped_by_labels_id=True, # Patch extraction options centered_on=centered_on, padding=padding, n_patches_per_label=n_patches_per_label, n_patches_per_partition=n_patches_per_partition, # Tiling/Sliding settings partitioning_method=partitioning_method, n_partitions_per_label=n_partitions_per_label, kernel_size=kernel_size, buffer=buffer, stride=stride, include_last=include_last, ensure_slice_size=ensure_slice_size, debug=debug, verbose=verbose, ) return {int(label_id): list_isel_dicts for label_id, list_isel_dicts in gen}
[docs] def get_patches_from_labels( xr_obj, label_name, patch_size, variable=None, # Output options n_patches=None, n_labels=None, labels_id=None, highlight_label_id=True, # Label Patch Extraction Options centered_on="max", padding=0, n_patches_per_label=None, n_patches_per_partition=1, # Label Tiling/Sliding Options partitioning_method=None, n_partitions_per_label=None, kernel_size=None, buffer=0, stride=None, include_last=True, ensure_slice_size=True, debug=False, verbose=False, ): """ Routines to extract patches around labels. Create a generator extracting (from a prelabeled xarray.Dataset) a patch around: - a label point - a label bounding box If ``centered_on`` is specified, output patches are guaranteed to have equal shape ! If ``centered_on`` is not specified, output patches are guaranteed to have only have a minimum shape ! If you want to extract the patch around the label bounding box, ``centered_on`` must not be specified. If you want to extract the patch around a label point, the ``centered_on`` method must be specified. If the identified point is close to an array boundary, the patch is expanded toward the valid directions. Tiling or sliding enables to split/slide over each label and extract multiple patch for each tile. ``tiling=True`` - ``centered_on = "centroid"`` (tiling around labels bbox) - ``centered_on = "center_of_mass"`` (better coverage around label) ``sliding=True`` - ``centered_on = "center_of_mass"`` (better coverage around label) (further data coverage) Only one parameter between ``n_patches`` and ``labels_id`` can be specified. Parameters ---------- xr_obj : xarray.Dataset xarray.Dataset with a label array named ``label_name``. label_name : str Name of the variable/coordinate representing the label array. patch_size : int or tuple The dimensions of the n-dimensional patch to extract. Only positive values (>1) are allowed. The value -1 can be used to specify the full array dimension shape. If the ``centered_on`` method is not ``'label_bbox'``, all output patches are ensured to have the same shape. Otherwise, if ``centered_on='label_bbox'``, the ``patch_size`` argument defines defined the minimum n-dimensional shape of the output patches. If ``int``, the value is applied to all label array dimensions. If ``list`` or ``tuple``, the length must match the number of dimensions of the array. If a ``dict``, the dictionary must have has keys the label array dimensions. n_patches : int, optional Maximum number of patches to extract. The default (``None``) enable to extract all available patches allowed by the specified patch extraction criteria. labels_id : list, optional List of labels for which to extract the patch. If ``None``, it extracts the patches by label order ``(1, 2, 3, ...)`` The default is ``None``. n_labels : int, optional The number of labels for which extract patches. If ``None`` (the default), it extract patches for all labels. This argument can be specified only if ``labels_id`` is unspecified ! highlight_label_id : bool, optional If ``True``, the ``label_name`` array of each patch is modified to contain only the ``label_id`` used to select the patch. variable : str, optional Dataset variable to use to identify the patch center when centered_on is defined. This is required only for ``centered_on='max'``, ``centered_on='min'`` or the custom function. centered_on : str or callable, optional The centered_on method characterize the point around which the patch is extracted. Valid pre-implemented centered_on methods are ``'label_bbox'``, ``'max'``, ``'min'``, ``'centroid'``, ``'center_of_mass'``, ``'random'``. The default method is ``'max'``. If ``label_bbox`` it extract the patches around the (padded) bounding box of the label. If ``label_bbox``, the output patch sizes are only ensured to have a minimum ``patch_size``, and will likely be of different size. Otherwise, the other methods guarantee that the output patches have a common shape. If ``centered_on`` is ``'max'``, ``'min'`` or a custom function, the ``variable`` argument must be specified. If ``centered_on`` is a custom function, it must: - return ``None`` if all array values are non-finite (i.e ``np.nan``) - return a tuple with same length as the array shape. padding : int, tuple or dict, optional The padding to apply in each direction around a label prior to partitioning (tiling/sliding) or direct patch extraction. The default, 0, applies 0 padding in every dimension. Negative padding values are allowed ! If ``int``, the value is applied to all label array dimensions. If ``list`` or ``tuple``, the length must match the number of dimensions of the array. If a ``dict``, the dictionary must have has keys the label array dimensions. n_patches_per_label: int, optional The maximum number of patches to extract for each label. The default (``None``) enables to extract all the available patches per label. If specified, ``n_patches_per_label`` must be larger than ``n_patches_per_partition`` ! n_patches_per_partition, int, optional The maximum number of patches to extract from each label partition. The default values is 1. This method can be specified only if ``centered_on='random'`` or a callable. partitioning_method : str Whether to retrieve ``'tiling'`` or ``'sliding'`` slices. If ``'tiling'``, partition start slices are separated by ``stride`` + ``kernel_size``. If ``'sliding'``, partition start slices are separated by stride. n_partitions_per_label : int, optional The maximum number of partitions to extract for each label. The default (``None``) enables to extract all the available partitions per label. stride : int, tuple or dict, optional If ``partitioning_method = 'sliding'``, default ``stride`` is set to 1. If ``partitioning_method = 'tiling'``, default ``stride`` is set to 0. Step size between slices. When ``partitioning_method='tiling'``, a positive stride make partition slices to not overlap and not touch, while a negative stride make partition slices to overlap by ``stride`` amount. If ``stride=0``, the partition slices are contiguous (no spacing between partitions). When ``partitioning_method='sliding'``, only a positive stride (>= 1) is allowed. If ``int``, the value is applied to all label array dimensions. If ``list`` or ``tuple``, the length must match the number of dimensions of the array. If a ``dict``, the dictionary must have has keys the label array dimensions. kernel_size: int, tuple or dict, optional The shape of the desired partitions. Only positive values (>1) are allowed. The value ``-1`` can be used to specify the full array dimension shape. If ``int``, the value is applied to all label array dimensions. If ``list`` or ``tuple``, the length must match the number of dimensions of the array. If a ``dict``, the dictionary must have has keys the label array dimensions. buffer: int, tuple or dict, optional The default is ``0``. Value by which to enlarge a partition on each side. The final partition size should be ``kernel_size`` + ``buffer``. If ``partitioning_method='tiling'`` and ``stride=0``, a positive buffer value corresponds to the amount of overlap between each partition. Depending on ``min_start`` and ``max_stop`` values, buffering might cause border partitions to not have same sizes. If ``int``, the value is applied to all label array dimensions. If ``list`` or ``tuple``, the length must match the number of dimensions of the array. If a ``dict``, the dictionary must have has keys the label array dimensions. include_last : bool, optional Whether to include the last partition if it does not match the ``kernel_size``. The default is ``True``. ensure_slice_size : bool, optional Used only if include_last is ``True``. If ``False``, the last partition will not have the specified ``kernel_size``. If ``True``, the last partition is enlarged to the specified ``kernel_size`` by tentatively expanding it on both sides (accounting for ``min_start`` and ``max_stop``). Yields ------ (xarray.Dataset or xarray.DataArray) A xarray object patch. """ # Define patches isel dictionary generator patches_isel_dicts_gen = _get_patches_isel_dict_generator( xr_obj=xr_obj, label_name=label_name, patch_size=patch_size, variable=variable, n_patches=n_patches, n_labels=n_labels, labels_id=labels_id, grouped_by_labels_id=False, # Label Patch Extraction Options centered_on=centered_on, padding=padding, n_patches_per_label=n_patches_per_label, n_patches_per_partition=n_patches_per_partition, # Tiling/Sliding Options partitioning_method=partitioning_method, n_partitions_per_label=n_partitions_per_label, kernel_size=kernel_size, buffer=buffer, stride=stride, include_last=include_last, ensure_slice_size=ensure_slice_size, debug=debug, verbose=verbose, ) # Extract the patches for label_id, isel_dict in patches_isel_dicts_gen: xr_obj_patch = _extract_xr_patch( xr_obj=xr_obj, label_name=label_name, isel_dict=isel_dict, label_id=label_id, highlight_label_id=highlight_label_id, ) # Return the patch around the label yield label_id, xr_obj_patch