import random from scipy.ndimage import map_coordinates import numpy as np import torch from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch from core.intersection import center_line_intersections_torch from core.scoring import cl_score_torch, cl_score_torch_xfr # Global variables (used in objective_function) image1_array = None # cortical_nii.gz image2_array = None # binarynii.gz image2_shape = None image3_array = None # roi2.nii.gz diameter = None length = None spacing = [0.5, 0.5, 0.5] device = None grid = None USE_TIP_PENALTY = None def set_global_context( cortical, spine, shape, spacing_, device_, grid_, use_tip_penalty=False # 新增 ): global cortical_tensor, spine_tensor, image2_shape, spacing, device, grid, USE_TIP_PENALTY cortical_tensor = cortical spine_tensor = spine image2_shape = shape spacing = spacing_ device = device_ grid = grid_ USE_TIP_PENALTY = use_tip_penalty def cylinder_circle_line_intersection_loss_deductions_torch( diameter: float, length: float, params: list[float], image_shape: tuple[int, int, int], cortical_tensor: torch.Tensor, spine_tensor: torch.Tensor, spacing: list[float], device: torch.device ) -> float: """ Computes the loss for a given set of cylinder params in PyTorch, returning a Python float for PSO consumption. """ position_z, position_y, position_x, azimuth, altitude = params cyl_fwd = generate_cylinder_n_torch( diameter, length, position_z, position_y, position_x, float(azimuth), float(altitude), image_shape, spacing, device, grid ) cyl_opp = generate_cylinder_o_torch( diameter, length, position_z, position_y, position_x, float(azimuth), float(altitude), image_shape, spacing, device, grid ) # We call the center_line_intersections in Torch mode intersections, _ = center_line_intersections_torch( position_z, position_y, position_x, azimuth, altitude, length, spine_tensor, spacing, device ) cyl_tip = None if USE_TIP_PENALTY: cyl_tip = generate_cylinder_tip_torch( diameter, length, position_z, position_y, position_x, float(azimuth), float(altitude), image_shape, spacing, device, grid ) # loss_value = cl_score_torch( loss_value = cl_score_torch_xfr( cortical_tensor, spine_tensor, cyl_fwd, cyl_opp, intersections, cylinder_tip_torch=cyl_tip ) return loss_value def objective_function_xfr(params: list[float], y_indices) -> float: """ Wrapper for the PSO objective function, calling our Torch-based loss function. Now params includes diameter and length at the end. params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw] """ # position_params = params[:5] # [z, y, x, azimuth, altitude] # diameter_raw = params[5] # length_raw = params[6] z, x, azimuth, altitude, diameter_raw, length_raw = params y = y_indices[round(z), round(x)] #+ random.uniform(-0.5, 0.5) # coords = np.array([[z], [x]]) # result = map_coordinates(y_indices, coords, order=1) # y= result[0] position_params = [z, y, x, azimuth, altitude] # 將連續值轉換為離散值 # diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw) diameter_discrete, length_discrete = diameter_raw, length_raw loss = cylinder_circle_line_intersection_loss_deductions_torch( diameter_discrete, length_discrete, position_params, image2_shape, cortical_tensor, spine_tensor, spacing, device ) return loss def objective_function(params: list[float]) -> float: """ Wrapper for the PSO objective function, calling our Torch-based loss function. Now params includes diameter and length at the end. params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw] """ position_params = params[:5] # [z, y, x, azimuth, altitude] diameter_raw = params[5] length_raw = params[6] # 將連續值轉換為離散值 diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw) # diameter_discrete, length_discrete = diameter_raw, length_raw loss = cylinder_circle_line_intersection_loss_deductions_torch( diameter_discrete, length_discrete, position_params, image2_shape, cortical_tensor, spine_tensor, spacing, device ) return loss