import torch import numpy as np def bresenham3d(x1: float, y1: float, z1: float, x2: float, y2: float, z2: float) -> list[tuple[int, int, int]]: """ 3D Bresenham line in Python (CPU). We keep this CPU-based because it's discrete and typically not large enough to warrant GPU acceleration. """ x1, y1, z1 = round(x1), round(y1), round(z1) x2, y2, z2 = round(x2), round(y2), round(z2) points = [(x1, y1, z1)] dx = abs(x2 - x1) dy = abs(y2 - y1) dz = abs(z2 - z1) xs = 1 if x2 > x1 else -1 ys = 1 if y2 > y1 else -1 zs = 1 if z2 > z1 else -1 # Driving axis X if dx >= dy and dx >= dz: p1 = 2 * dy - dx p2 = 2 * dz - dx while x1 != x2: x1 += xs if p1 >= 0: y1 += ys p1 -= 2 * dx if p2 >= 0: z1 += zs p2 -= 2 * dx p1 += 2 * dy p2 += 2 * dz points.append((x1, y1, z1)) # Driving axis Y elif dy >= dx and dy >= dz: p1 = 2 * dx - dy p2 = 2 * dz - dy while y1 != y2: y1 += ys if p1 >= 0: x1 += xs p1 -= 2 * dy if p2 >= 0: z1 += zs p2 -= 2 * dy p1 += 2 * dx p2 += 2 * dz points.append((x1, y1, z1)) # Driving axis Z else: p1 = 2 * dy - dz p2 = 2 * dx - dz while z1 != z2: z1 += zs if p1 >= 0: y1 += ys p1 -= 2 * dz if p2 >= 0: x1 += xs p2 -= 2 * dz p1 += 2 * dy p2 += 2 * dx points.append((x1, y1, z1)) return points def center_line_intersections_torch( position_z: float, position_y: float, position_x: float, azimuth: float, altitude: float, length: float, image_tensor: torch.Tensor, spacing: list[float], device:torch.device ) -> tuple[int, torch.Tensor]: """ Computes the number of intersections along the 3D center line in the torch-based spine array. The line generation (Bresenham) is done on CPU, but the intersection counting is done in Torch. Returns (intersections, line_mask_torch). """ azimuth_rad = np.radians(azimuth) altitude_rad = np.radians(altitude) if spacing == [1, 1, 1]: length2 = length elif spacing == [0.5, 0.5, 0.5]: length2 = length * 2 else: raise ValueError(f"Unsupported spacing: {spacing}") # Direction vectors direction_z = np.cos(altitude_rad) direction_y = np.sin(altitude_rad) * np.sin(azimuth_rad) direction_x = np.sin(altitude_rad) * np.cos(azimuth_rad) # Endpoints end_point = ( position_z + length2 * direction_z, position_y + length2 * direction_y, position_x + length2 * direction_x, ) start_opposite = ( position_z - length2 * direction_z, position_y - length2 * direction_y, position_x - length2 * direction_x, ) # Round endpoints start_point = ( int(round(position_z)), int(round(position_y)), int(round(position_x)), ) end_point = tuple(map(int, np.round(end_point))) start_opposite = tuple(map(int, np.round(start_opposite))) # Bresenham line (CPU) line_points = bresenham3d( start_opposite[0], start_opposite[1], start_opposite[2], end_point[0], end_point[1], end_point[2] ) shape = image_tensor.shape # (z, y, x) line_mask_torch = torch.zeros(shape, device=device, dtype=torch.uint8) # Gather values from image_tensor point_values = [] for (z_pt, y_pt, x_pt) in line_points: if 0 <= z_pt < shape[0] and 0 <= y_pt < shape[1] and 0 <= x_pt < shape[2]: line_mask_torch[z_pt, y_pt, x_pt] = 1 # Convert to CPU to gather value quickly point_values.append(image_tensor[z_pt, y_pt, x_pt].item()) # Convert to NumPy for the simple difference-based intersection counting point_values_np = np.array(point_values, dtype=np.int32) d1 = np.diff(point_values_np) d2 = np.diff(d1) d3 = np.diff(d2) intersections = ( np.count_nonzero(np.abs(d1) == 1) - 2 * np.count_nonzero(np.abs(d2) == 2) + 2 * np.count_nonzero(np.abs(d3) == 4) ) return intersections, line_mask_torch