# -----------------------------------------------------------------------
# Copyright 2026 Martin Wieser
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------
"""Batch operations for mapping and projecting across multiple images."""
import logging
from typing import overload
import numpy as np
from pyproj import CRS
from weitsicht.exceptions import (
CoordinateTransformationError,
CRSInputError,
MapperMissingError,
MappingError,
NotGeoreferencedError,
)
from weitsicht.image.base_class import ImageBase
from weitsicht.mapping.base_class import MappingBase
from weitsicht.transform.coordinates_transformer import CoordinateTransformer
from weitsicht.utils import ArrayNx3, Issue, MappingResult, ProjectionResult, ResultFailure, to_array_nx3
__all__ = ["ImageBatch"]
logger = logging.getLogger(__name__)
[docs]
class ImageBatch:
"""Container for running operations on a set of named images."""
[docs]
def __init__(self, images: dict[str, ImageBase], mapper: MappingBase | None = None):
"""Initialize an image batch.
:param images: Dictionary of images.
:type images: dict[str, ImageBase]
:param mapper: Optional default mapper assigned to images without a mapper, defaults to ``None``.
:type mapper: MappingBase | None
"""
self.images: dict[str, ImageBase] = images
self.mapper: MappingBase | None = mapper
if self.mapper is not None:
for name, img in self.images.items():
if img.mapper is None:
self.images[name].mapper = self.mapper
def __len__(self):
"""Return the number of images in the batch."""
return len(self.images)
[docs]
def add_images(self, images: dict[str, ImageBase]):
"""Add images to the batch.
:param images: Images to add.
:type images: dict[str, ImageBase]
:raises KeyError: If any keys already exist in the batch.
"""
equal_keys = list(set(self.images.keys()).intersection(images.keys()))
if len(equal_keys) > 0:
raise KeyError("Identical keys can not be added")
if self.mapper is not None:
for name, img in images.items():
if img.mapper is None:
images[name].mapper = self.mapper
self.images.update(images)
def __setitem__(self, key, image: ImageBase):
"""Set an image in the batch by key.
:param key: Image key.
:param image: Image instance.
:type image: ImageBase
:raises ValueError: If the object is not an image instance.
"""
if image.__class__.__base__ != ImageBase:
raise ValueError("Only a image class can be assigned")
self.images[key] = image
@overload
def __getitem__(self, keys: str) -> ImageBase: ...
@overload
def __getitem__(self, keys: tuple[str] | list[str]) -> dict[str, ImageBase]: ...
def __getitem__(self, keys: str | tuple[str] | list[str]) -> ImageBase | dict[str, ImageBase]:
"""Get images by key or keys.
:param keys: Image key or sequence of keys.
:type keys: str | tuple[str] | list[str]
:return: Single image or a dictionary of images.
:rtype: ImageBase | dict[str, ImageBase]
:raises KeyError: If a requested key is not found.
:raises TypeError: If ``keys`` has an unsupported type.
"""
if isinstance(keys, str):
if keys in self.images.keys():
return self.images[keys]
else:
raise KeyError("Key not found in images")
elif isinstance(keys, tuple) or isinstance(keys, list):
# raise KeyError("Key not found in images")
# convert a simple index and tuples to lists
# if isinstance(keys, tuple)
# keys = list(keys) if type(keys) in [list, tuple] else list([keys])
dict_return = {}
for key in keys:
if key in self.images.keys():
dict_return[key] = self.images[key]
else:
raise KeyError("Keys not found in images")
if len(dict_return) == 1:
return list(dict_return.values())[0]
return dict_return
raise TypeError
@overload
def index(self, indices: int) -> ImageBase: ...
@overload
def index(self, indices: list[int] | tuple[int]) -> dict[str, ImageBase]: ...
[docs]
def index(self, indices: list[int] | tuple[int] | int) -> ImageBase | dict[str, ImageBase]:
"""index get images by indices
Get the images by using indices. Will return the image at the indices of the image dict from ImageBatch.
:param indices: The indices you want to get images for
:type indices: list[int] | tuple[int] | int
:raises TypeError: Indices type not valid
:raises IndexError: One of the indices is not valid
:return: return single image or dict of images for multiple indices
:rtype: ImageBase | dict[str, ImageBase]
"""
if isinstance(indices, int):
_indices = list([indices])
elif isinstance(indices, tuple) or isinstance(indices, list):
_indices = list(indices)
else:
raise TypeError
keys = list(self.images.keys())
try:
dict_return = {keys[x]: self.images[keys[x]] for x in _indices}
except IndexError as err:
raise IndexError("Index not found in images") from err
if len(dict_return) == 1:
return list(dict_return.values())[0]
return dict_return
[docs]
def map_center_point(
self, mapper: MappingBase | None = None, transformer: CoordinateTransformer | None = None
) -> dict[str, MappingResult]:
"""Map center point of image to 3d
:param mapper: Specify Mapper to be used, can be different to the one assigned in the class
:param transformer: A CoordinateTransfoer object which can be passed instead using crs_s
:return: MappingResult for all images
:raises ValueError: Image Batch is empty
:raises NotGeoreferencedError: If an image is not geo-referenced
:raises MapperMissingError: If no mapper is available
:raises CRSInputError: If CRS/transformer input is invalid
:raises CoordinateTransformationError: If coordinate transformation fails
:raises MappingError: If mapping fails
"""
center_p_dict = {}
if self.__len__() == 0:
raise ValueError("ImageBatch is empty")
# one_mapping_worked = False
all_mapping_worked = True
for key, image in self.images.items():
try:
result = image.map_center_point(mapper=self.mapper if mapper is None else mapper, transformer=transformer)
except NotGeoreferencedError as err_georef:
result = ResultFailure(ok=False, error=str(err_georef), issues={Issue.IMAGE_BATCH_ERROR})
except MapperMissingError as err_mapper:
result = ResultFailure(ok=False, error=str(err_mapper), issues={Issue.IMAGE_BATCH_ERROR})
except CRSInputError as err_crs:
result = ResultFailure(ok=False, error=str(err_crs), issues={Issue.IMAGE_BATCH_ERROR})
except CoordinateTransformationError as err_trafo:
result = ResultFailure(ok=False, error=str(err_trafo), issues={Issue.IMAGE_BATCH_ERROR})
except MappingError as err:
result = ResultFailure(ok=False, error=str(err), issues={Issue.IMAGE_BATCH_ERROR})
center_p_dict[key] = result
if result.ok is False:
all_mapping_worked = False
# else:
# one_mapping_worked = True
# if not one_mapping_worked:
# raise MappingError("For none of the images the center point could be mapped")
if not all_mapping_worked:
logger.warning("Some of the center points could not me mapped")
return center_p_dict
[docs]
def project(
self,
coordinates: ArrayNx3,
crs_s: CRS | None = None,
transformer: CoordinateTransformer | None = None,
only_valid: bool = False,
) -> dict[str, ProjectionResult] | None:
"""Calculate projection of 3d points into image. If points was outsize image size and to_distortion is true
the undistorted projection will be returned
:param coordinates: Array of (nx3) with the point positions
:type coordinates: ArrayNx3
:param crs_s: The coordinate system of the input coordinates
:param transformer: A CoordinateTransfoer object which can be passed instead using crs_s
:param only_valid: If true only return images with valid projections.
:return: dict[ImageName, ProjectionResult] or None if for filtered (only_valid) no projections are valid
:raises ValueError: If the image batch is empty.
:raises CRSInputError: If both ``crs_s`` and ``transformer`` are provided.
:raises NotGeoreferencedError: If an image is not geo-referenced.
:raises MapperMissingError: If a required mapper is missing.
:raises CoordinateTransformationError: If coordinate transformation fails.
"""
if self.__len__() == 0:
raise ValueError("ImageBatch is empty")
_coordinates = to_array_nx3(coordinates)
if crs_s is not None and transformer is not None:
raise CRSInputError("Either crs or transformation can be used or both None")
projected_dict: dict[str, ProjectionResult] = {}
for key, image in self.images.items():
try:
result = image.project(_coordinates, crs_s, transformer=transformer)
except NotGeoreferencedError as err_georef:
result = ResultFailure(ok=False, error=str(err_georef), issues={Issue.IMAGE_BATCH_ERROR})
except MapperMissingError as err_mapper:
result = ResultFailure(ok=False, error=str(err_mapper), issues={Issue.IMAGE_BATCH_ERROR})
except CRSInputError as err_crs:
result = ResultFailure(ok=False, error=str(err_crs), issues={Issue.IMAGE_BATCH_ERROR})
except CoordinateTransformationError as err_trafo:
result = ResultFailure(ok=False, error=str(err_trafo), issues={Issue.IMAGE_BATCH_ERROR})
projected_dict[key] = result
if only_valid:
filtered_dict: dict[str, ProjectionResult] = {}
for key, value in projected_dict.items():
if value.ok is True:
if np.any(value.mask):
filtered_dict[key] = value
if len(filtered_dict) == 0:
return None
return filtered_dict
return projected_dict