371 lines
13 KiB
Python
371 lines
13 KiB
Python
|
|
import pickle
|
|||
|
|
from os import path as osp
|
|||
|
|
|
|||
|
|
import mmcv
|
|||
|
|
import numpy as np
|
|||
|
|
from mmcv import track_iter_progress
|
|||
|
|
from mmcv.ops import roi_align
|
|||
|
|
from pycocotools import mask as maskUtils
|
|||
|
|
from pycocotools.coco import COCO
|
|||
|
|
|
|||
|
|
from mmdet3d.core.bbox import box_np_ops as box_np_ops
|
|||
|
|
from mmdet3d.datasets import build_dataset
|
|||
|
|
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _poly2mask(mask_ann, img_h, img_w):
|
|||
|
|
if isinstance(mask_ann, list):
|
|||
|
|
# polygon -- a single object might consist of multiple parts
|
|||
|
|
# we merge all parts into one mask rle code
|
|||
|
|
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
|||
|
|
rle = maskUtils.merge(rles)
|
|||
|
|
elif isinstance(mask_ann["counts"], list):
|
|||
|
|
# uncompressed RLE
|
|||
|
|
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
|||
|
|
else:
|
|||
|
|
# rle
|
|||
|
|
rle = mask_ann
|
|||
|
|
mask = maskUtils.decode(rle)
|
|||
|
|
return mask
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_coco_ann_info(ann_info):
|
|||
|
|
gt_bboxes = []
|
|||
|
|
gt_labels = []
|
|||
|
|
gt_bboxes_ignore = []
|
|||
|
|
gt_masks_ann = []
|
|||
|
|
|
|||
|
|
for i, ann in enumerate(ann_info):
|
|||
|
|
if ann.get("ignore", False):
|
|||
|
|
continue
|
|||
|
|
x1, y1, w, h = ann["bbox"]
|
|||
|
|
if ann["area"] <= 0:
|
|||
|
|
continue
|
|||
|
|
bbox = [x1, y1, x1 + w, y1 + h]
|
|||
|
|
if ann.get("iscrowd", False):
|
|||
|
|
gt_bboxes_ignore.append(bbox)
|
|||
|
|
else:
|
|||
|
|
gt_bboxes.append(bbox)
|
|||
|
|
gt_masks_ann.append(ann["segmentation"])
|
|||
|
|
|
|||
|
|
if gt_bboxes:
|
|||
|
|
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
|
|||
|
|
gt_labels = np.array(gt_labels, dtype=np.int64)
|
|||
|
|
else:
|
|||
|
|
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
|
|||
|
|
gt_labels = np.array([], dtype=np.int64)
|
|||
|
|
|
|||
|
|
if gt_bboxes_ignore:
|
|||
|
|
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
|
|||
|
|
else:
|
|||
|
|
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
|||
|
|
|
|||
|
|
ann = dict(bboxes=gt_bboxes, bboxes_ignore=gt_bboxes_ignore, masks=gt_masks_ann)
|
|||
|
|
|
|||
|
|
return ann
|
|||
|
|
|
|||
|
|
|
|||
|
|
def crop_image_patch_v2(pos_proposals, pos_assigned_gt_inds, gt_masks):
|
|||
|
|
import torch
|
|||
|
|
from torch.nn.modules.utils import _pair
|
|||
|
|
|
|||
|
|
device = pos_proposals.device
|
|||
|
|
num_pos = pos_proposals.size(0)
|
|||
|
|
fake_inds = torch.arange(num_pos, device=device).to(dtype=pos_proposals.dtype)[
|
|||
|
|
:, None
|
|||
|
|
]
|
|||
|
|
rois = torch.cat([fake_inds, pos_proposals], dim=1) # Nx5
|
|||
|
|
mask_size = _pair(28)
|
|||
|
|
rois = rois.to(device=device)
|
|||
|
|
gt_masks_th = (
|
|||
|
|
torch.from_numpy(gt_masks)
|
|||
|
|
.to(device)
|
|||
|
|
.index_select(0, pos_assigned_gt_inds)
|
|||
|
|
.to(dtype=rois.dtype)
|
|||
|
|
)
|
|||
|
|
# Use RoIAlign could apparently accelerate the training (~0.1s/iter)
|
|||
|
|
targets = roi_align(gt_masks_th, rois, mask_size[::-1], 1.0, 0, True).squeeze(1)
|
|||
|
|
return targets
|
|||
|
|
|
|||
|
|
|
|||
|
|
def crop_image_patch(pos_proposals, gt_masks, pos_assigned_gt_inds, org_img):
|
|||
|
|
num_pos = pos_proposals.shape[0]
|
|||
|
|
masks = []
|
|||
|
|
img_patches = []
|
|||
|
|
for i in range(num_pos):
|
|||
|
|
gt_mask = gt_masks[pos_assigned_gt_inds[i]]
|
|||
|
|
bbox = pos_proposals[i, :].astype(np.int32)
|
|||
|
|
x1, y1, x2, y2 = bbox
|
|||
|
|
w = np.maximum(x2 - x1 + 1, 1)
|
|||
|
|
h = np.maximum(y2 - y1 + 1, 1)
|
|||
|
|
|
|||
|
|
mask_patch = gt_mask[y1 : y1 + h, x1 : x1 + w]
|
|||
|
|
masked_img = gt_mask[..., None] * org_img
|
|||
|
|
img_patch = masked_img[y1 : y1 + h, x1 : x1 + w]
|
|||
|
|
|
|||
|
|
img_patches.append(img_patch)
|
|||
|
|
masks.append(mask_patch)
|
|||
|
|
return img_patches, masks
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_groundtruth_database(
|
|||
|
|
dataset_class_name,
|
|||
|
|
data_path,
|
|||
|
|
info_prefix,
|
|||
|
|
info_path=None,
|
|||
|
|
mask_anno_path=None,
|
|||
|
|
used_classes=None,
|
|||
|
|
database_save_path=None,
|
|||
|
|
db_info_save_path=None,
|
|||
|
|
relative_path=True,
|
|||
|
|
add_rgb=False,
|
|||
|
|
lidar_only=False,
|
|||
|
|
bev_only=False,
|
|||
|
|
coors_range=None,
|
|||
|
|
with_mask=False,
|
|||
|
|
load_augmented=None,
|
|||
|
|
):
|
|||
|
|
"""Given the raw data, generate the ground truth database.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
dataset_class_name (str): Name of the input dataset.
|
|||
|
|
data_path (str): Path of the data.
|
|||
|
|
info_prefix (str): Prefix of the info file.
|
|||
|
|
info_path (str): Path of the info file.
|
|||
|
|
Default: None.
|
|||
|
|
mask_anno_path (str): Path of the mask_anno.
|
|||
|
|
Default: None.
|
|||
|
|
used_classes (list[str]): Classes have been used.
|
|||
|
|
Default: None.
|
|||
|
|
database_save_path (str): Path to save database.
|
|||
|
|
Default: None.
|
|||
|
|
db_info_save_path (str): Path to save db_info.
|
|||
|
|
Default: None.
|
|||
|
|
relative_path (bool): Whether to use relative path.
|
|||
|
|
Default: True.
|
|||
|
|
with_mask (bool): Whether to use mask.
|
|||
|
|
Default: False.
|
|||
|
|
"""
|
|||
|
|
print(f"Create GT Database of {dataset_class_name}")
|
|||
|
|
dataset_cfg = dict(
|
|||
|
|
type=dataset_class_name, dataset_root=data_path, ann_file=info_path
|
|||
|
|
)
|
|||
|
|
if dataset_class_name == "KittiDataset":
|
|||
|
|
dataset_cfg.update(
|
|||
|
|
test_mode=False,
|
|||
|
|
split="training",
|
|||
|
|
modality=dict(
|
|||
|
|
use_lidar=True,
|
|||
|
|
use_depth=False,
|
|||
|
|
use_lidar_intensity=True,
|
|||
|
|
use_camera=with_mask,
|
|||
|
|
),
|
|||
|
|
pipeline=[
|
|||
|
|
dict(
|
|||
|
|
type="LoadPointsFromFile",
|
|||
|
|
coord_type="LIDAR",
|
|||
|
|
load_dim=4,
|
|||
|
|
use_dim=4,
|
|||
|
|
),
|
|||
|
|
dict(
|
|||
|
|
type="LoadAnnotations3D",
|
|||
|
|
with_bbox_3d=True,
|
|||
|
|
with_label_3d=True,
|
|||
|
|
),
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elif dataset_class_name == "NuScenesDataset":
|
|||
|
|
if not load_augmented:
|
|||
|
|
dataset_cfg.update(
|
|||
|
|
use_valid_flag=True,
|
|||
|
|
pipeline=[
|
|||
|
|
dict(
|
|||
|
|
type="LoadPointsFromFile",
|
|||
|
|
coord_type="LIDAR",
|
|||
|
|
load_dim=5,
|
|||
|
|
use_dim=5,
|
|||
|
|
),
|
|||
|
|
dict(
|
|||
|
|
type="LoadPointsFromMultiSweeps",
|
|||
|
|
sweeps_num=10,
|
|||
|
|
use_dim=[0, 1, 2, 3, 4],
|
|||
|
|
pad_empty_sweeps=True,
|
|||
|
|
remove_close=True,
|
|||
|
|
),
|
|||
|
|
dict(
|
|||
|
|
type="LoadAnnotations3D", with_bbox_3d=True, with_label_3d=True
|
|||
|
|
),
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
dataset_cfg.update(
|
|||
|
|
use_valid_flag=True,
|
|||
|
|
pipeline=[
|
|||
|
|
dict(
|
|||
|
|
type="LoadPointsFromFile",
|
|||
|
|
coord_type="LIDAR",
|
|||
|
|
load_dim=16,
|
|||
|
|
use_dim=list(range(16)),
|
|||
|
|
load_augmented=load_augmented,
|
|||
|
|
),
|
|||
|
|
dict(
|
|||
|
|
type="LoadPointsFromMultiSweeps",
|
|||
|
|
sweeps_num=10,
|
|||
|
|
load_dim=16,
|
|||
|
|
use_dim=list(range(16)),
|
|||
|
|
pad_empty_sweeps=True,
|
|||
|
|
remove_close=True,
|
|||
|
|
load_augmented=load_augmented,
|
|||
|
|
),
|
|||
|
|
dict(
|
|||
|
|
type="LoadAnnotations3D", with_bbox_3d=True, with_label_3d=True
|
|||
|
|
),
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elif dataset_class_name == "WaymoDataset":
|
|||
|
|
dataset_cfg.update(
|
|||
|
|
test_mode=False,
|
|||
|
|
split="training",
|
|||
|
|
modality=dict(
|
|||
|
|
use_lidar=True,
|
|||
|
|
use_depth=False,
|
|||
|
|
use_lidar_intensity=True,
|
|||
|
|
use_camera=False,
|
|||
|
|
),
|
|||
|
|
pipeline=[
|
|||
|
|
dict(
|
|||
|
|
type="LoadPointsFromFile",
|
|||
|
|
coord_type="LIDAR",
|
|||
|
|
load_dim=6,
|
|||
|
|
use_dim=5,
|
|||
|
|
),
|
|||
|
|
dict(
|
|||
|
|
type="LoadAnnotations3D",
|
|||
|
|
with_bbox_3d=True,
|
|||
|
|
with_label_3d=True,
|
|||
|
|
),
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
dataset = build_dataset(dataset_cfg)
|
|||
|
|
|
|||
|
|
if database_save_path is None:
|
|||
|
|
database_save_path = osp.join(data_path, f"{info_prefix}_gt_database")
|
|||
|
|
if db_info_save_path is None:
|
|||
|
|
db_info_save_path = osp.join(data_path, f"{info_prefix}_dbinfos_train.pkl")
|
|||
|
|
mmcv.mkdir_or_exist(database_save_path)
|
|||
|
|
all_db_infos = dict()
|
|||
|
|
if with_mask:
|
|||
|
|
coco = COCO(osp.join(data_path, mask_anno_path))
|
|||
|
|
imgIds = coco.getImgIds()
|
|||
|
|
file2id = dict()
|
|||
|
|
for i in imgIds:
|
|||
|
|
info = coco.loadImgs([i])[0]
|
|||
|
|
file2id.update({info["file_name"]: i})
|
|||
|
|
|
|||
|
|
group_counter = 0
|
|||
|
|
for j in track_iter_progress(list(range(len(dataset)))):
|
|||
|
|
input_dict = dataset.get_data_info(j)
|
|||
|
|
dataset.pre_pipeline(input_dict)
|
|||
|
|
example = dataset.pipeline(input_dict)
|
|||
|
|
annos = example["ann_info"]
|
|||
|
|
image_idx = example["sample_idx"]
|
|||
|
|
points = example["points"].tensor.numpy()
|
|||
|
|
gt_boxes_3d = annos["gt_bboxes_3d"].tensor.numpy()
|
|||
|
|
names = annos["gt_names"]
|
|||
|
|
group_dict = dict()
|
|||
|
|
if "group_ids" in annos:
|
|||
|
|
group_ids = annos["group_ids"]
|
|||
|
|
else:
|
|||
|
|
group_ids = np.arange(gt_boxes_3d.shape[0], dtype=np.int64)
|
|||
|
|
difficulty = np.zeros(gt_boxes_3d.shape[0], dtype=np.int32)
|
|||
|
|
if "difficulty" in annos:
|
|||
|
|
difficulty = annos["difficulty"]
|
|||
|
|
|
|||
|
|
num_obj = gt_boxes_3d.shape[0]
|
|||
|
|
point_indices = box_np_ops.points_in_rbbox(points, gt_boxes_3d)
|
|||
|
|
|
|||
|
|
if with_mask:
|
|||
|
|
# prepare masks
|
|||
|
|
gt_boxes = annos["gt_bboxes"]
|
|||
|
|
img_path = osp.split(example["img_info"]["filename"])[-1]
|
|||
|
|
if img_path not in file2id.keys():
|
|||
|
|
print(f"skip image {img_path} for empty mask")
|
|||
|
|
continue
|
|||
|
|
img_id = file2id[img_path]
|
|||
|
|
kins_annIds = coco.getAnnIds(imgIds=img_id)
|
|||
|
|
kins_raw_info = coco.loadAnns(kins_annIds)
|
|||
|
|
kins_ann_info = _parse_coco_ann_info(kins_raw_info)
|
|||
|
|
h, w = annos["img_shape"][:2]
|
|||
|
|
gt_masks = [_poly2mask(mask, h, w) for mask in kins_ann_info["masks"]]
|
|||
|
|
# get mask inds based on iou mapping
|
|||
|
|
bbox_iou = bbox_overlaps(kins_ann_info["bboxes"], gt_boxes)
|
|||
|
|
mask_inds = bbox_iou.argmax(axis=0)
|
|||
|
|
valid_inds = bbox_iou.max(axis=0) > 0.5
|
|||
|
|
|
|||
|
|
# mask the image
|
|||
|
|
# use more precise crop when it is ready
|
|||
|
|
# object_img_patches = np.ascontiguousarray(
|
|||
|
|
# np.stack(object_img_patches, axis=0).transpose(0, 3, 1, 2))
|
|||
|
|
# crop image patches using roi_align
|
|||
|
|
# object_img_patches = crop_image_patch_v2(
|
|||
|
|
# torch.Tensor(gt_boxes),
|
|||
|
|
# torch.Tensor(mask_inds).long(), object_img_patches)
|
|||
|
|
object_img_patches, object_masks = crop_image_patch(
|
|||
|
|
gt_boxes, gt_masks, mask_inds, annos["img"]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
for i in range(num_obj):
|
|||
|
|
filename = f"{image_idx}_{names[i]}_{i}.bin"
|
|||
|
|
abs_filepath = osp.join(database_save_path, filename)
|
|||
|
|
rel_filepath = osp.join(f"{info_prefix}_gt_database", filename)
|
|||
|
|
|
|||
|
|
# save point clouds and image patches for each object
|
|||
|
|
gt_points = points[point_indices[:, i]]
|
|||
|
|
gt_points[:, :3] -= gt_boxes_3d[i, :3]
|
|||
|
|
|
|||
|
|
if with_mask:
|
|||
|
|
if object_masks[i].sum() == 0 or not valid_inds[i]:
|
|||
|
|
# Skip object for empty or invalid mask
|
|||
|
|
continue
|
|||
|
|
img_patch_path = abs_filepath + ".png"
|
|||
|
|
mask_patch_path = abs_filepath + ".mask.png"
|
|||
|
|
mmcv.imwrite(object_img_patches[i], img_patch_path)
|
|||
|
|
mmcv.imwrite(object_masks[i], mask_patch_path)
|
|||
|
|
|
|||
|
|
with open(abs_filepath, "w") as f:
|
|||
|
|
gt_points.tofile(f)
|
|||
|
|
|
|||
|
|
if (used_classes is None) or names[i] in used_classes:
|
|||
|
|
db_info = {
|
|||
|
|
"name": names[i],
|
|||
|
|
"path": rel_filepath,
|
|||
|
|
"image_idx": image_idx,
|
|||
|
|
"gt_idx": i,
|
|||
|
|
"box3d_lidar": gt_boxes_3d[i],
|
|||
|
|
"num_points_in_gt": gt_points.shape[0],
|
|||
|
|
"difficulty": difficulty[i],
|
|||
|
|
}
|
|||
|
|
local_group_id = group_ids[i]
|
|||
|
|
# if local_group_id >= 0:
|
|||
|
|
if local_group_id not in group_dict:
|
|||
|
|
group_dict[local_group_id] = group_counter
|
|||
|
|
group_counter += 1
|
|||
|
|
db_info["group_id"] = group_dict[local_group_id]
|
|||
|
|
if "score" in annos:
|
|||
|
|
db_info["score"] = annos["score"][i]
|
|||
|
|
if with_mask:
|
|||
|
|
db_info.update({"box2d_camera": gt_boxes[i]})
|
|||
|
|
if names[i] in all_db_infos:
|
|||
|
|
all_db_infos[names[i]].append(db_info)
|
|||
|
|
else:
|
|||
|
|
all_db_infos[names[i]] = [db_info]
|
|||
|
|
|
|||
|
|
for k, v in all_db_infos.items():
|
|||
|
|
print(f"load {len(v)} {k} database infos")
|
|||
|
|
|
|||
|
|
with open(db_info_save_path, "wb") as f:
|
|||
|
|
pickle.dump(all_db_infos, f)
|