class Viewer:
def __init__(
self,
nd_image: np.ndarray,
axlims: dict[str : list[float, float]],
x_axis: Optional[str] = None,
y_axis: Optional[str] = None,
s_axis: Optional[str] = None,
):
self.im: np.ndarray = nd_image
self.shape = nd_image.shape
if len(self.shape) < 2:
raise ValueError("Cannot display plot for image with less than 2 dimensions.")
if len(self.shape) != len(axlims):
raise ValueError("axlims must correspond with nd_image axes.")
if not all(len(ax) == 2 for ax in axlims.values()):
raise ValueError(
"axlims must only have 2 values per axis, the min and max of the axis."
)
self.axnames = list(axlims.keys())
self.axvalues = [
np.linspace(vmin, vmax, n) for n, (vmin, vmax) in zip(self.shape, axlims.values())
]
self.axsteps = [
(vmax - vmin) / (n - 1) for n, (vmin, vmax) in zip(self.shape, axlims.values())
]
# bad input handling, treat 2D and 3D+ cases separately
if len(self.shape) == 2:
if s_axis is not None:
raise ValueError("No scroll axis can be set for a 2D image.")
if all(ax == None for ax in [x_axis, y_axis]):
x_axis = 0
y_axis = 1
elif any(ax == None for ax in [x_axis, y_axis, s_axis]):
raise ValueError("x_axis and y_axis must both be given.")
else:
if all(ax == None for ax in [x_axis, y_axis, s_axis]):
x_axis = 0
y_axis = 1
s_axis = 2
elif any(ax == None for ax in [x_axis, y_axis, s_axis]):
raise ValueError("All of x_axis, y_axis, and s_axis must be given.")
if x_axis == y_axis or y_axis == s_axis or s_axis == x_axis:
raise ValueError("Supplied axes must be distinct.")
self.ihax = x_axis
self.ivax = y_axis
self.isax = s_axis
self.cursor = [0] * len(self.shape)
self.cursor[self.ihax] = FULL_AXIS
self.cursor[self.ivax] = FULL_AXIS
self.scroll_plane = self.slice_for_scroll()
# === SLICING METHODS ===
def slice_for_scroll(self):
cursor = self.cursor.copy()
cursor[self.ihax] = FULL_AXIS
cursor[self.ivax] = FULL_AXIS
if self.isax:
cursor[self.isax] = FULL_AXIS
if sum(i == FULL_AXIS for i in cursor) > 3:
raise ValueError("pivot not set!")
axes = np.argsort(
[self.ihax, self.ivax, self.isax] if self.isax else [self.ihax, self.ivax]
)
axes = np.argsort(axes) # require inverse permutation
return self.im[tuple(cursor)].transpose(axes)
def slice_for_plot(self):
if not self.isax:
return self.scroll_plane
return self.scroll_plane[..., self.cursor[self.isax]]
# === EVENT METHODS ===
def update_plot(self):
plot_plane = self.slice_for_plot()
self.viewer_image.set_data(plot_plane.T)
info = [
f"{ax}={values[vi]}"
for i, (ax, values, vi) in enumerate(zip(self.axnames, self.axvalues, self.cursor))
if i not in [self.ihax, self.ivax]
]
self.ax.set_title(", ".join(info))
self.ax.set_xlabel(self.axnames[self.ihax])
self.ax.set_ylabel(self.axnames[self.ivax])
# update figure
self.fig.canvas.draw()
self.fig.canvas.flush_events()
def on_left_click(self, event):
x = event.xdata
y = event.ydata
if x is None or y is None:
return
xlims = self.axvalues[self.ihax]
ylims = self.axvalues[self.ivax]
xstep = self.axsteps[self.ihax]
ystep = self.axsteps[self.ivax]
xi = np.searchsorted(xlims, x - xstep / 2)
yi = np.searchsorted(ylims, y - ystep / 2)
self.cursor[self.ihax] = xi
self.cursor[self.ivax] = yi
def on_scroll(self, event):
if not self.isax:
return
new_index = self.cursor[self.isax] + int(event.step)
self.cursor[self.isax] = max(0, min(self.shape[self.isax] - 1, new_index))
self.update_plot()
def on_keypress(self, event):
log = []
log.append(f"keypress {event.key}")
if event.key.isdigit() and self.isax is not None:
axis_index = int(event.key) - 1
if axis_index > len(self.shape):
return
if self.ihax == axis_index:
return
if self.ivax == axis_index:
return
self.isax = axis_index
self.scroll_plane = self.slice_for_scroll()
log.append(f"scroll axis set to {self.axnames[self.isax]}")
if event.key.count("+") == 1:
key1, key2 = event.key.split("+")
if not key2.isdigit():
return
if key1 not in ["ctrl", "alt"]:
return
if any(i == FULL_AXIS for i in self.cursor):
raise ValueError("pivot not set, click on the plotting region to set a pivot")
axis_index = int(key2) - 1
if key1 == "ctrl":
if self.isax == axis_index:
self.isax = self.ihax # swap axes
log.append(f"scroll axis set to {self.axnames[self.isax]}")
if self.ivax == axis_index:
self.ivax = self.ihax # swap axes
log.append(f"y axis set to {self.axnames[self.ivax]}")
self.ihax = axis_index
self.scroll_plane = self.slice_for_scroll()
log.append(f"x axis set to {self.axnames[self.ihax]}")
if key1 == "alt":
if self.isax == axis_index:
self.isax = self.ivax # swap axes
log.append(f"scroll axis set to {self.axnames[self.isax]}")
if self.ihax == axis_index:
self.ihax = self.ivax # swap axes
log.append(f"x axis set to {self.axnames[self.ihax]}")
self.ivax = axis_index
self.scroll_plane = self.slice_for_scroll()
log.append(f"y axis set to {self.axnames[self.ivax]}")
print(", ".join(log))
self.update_plot()
# === MPL CONFIGURATION METHODS ===
def _mpl_button_press_event(self, event):
for thing in self.garbage:
# NOTE: Various objects seems to be in the garbage list when not existing
# within the plot's object list. For now, I will pass in these cases.
# I have no clue what these objects are. This may lead to memory problems
# in extreme cases, but for now they appear harmless.
try:
thing.remove()
except ValueError:
pass
del thing
# to counteract the above memory concerns, run gc.collect here
gc.collect()
if event.button == 1:
self.on_left_click(event)
if event.button == 3:
self.on_right_click(event)
def _mpl_scroll_event(self, event):
for thing in self.garbage:
# NOTE: Various objects seems to be in the garbage list when not existing
# within the plot's object list. For now, I will pass in these cases.
# I have no clue what these objects are. This may lead to memory problems
# in extreme cases, but for now they appear harmless.
try:
thing.remove()
except ValueError:
pass
del thing
# to counteract the above memory concerns, run gc.collect here
gc.collect()
self.on_scroll(event)
def _mpl_key_press_event(self, event):
self.on_keypress(event)
def _mpl_add_connections(self):
self.connections = []
self.connections.append(
self.fig.canvas.mpl_connect("button_press_event", self._mpl_button_press_event)
)
self.connections.append(self.fig.canvas.mpl_connect("scroll_event", self._mpl_scroll_event))
self.connections.append(
self.fig.canvas.mpl_connect("key_press_event", self._mpl_key_press_event)
)
def show(self, **imshow_kwargs: dict[str, Any]):
"""
Show the interactive plot.
Args:
**imshow_kwargs: Additional keyword arguments for plt.imshow.
"""
self.plot_axes(**imshow_kwargs)
self.garbage: set = set()
# connect methods to figure
self._mpl_add_connections()
plt.show()
# === INITIAL PLOTTING METHOD ===
def plot_axes(self, **imshow_kwargs: dict[str, Any]):
"""
Plot the axes.
Args:
**imshow_kwargs: Additional keyword arguments for imshow.
"""
self.fig = plt.figure(figsize=(6, 6))
self.ax = self.fig.add_subplot(111)
info = [
f"{ax}={values[vi]}"
for i, (ax, values, vi) in enumerate(zip(self.axnames, self.axvalues, self.cursor))
if i not in [self.ihax, self.ivax]
]
self.ax.set_title(", ".join(info))
xmin, *_, xmax = self.axvalues[self.ihax]
ymin, *_, ymax = self.axvalues[self.ivax]
xstep = self.axsteps[self.ihax]
ystep = self.axsteps[self.ivax]
kwargs = dict(
origin="lower",
cmap="inferno",
vmin=0,
vmax=1,
extent=(xmin - xstep / 2, xmax + xstep / 2, ymin - ystep / 2, ymax + ystep / 2),
aspect=(xmax - xmin) / (ymax - ymin),
)
kwargs.update(imshow_kwargs)
plot_plane = self.slice_for_plot()
self.ax.set_xlabel(self.axnames[self.ihax])
self.ax.set_ylabel(self.axnames[self.ivax])
self.viewer_image = self.ax.imshow(plot_plane.T, **kwargs)