CBT_project/visualization/res_plot_3d.py

405 lines
15 KiB
Python
Raw Normal View History

2026-04-10 05:25:27 +00:00
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
import csv
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values
from core.intersection import center_line_intersections_torch
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask, cl_score_torch_xfr
2026-04-10 05:25:27 +00:00
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from utils.helpers import save_with_unique_name
def set_axes_equal_3d(ax):
"""
Make axes of 3D plot have equal scale so that spheres appear as spheres,
cubes as cubes, etc.
"""
x_limits = ax.get_xlim3d()
y_limits = ax.get_ylim3d()
z_limits = ax.get_zlim3d()
x_range = abs(x_limits[1] - x_limits[0])
x_middle = np.mean(x_limits)
y_range = abs(y_limits[1] - y_limits[0])
y_middle = np.mean(y_limits)
z_range = abs(z_limits[1] - z_limits[0])
z_middle = np.mean(z_limits)
plot_radius = 0.5*max([x_range, y_range, z_range])
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
try:
ax.set_box_aspect([1, 1, 1])
except AttributeError:
pass
2026-04-10 05:25:27 +00:00
def res_plt_2_torch(
spine_tensor: torch.Tensor,
cortical_tensor: torch.Tensor,
image_shape: tuple[int, int, int],
image2_path: str,
base_folder: str,
label_str: str,
diameter_l: float,
length_l: float,
diameter_r: float,
length_r: float,
best_position_l: list[float],
best_position_r: list[float],
swarm_size: int,
max_iter: int,
total_time: float,
spacing: list[float],
CBT: bool,
device: torch.device,
grid=None
) -> None:
"""
Same plotting function as before, but it uses torch-based generation
and then moves data to CPU for matplotlib 3D scatter.
"""
cyl_l = generate_cylinder_n_torch(
diameter_l,
length_l,
best_position_l[0],
best_position_l[1],
best_position_l[2],
best_position_l[3],
best_position_l[4],
image_shape,
spacing,
device,
grid
)
cyl_lo = generate_cylinder_o_torch(
diameter_l,
length_l,
best_position_l[0],
best_position_l[1],
best_position_l[2],
best_position_l[3],
best_position_l[4],
image_shape,
spacing,
device,
grid
)
cyl_r = generate_cylinder_n_torch(
diameter_r,
length_r,
best_position_r[0],
best_position_r[1],
best_position_r[2],
best_position_r[3],
best_position_r[4],
image_shape,
spacing,
device,
grid
)
cyl_ro = generate_cylinder_o_torch(
diameter_r,
length_r,
best_position_r[0],
best_position_r[1],
best_position_r[2],
best_position_r[3],
best_position_r[4],
image_shape,
spacing,
device,
grid
)
intersections_l, line_mask_l = center_line_intersections_torch(
best_position_l[0],
best_position_l[1],
best_position_l[2],
best_position_l[3],
best_position_l[4],
int(length_l),
spine_tensor,
spacing,
device
)
loss_l = cl_score_torch(cortical_tensor, spine_tensor, cyl_l, cyl_lo, intersections_l)
intersections_r, line_mask_r = center_line_intersections_torch(
best_position_r[0],
best_position_r[1],
best_position_r[2],
best_position_r[3],
best_position_r[4],
int(length_r),
spine_tensor,
spacing,
device
)
# loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
loss_r = cl_score_torch_xfr(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
2026-04-10 05:25:27 +00:00
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
# Move data to CPU for plotting
line_mask_l_cpu = line_mask_l.cpu().numpy()
line_mask_r_cpu = line_mask_r.cpu().numpy()
cyl_l_cpu = cyl_l.cpu().numpy()
cyl_lo_cpu = cyl_lo.cpu().numpy()
cyl_r_cpu = cyl_r.cpu().numpy()
cyl_ro_cpu = cyl_ro.cpu().numpy()
spine_cpu = spine_tensor.cpu().numpy()
z_lin1, y_lin1, x_lin1 = np.where(line_mask_l_cpu == 1)
z_lin2, y_lin2, x_lin2 = np.where(line_mask_r_cpu == 1)
z_cyl_l1, y_cyl_l1, x_cyl_l1 = np.where(cyl_l_cpu == 1)
z_cyl_l2, y_cyl_l2, x_cyl_l2 = np.where(cyl_lo_cpu == 1)
z_cyl_r1, y_cyl_r1, x_cyl_r1 = np.where(cyl_r_cpu == 1)
z_cyl_r2, y_cyl_r2, x_cyl_r2 = np.where(cyl_ro_cpu == 1)
z_img, y_img, x_img = np.where(spine_cpu == 1)
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(221, projection='3d')
ax1.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax1.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax1.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax1.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax1.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis')
set_axes_equal_3d(ax1)
2026-04-10 05:25:27 +00:00
ax2 = fig.add_subplot(222, projection='3d')
ax2.view_init(elev=90, azim=-90, roll=0)
ax2.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax2.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax2.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax2.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax2.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis')
set_axes_equal_3d(ax2)
2026-04-10 05:25:27 +00:00
ax2.legend()
ax3 = fig.add_subplot(223, projection='3d')
ax3.view_init(elev=0, azim=90, roll=0)
ax3.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax3.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax3.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax3.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax3.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis')
set_axes_equal_3d(ax3)
2026-04-10 05:25:27 +00:00
ax4 = fig.add_subplot(224, projection='3d')
ax4.view_init(elev=0, azim=0, roll=0)
ax4.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax4.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax4.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax4.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax4.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis')
set_axes_equal_3d(ax4)
2026-04-10 05:25:27 +00:00
cyl_points_l = torch.sum(cyl_l).item()
cyl_points_r = torch.sum(cyl_r).item()
overlap_l = ((cortical_tensor == 1) & (cyl_l == 1)).sum().item()
overlap_r = ((cortical_tensor == 1) & (cyl_r == 1)).sum().item()
overlap_b_l = ((spine_tensor == 1) & (cyl_l == 1)).sum().item()
overlap_b_r = ((spine_tensor == 1) & (cyl_r == 1)).sum().item()
overlap_cortical_l = (overlap_l / cyl_points_l) * 100
overlap_cortical_r = (overlap_r / cyl_points_r) * 100
overlap_vertebral_l = (overlap_b_l / cyl_points_l) * 100
overlap_vertebral_r = (overlap_b_r / cyl_points_r) * 100
cb_ratio_l = overlap_cortical_l/overlap_vertebral_l
cb_ratio_r = overlap_cortical_r/overlap_vertebral_r
user_altitude_l = 90 - best_position_l[4] - alt
user_altitude_r = 90 - best_position_r[4] - alt
user_azimuth_l = 90 - best_position_l[3] - azi
user_azimuth_r = 90 - best_position_r[3] - azi
date_str = datetime.now().strftime("%Y%m%d")
patient_id = os.path.basename(os.path.dirname(image2_path))
output_folder = os.path.join(base_folder, date_str, patient_id)
os.makedirs(output_folder, exist_ok=True)
csv_path = os.path.join(output_folder, 'output.csv')
# 檢查檔案是否存在 (決定是否寫入標題)
file_exists = os.path.isfile(csv_path)
# 欄位標題 (Header)
headers = [
'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter',
'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff',
'Intersections', 'Best_Loss', 'cyl_points', 'Overlap_Cortical', 'Overlap_Bone',
2026-04-10 05:25:27 +00:00
'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time'
]
try:
with open(csv_path, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
# 如果是新檔案,寫入 Header
if not file_exists:
writer.writerow(headers)
# 寫入 Left 數據
writer.writerow([
label_str,
'L',
diameter_l,
length_l,
swarm_size,
max_iter,
# f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
f"({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f})",
2026-04-10 05:25:27 +00:00
f"{best_position_l[3]:.2f}",
f"{best_position_l[3]-azi:.2f}",
f"{best_position_l[4]:.2f}",
f"{best_position_l[4]-alt:.2f}",
intersections_l,
f"{loss_l:.2f}",
cyl_points_l,
2026-04-10 05:25:27 +00:00
f"{overlap_cortical_l:.2f}",
f"{overlap_vertebral_l:.2f}",
f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}",
f"{user_azimuth_l:.2f}",
f"{user_altitude_l:.2f}",
f"{total_time:.2f}"
])
# 寫入 Right 數據
writer.writerow([
label_str,
'R',
diameter_r,
length_r,
swarm_size,
max_iter,
# f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
f"({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f})",
2026-04-10 05:25:27 +00:00
f"{best_position_r[3]:.2f}",
f"{best_position_r[3]-azi:.2f}",
f"{best_position_r[4]:.2f}",
f"{best_position_r[4]-alt:.2f}",
intersections_r,
f"{loss_r:.2f}",
cyl_points_r,
2026-04-10 05:25:27 +00:00
f"{overlap_cortical_r:.2f}",
f"{overlap_vertebral_r:.2f}",
f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}",
f"{user_azimuth_r:.2f}",
f"{user_altitude_r:.2f}",
f"{total_time:.2f}"
])
print(f"[CSV Saved] {csv_path}")
except Exception as e:
print(f"[Error] Failed to write CSV: {e}")
fig.text(0.5, 0.98, f'{label_str} Best Position', ha='center', fontsize=15)
fig.text(
0.5, 0.44,
f'L: Diameter = {diameter_l} mm, {length_l} mm, '
f'R: Diameter = {diameter_r} mm, {length_r} mm, '
f'Swarm size = {swarm_size}, Iteration = {max_iter}, Total time = {total_time:.2f} s',
ha='center', fontsize=12
)
fig.text(
0.5, 0.03,
# f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
f'Left : Position = ({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f}), '
2026-04-10 05:25:27 +00:00
f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, '
f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}',
ha='center', fontsize=9
)
fig.text(
0.5, 0.01,
# f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
f'Right : Position = ({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f}), '
2026-04-10 05:25:27 +00:00
f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, '
f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}',
ha='center', fontsize=9
)
fig.tight_layout()
date_str = datetime.now().strftime("%Y%m%d")
file_name = os.path.basename(image2_path)
level = file_name.split('_')[0]
output_folder = os.path.join(base_folder, date_str, patient_id)
os.makedirs(output_folder, exist_ok=True)
if CBT == True:
way = 'CBT'
else:
way = 'TPS'
path = save_with_unique_name(output_folder, label_str, way,
diameter_l, length_l, diameter_r, length_r,
swarm_size, max_iter)
fig.savefig(path, dpi=200, bbox_inches="tight")
print("[Saved figure]", path)
plt.close(fig)
def eval_overlap_from_position(
pos,
optimize_size: bool,
spine_tensor: torch.Tensor,
image_shape,
spacing,
device: torch.device,
grid=None,
fixed_diameter: float | None = None,
fixed_length: float | None = None,
):
"""
根據 position 生成 cylinder mask再算 overlap ratio
"""
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
if fixed_diameter is None or fixed_length is None:
raise ValueError("fixed_diameter and fixed_length must be provided when optimize_size=False")
d, L = fixed_diameter, fixed_length
params_5 = pos
z, y, x, az, alt = params_5
cyl_mask = generate_cylinder_n_torch(
d, L,
z, y, x,
az, alt,
image_shape, spacing,
device=device,
grid=grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L