CBT_project/core/scoring.py

266 lines
8.9 KiB
Python
Raw Normal View History

2026-04-10 05:25:27 +00:00
import torch
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch
from config.constant import OVERLAP_THRESH
def cl_score_torch_xfr(
cortical_tensor: torch.Tensor,
spine_tensor: torch.Tensor,
cylinder_torch: torch.Tensor,
cylinder_o_torch: torch.Tensor,
intersections: int,
diameter: float = None,
length: float = None,
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
) -> float:
"""
漸進式評分優先確保找到骨頭再改善細節
"""
cyl_total = cylinder_torch.sum().item()
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
in_bone= ((spine_tensor == 1) & (cylinder_torch == 1)).sum().item()
not_in_bone= ((spine_tensor == 0) & (cylinder_torch == 1)).sum().item()
# if cyl_total == 0:
# return float(1000*1000)
# return float(1e9) # 極差的情況
overlap_ratio = overlap / cyl_total
# if cyl_total < 1000:
# return float((1000 - cyl_total)*10000)
score = cyl_total
# if in_bone == 0:
# return float(not_in_bone*200)
score += overlap*30
score += in_bone*10
score -= not_in_bone*1000
score -= null_vox2*1000
return float(-score)
# === 階段 1首要目標是找到骨頭overlap > 0 ===
if overlap == 0:
# 完全沒有 overlap 是最糟糕的情況
score -= 500000 # 超大懲罰
# 如果連 spine 都沒穿過,更糟
if intersections == 0:
score -= 500000
return float(-score)
# === 階段 2有找到骨頭了開始改善品質 ===
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
if overlap_ratio < 0.1:
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
score += overlap * 5000 # 很高的單位獎勵
elif overlap_ratio < 0.3:
# 10-30%:中等獎勵
score += overlap * 3000
elif overlap_ratio < 0.5:
# 30-50%:正常獎勵
score += overlap * 2000
else:
# 50%+:獎勵 + 額外比例獎勵
score += overlap * 2000
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
# 2. Intersection 控制(稍微放寬)
if intersections == 1:
score += 20000 # 完美
elif intersections == 0:
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
elif intersections == 2:
score -= 10000 # 可接受但不理想
else:
score -= intersections * 15000
# 3. Null voxel 懲罰(漸進式)
null_ratio = null_vox / cyl_total
if overlap_ratio < 0.2:
# 如果 overlap 還很少,對 null voxel 寬容一點
score -= null_vox * 300
elif overlap_ratio < 0.4:
score -= null_vox * 600
else:
# overlap 夠高了,開始嚴格要求
if null_ratio > 0.5:
score -= null_vox * 1500
else:
score -= null_vox * 800
# 4. 反向圓柱懲罰
score -= null_vox2 * 1000
# 5. 尺寸合理性(放寬)
if diameter is not None and length is not None:
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
score -= 3000
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
score -= 3000
# 6. 尖端 breach 懲罰
if cylinder_tip_torch is not None:
tip_total = cylinder_tip_torch.sum().item()
if tip_total > 0:
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
tip_breach_ratio = tip_breach / tip_total
if tip_breach_ratio > 0:
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
return float(-score)
2026-04-10 05:25:27 +00:00
def cl_score_torch(
cortical_tensor: torch.Tensor,
spine_tensor: torch.Tensor,
cylinder_torch: torch.Tensor,
cylinder_o_torch: torch.Tensor,
intersections: int,
diameter: float = None,
length: float = None,
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
) -> float:
"""
漸進式評分優先確保找到骨頭再改善細節
"""
cyl_total = cylinder_torch.sum().item()
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
if cyl_total == 0:
return float(1e9) # 極差的情況
overlap_ratio = overlap / cyl_total
score = 0
# === 階段 1首要目標是找到骨頭overlap > 0 ===
if overlap == 0:
# 完全沒有 overlap 是最糟糕的情況
score -= 500000 # 超大懲罰
# 如果連 spine 都沒穿過,更糟
if intersections == 0:
score -= 500000
return float(-score)
# === 階段 2有找到骨頭了開始改善品質 ===
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
if overlap_ratio < 0.1:
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
score += overlap * 5000 # 很高的單位獎勵
elif overlap_ratio < 0.3:
# 10-30%:中等獎勵
score += overlap * 3000
elif overlap_ratio < 0.5:
# 30-50%:正常獎勵
score += overlap * 2000
else:
# 50%+:獎勵 + 額外比例獎勵
score += overlap * 2000
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
# 2. Intersection 控制(稍微放寬)
if intersections == 1:
score += 20000 # 完美
elif intersections == 0:
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
elif intersections == 2:
score -= 10000 # 可接受但不理想
else:
score -= intersections * 15000
# 3. Null voxel 懲罰(漸進式)
null_ratio = null_vox / cyl_total
if overlap_ratio < 0.2:
# 如果 overlap 還很少,對 null voxel 寬容一點
score -= null_vox * 300
elif overlap_ratio < 0.4:
score -= null_vox * 600
else:
# overlap 夠高了,開始嚴格要求
if null_ratio > 0.5:
score -= null_vox * 1500
else:
score -= null_vox * 800
# 4. 反向圓柱懲罰
score -= null_vox2 * 1000
# 5. 尺寸合理性(放寬)
if diameter is not None and length is not None:
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
score -= 3000
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
score -= 3000
# 6. 尖端 breach 懲罰
if cylinder_tip_torch is not None:
tip_total = cylinder_tip_torch.sum().item()
if tip_total > 0:
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
tip_breach_ratio = tip_breach / tip_total
if tip_breach_ratio > 0:
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
return float(-score)
def get_overlap_ratio(
position_params: list,
diameter: float,
length: float,
cortical_tensor: torch.Tensor,
image_shape: tuple,
spacing: list,
device: torch.device,
grid=None
) -> float:
"""
計算 Cylinder Cortical Bone 的重疊比例 (%)
"""
# 生成 Cylinder Mask
cyl_mask = generate_cylinder_n_torch(
diameter, length,
position_params[0], position_params[1], position_params[2],
position_params[3], position_params[4],
image_shape, spacing, device, grid
)
# 計算體積 (Voxel count)
cyl_vol = torch.sum(cyl_mask).item()
if cyl_vol == 0:
return 0.0
# 計算重疊部分
# 注意:這裡使用 cortical_tensor (與 cl_score_torch 邏輯一致)
overlap_count = ((cortical_tensor == 1) & (cyl_mask == 1)).sum().item()
return (overlap_count / cyl_vol) * 100.0
def compute_overlap_ratio_from_cylinder_mask(cyl_mask: torch.Tensor,
spine_mask: torch.Tensor,
eps: float = 1e-6) -> float:
"""
一個常見定義 overlap = intersection / cylinder_volume
你也可以改成 intersection / spine_volume Dice依你論文/需求一致即可
cyl_mask, spine_mask: uint8/bool tensor, same shape
"""
cyl = cyl_mask.bool()
sp = spine_mask.bool()
inter = (cyl & sp).sum().item()
denom = cyl.sum().item()
return float(inter) / float(denom + eps)
def is_solution_ok(loss: float, overlap: float, overlap_thresh: float = OVERLAP_THRESH) -> bool:
return (loss <= 0) and (overlap >= overlap_thresh)