227 lines
7.2 KiB
Python
227 lines
7.2 KiB
Python
#!/usr/bin/env python
|
||
"""
|
||
简单BEV分割可视化 - 不依赖mmdet3d库
|
||
"""
|
||
import os
|
||
import pickle
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
|
||
def load_results_safely(filepath):
|
||
"""安全加载结果,避免CUDA库问题"""
|
||
# 临时禁用CUDA相关的导入
|
||
import sys
|
||
original_modules = sys.modules.copy()
|
||
|
||
try:
|
||
# 阻止可能导致CUDA导入的模块
|
||
blocked_modules = [
|
||
'mmdet3d', 'torch', 'torchvision', 'torchaudio',
|
||
'mmcv', 'mmdet', 'torchpack'
|
||
]
|
||
for mod in blocked_modules:
|
||
if mod in sys.modules:
|
||
del sys.modules[mod]
|
||
|
||
with open(filepath, 'rb') as f:
|
||
results = pickle.load(f)
|
||
return results, True
|
||
except Exception as e:
|
||
print(f"标准加载失败: {e}")
|
||
return None, False
|
||
finally:
|
||
# 恢复原始模块状态
|
||
sys.modules.update(original_modules)
|
||
|
||
def visualize_bev_masks(masks_bev, sample_idx=0, save_dir='bev_visualization'):
|
||
"""可视化BEV分割结果"""
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 类别名称
|
||
class_names = ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']
|
||
colors = ['green', 'yellow', 'orange', 'red', 'purple', 'black']
|
||
|
||
# 如果是torch tensor,转换为numpy
|
||
if hasattr(masks_bev, 'cpu'):
|
||
masks_bev = masks_bev.cpu().numpy()
|
||
|
||
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||
fig.suptitle(f'BEV分割结果 - 样本 {sample_idx}', fontsize=16)
|
||
|
||
for i in range(6):
|
||
row, col = i // 3, i % 3
|
||
ax = axes[row, col]
|
||
|
||
mask = masks_bev[i]
|
||
if hasattr(mask, 'cpu'):
|
||
mask = mask.cpu().numpy()
|
||
|
||
# 显示mask (0-1之间的概率)
|
||
im = ax.imshow(mask, cmap='viridis', vmin=0, vmax=1)
|
||
ax.set_title(f'{class_names[i]}\n激活像素: {(mask > 0.5).sum()}')
|
||
ax.axis('off')
|
||
|
||
# 计算统计信息
|
||
activated_pixels = (mask > 0.5).sum()
|
||
total_pixels = mask.size
|
||
percentage = activated_pixels / total_pixels * 100
|
||
print("15s")
|
||
|
||
# 添加颜色条
|
||
plt.tight_layout()
|
||
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
|
||
fig.colorbar(im, cax=cbar_ax, label='预测概率')
|
||
|
||
# 保存图像
|
||
save_path = os.path.join(save_dir, f'bev_segmentation_sample_{sample_idx}.png')
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f"\nBEV分割可视化已保存: {save_path}")
|
||
plt.show()
|
||
|
||
# 创建叠加视图
|
||
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
||
ax.set_title(f'BEV分割叠加视图 - 样本 {sample_idx}')
|
||
|
||
# 创建彩色叠加图像
|
||
height, width = masks_bev[0].shape
|
||
overlay = np.zeros((height, width, 3), dtype=np.uint8)
|
||
|
||
for i in range(6):
|
||
mask = masks_bev[i]
|
||
if hasattr(mask, 'cpu'):
|
||
mask = mask.cpu().numpy()
|
||
|
||
# 只显示高置信度区域
|
||
binary_mask = mask > 0.5
|
||
|
||
# 为每个类别分配颜色
|
||
if class_names[i] == 'drivable_area':
|
||
color = [0, 255, 0] # 绿色
|
||
elif class_names[i] == 'ped_crossing':
|
||
color = [255, 255, 0] # 黄色
|
||
elif class_names[i] == 'walkway':
|
||
color = [255, 165, 0] # 橙色
|
||
elif class_names[i] == 'stop_line':
|
||
color = [255, 0, 0] # 红色
|
||
elif class_names[i] == 'carpark_area':
|
||
color = [128, 0, 128] # 紫色
|
||
else: # divider
|
||
color = [0, 0, 0] # 黑色
|
||
|
||
# 应用颜色 (只在mask区域)
|
||
for c in range(3):
|
||
overlay[binary_mask, c] = color[c]
|
||
|
||
ax.imshow(overlay)
|
||
ax.set_title(f'BEV分割叠加 (高置信度区域) - 样本 {sample_idx}')
|
||
ax.axis('off')
|
||
|
||
# 添加图例
|
||
legend_elements = []
|
||
for i, name in enumerate(class_names):
|
||
if class_names[i] == 'drivable_area':
|
||
color = [0, 255, 0]
|
||
elif class_names[i] == 'ped_crossing':
|
||
color = [255, 255, 0]
|
||
elif class_names[i] == 'walkway':
|
||
color = [255, 165, 0]
|
||
elif class_names[i] == 'stop_line':
|
||
color = [255, 0, 0]
|
||
elif class_names[i] == 'carpark_area':
|
||
color = [128, 0, 128]
|
||
else:
|
||
color = [0, 0, 0]
|
||
|
||
import matplotlib.patches as mpatches
|
||
patch = mpatches.Patch(color=np.array(color)/255, label=name)
|
||
legend_elements.append(patch)
|
||
|
||
ax.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')
|
||
|
||
overlay_path = os.path.join(save_dir, f'bev_overlay_sample_{sample_idx}.png')
|
||
plt.savefig(overlay_path, dpi=300, bbox_inches='tight')
|
||
print(f"BEV叠加可视化已保存: {overlay_path}")
|
||
plt.show()
|
||
|
||
def analyze_3d_detection(sample, sample_idx=0):
|
||
"""分析3D检测结果"""
|
||
print(f"\n=== 3D检测结果分析 - 样本 {sample_idx} ===")
|
||
|
||
if 'boxes_3d' in sample:
|
||
boxes_3d = sample['boxes_3d']
|
||
print(f"3D检测框数量: {len(boxes_3d)}")
|
||
|
||
if len(boxes_3d) > 0:
|
||
if hasattr(boxes_3d, 'cpu'):
|
||
boxes_3d = boxes_3d.cpu().numpy()
|
||
|
||
print(f"检测框形状: {boxes_3d.shape}")
|
||
|
||
# 分析检测框的范围
|
||
if boxes_3d.shape[1] >= 7: # x, y, z, w, l, h, rot
|
||
x_coords = boxes_3d[:, 0]
|
||
y_coords = boxes_3d[:, 1]
|
||
z_coords = boxes_3d[:, 2]
|
||
|
||
print(".2f")
|
||
print(".2f")
|
||
print(".2f")
|
||
|
||
if 'scores_3d' in sample:
|
||
scores_3d = sample['scores_3d']
|
||
if len(scores_3d) > 0:
|
||
if hasattr(scores_3d, 'cpu'):
|
||
scores_3d = scores_3d.cpu().numpy()
|
||
print(".3f")
|
||
|
||
if 'labels_3d' in sample:
|
||
labels_3d = sample['labels_3d']
|
||
if len(labels_3d) > 0:
|
||
if hasattr(labels_3d, 'cpu'):
|
||
labels_3d = labels_3d.cpu().numpy()
|
||
unique_labels, counts = np.unique(labels_3d, return_counts=True)
|
||
print("检测类别分布:")
|
||
for label, count in zip(unique_labels, counts):
|
||
print(f" 类别 {int(label)}: {count} 个")
|
||
|
||
def main():
|
||
# 推理结果文件
|
||
result_file = '/data/infer_test/20251120_124755/one_batch_results.pkl'
|
||
|
||
print("开始分析推理结果...")
|
||
print(f"结果文件: {result_file}")
|
||
|
||
# 加载结果
|
||
results, success = load_results_safely(result_file)
|
||
|
||
if not success or results is None:
|
||
print("无法加载结果文件")
|
||
return
|
||
|
||
print(f"成功加载! 包含 {len(results)} 个样本")
|
||
|
||
# 分析第一个样本
|
||
sample_idx = 0
|
||
if sample_idx < len(results):
|
||
sample = results[sample_idx]
|
||
print(f"\n样本 {sample_idx} 包含字段: {list(sample.keys())}")
|
||
|
||
# 可视化BEV分割
|
||
if 'masks_bev' in sample:
|
||
print("\n开始BEV分割可视化...")
|
||
visualize_bev_masks(sample['masks_bev'], sample_idx)
|
||
|
||
# 分析3D检测
|
||
analyze_3d_detection(sample, sample_idx)
|
||
|
||
# 分析Ground Truth (如果有)
|
||
if 'gt_masks_bev' in sample:
|
||
gt_masks = sample['gt_masks_bev']
|
||
print(f"\nGround Truth BEV分割形状: {gt_masks.shape}")
|
||
|
||
else:
|
||
print(f"样本索引 {sample_idx} 超出范围")
|
||
|
||
if __name__ == '__main__':
|
||
main()
|