import torch from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch from core.intersection import center_line_intersections_torch from core.scoring import cl_score_torch # Global variables (used in objective_function) image1_array = None # cortical_nii.gz image2_array = None # binarynii.gz image2_shape = None image3_array = None # roi2.nii.gz diameter = None length = None spacing = [0.5, 0.5, 0.5] device = None grid = None USE_TIP_PENALTY = None def set_global_context( cortical, spine, shape, spacing_, device_, grid_, use_tip_penalty=False # 新增 ): global cortical_tensor, spine_tensor, image2_shape, spacing, device, grid, USE_TIP_PENALTY cortical_tensor = cortical spine_tensor = spine image2_shape = shape spacing = spacing_ device = device_ grid = grid_ USE_TIP_PENALTY = use_tip_penalty def cylinder_circle_line_intersection_loss_deductions_torch( diameter: float, length: float, params: list[float], image_shape: tuple[int, int, int], cortical_tensor: torch.Tensor, spine_tensor: torch.Tensor, spacing: list[float], device: torch.device ) -> float: """ Computes the loss for a given set of cylinder params in PyTorch, returning a Python float for PSO consumption. """ position_z, position_y, position_x, azimuth, altitude = params cyl_fwd = generate_cylinder_n_torch( diameter, length, position_z, position_y, position_x, float(azimuth), float(altitude), image_shape, spacing, device, grid ) cyl_opp = generate_cylinder_o_torch( diameter, length, position_z, position_y, position_x, float(azimuth), float(altitude), image_shape, spacing, device, grid ) # We call the center_line_intersections in Torch mode intersections, _ = center_line_intersections_torch( position_z, position_y, position_x, azimuth, altitude, length, spine_tensor, spacing, device ) cyl_tip = None if USE_TIP_PENALTY: cyl_tip = generate_cylinder_tip_torch( diameter, length, position_z, position_y, position_x, float(azimuth), float(altitude), image_shape, spacing, device, grid ) loss_value = cl_score_torch( cortical_tensor, spine_tensor, cyl_fwd, cyl_opp, intersections, cylinder_tip_torch=cyl_tip ) return loss_value def objective_function(params: list[float]) -> float: """ Wrapper for the PSO objective function, calling our Torch-based loss function. Now params includes diameter and length at the end. params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw] """ position_params = params[:5] # [z, y, x, azimuth, altitude] diameter_raw = params[5] length_raw = params[6] # 將連續值轉換為離散值 diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw) loss = cylinder_circle_line_intersection_loss_deductions_torch( diameter_discrete, length_discrete, position_params, image2_shape, cortical_tensor, spine_tensor, spacing, device ) return loss