Source code for torchdrivesim._iou_utils

"""
torch implementation of 2d oriented box intersection

author: Lanxiao Li
copied from https://github.com/lilanxiao/Rotated_IoU
2020.8

MIT License

Copyright (c) 2020 Lanxiao Li

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import Tuple

import numpy as np
import torch

[docs]EPSILON = 1e-8
[docs]def precision_rounding(x, n_digits=6): return (x * 10**n_digits).round() / (10**n_digits)
[docs]def box_intersection_th(corners1: torch.Tensor, corners2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """find intersection points of rectangles. Convention: if two edges are collinear, there is no intersection point. Args: corners1: B, N, 4, 2 corners2: B, N, 4, 2 Returns: A tuple (intersectons, mask) where: intersections: B, N, 4, 4, 2 mask: B, N, 4, 4; bool """ # build edges from corners line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3) # B, N, 4, 4: Batch, Box, edge, point line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3) # duplicate data to pair each edges from the boxes # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point line1_ext = line1.unsqueeze(3).repeat([1,1,1,4,1]) line2_ext = line2.unsqueeze(2).repeat([1,1,4,1,1]) x1 = line1_ext[..., 0] y1 = line1_ext[..., 1] x2 = line1_ext[..., 2] y2 = line1_ext[..., 3] x3 = line2_ext[..., 0] y3 = line2_ext[..., 1] x4 = line2_ext[..., 2] y4 = line2_ext[..., 3] # math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection num = (x1-x2)*(y3-y4) - (y1-y2)*(x3-x4) den_t = (x1-x3)*(y3-y4) - (y1-y3)*(x3-x4) t = den_t / num t[torch.abs(num) < 1e-4] = -1. mask_t = (t > 0) * (t < 1) # intersection on line segment 1 den_u = (x1-x2)*(y1-y3) - (y1-y2)*(x1-x3) u = -den_u / num u[torch.abs(num) < 1e-4] = -1. mask_u = (u > 0) * (u < 1) # intersection on line segment 2 mask = mask_t * mask_u t = den_t / (num + EPSILON) # overwrite with EPSILON. otherwise numerically unstable intersections = torch.stack([x1 + t*(x2-x1), y1 + t*(y2-y1)], dim=-1) intersections = intersections * mask.float().unsqueeze(-1) return intersections, mask
[docs]def box1_in_box2(corners1: torch.Tensor, corners2: torch.Tensor) -> torch.Tensor: """check if corners of box1 lie in box2 Convention: if a corner is exactly on the edge of the other box, it's also a valid point Args: corners1: (B, N, 4, 2) corners2: (B, N, 4, 2) Returns: c1_in_2: (B, N, 4) Bool """ a = corners2[:, :, 0:1, :] # (B, N, 1, 2) b = corners2[:, :, 1:2, :] # (B, N, 1, 2) d = corners2[:, :, 3:4, :] # (B, N, 1, 2) ab = b - a # (B, N, 1, 2) am = corners1 - a # (B, N, 4, 2) ad = d - a # (B, N, 1, 2) p_ab = torch.sum(ab * am, dim=-1) # (B, N, 4) norm_ab = torch.sum(ab * ab, dim=-1) # (B, N, 1) p_ad = torch.sum(ad * am, dim=-1) # (B, N, 4) norm_ad = torch.sum(ad * ad, dim=-1) # (B, N, 1) # NOTE: the expression looks ugly but is stable if the two boxes are exactly the same # also stable with different scale of bboxes cond1 = precision_rounding(p_ab / norm_ab) cond1 = (cond1 > - 1e-6) * (cond1 < 1 + 1e-6) # (B, N, 4) cond2 = precision_rounding(p_ad / norm_ad) cond2 = (cond2 > - 1e-6) * (cond2 < 1 + 1e-6) # (B, N, 4) return cond1*cond2
[docs]def box_in_box_th(corners1:torch.Tensor, corners2:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """check if corners of two boxes lie in each other Args: corners1: (B, N, 4, 2) corners2: (B, N, 4, 2) Returns: A tuple (c1_in_2, c2_in_1) where: c1_in_2: (B, N, 4) Bool. i-th corner of box1 in box2 c2_in_1: (B, N, 4) Bool. i-th corner of box2 in box1 """ c1_in_2 = box1_in_box2(corners1, corners2) c2_in_1 = box1_in_box2(corners2, corners1) return c1_in_2, c2_in_1
[docs]def build_vertices(corners1: torch.Tensor, corners2: torch.Tensor, c1_in_2: torch.Tensor, c2_in_1: torch.Tensor, inters: torch.Tensor, mask_inter: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """find vertices of intersection area Args: corners1: (B, N, 4, 2) corners2: (B, N, 4, 2) c1_in_2: Bool, (B, N, 4) c2_in_1: Bool, (B, N, 4) inters: (B, N, 4, 4, 2) mask_inter: (B, N, 4, 4) Returns: A tuple (vertices, mask) where: vertices: (B, N, 24, 2) vertices of intersection area. only some elements are valid mask: (B, N, 24) indicates valid elements in vertices """ # NOTE: inter has elements equals zero and has zeros gradient (masked by multiplying with 0). # can be used as trick B = corners1.size()[0] N = corners1.size()[1] vertices = torch.cat([corners1, corners2, inters.view([B, N, -1, 2])], dim=2) # (B, N, 4+4+16, 2) mask = torch.cat([c1_in_2, c2_in_1, mask_inter.view([B, N, -1])], dim=2) # Bool (B, N, 4+4+16) return vertices, mask
@torch.no_grad()
[docs]def sort_indices(vertices: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Args: vertices: float (B, N, 24, 2) mask: bool (B, N, 24) Returns: bool (B, N, 9) Note: why 9? the polygon has maximal 8 vertices. +1 to duplicate the first element. the index should have following structure: (A, B, C, ... , A, X, X, X) and X indicates the index of arbitary elements in the last 16 (intersections not corners) with value 0 and mask False. (cause they have zero value and zero gradient) """ B, N = vertices.shape[:2] num_valid = torch.sum(mask.int(), dim=2).int() # (B, N) center = torch.sum(vertices * mask.float().unsqueeze(-1), dim=2, keepdim=True) / num_valid.unsqueeze(-1).unsqueeze(-1) r = torch.sqrt((vertices - center).pow(2).sum(dim=-1)) angles = torch.where((vertices[...,1] - center[...,1])>0, torch.arccos((vertices[...,0] - center[...,0])/r), 2*np.pi-torch.arccos((vertices[...,0] - center[...,0])/r)) angles[~mask] = float('inf') inds_sorted = torch.argsort(angles, descending=False) num_valid = num_valid.view(-1) mask_orig_shape = mask.shape mask = mask.view(-1, 24) inds_sorted = inds_sorted.view(-1, 24) while (num_valid > 8).any(): # Handle an ANNOYING bug from vertices being too close to each other! invalid_mask = num_valid > 8 vertices_flat = vertices.view(-1, 24, 2) inds_expanded = inds_sorted.unsqueeze(-1).expand(*inds_sorted.shape, 2) sorted_vertices = torch.gather(vertices_flat, dim=1, index=inds_expanded) sorted_mask = torch.gather(mask, dim=1, index=inds_sorted) dist = (sorted_vertices[:, :-1] - sorted_vertices[:, 1:]).norm(dim=-1) local_mask = torch.arange(dist.shape[1], device=mask.device).unsqueeze(0).expand(*dist.shape) >= (num_valid-1).unsqueeze(1) dist[local_mask] = float('inf') to_remove_indices = torch.gather(inds_sorted, dim=1, index=dist.argmin(dim=-1).unsqueeze(1)) to_remove_indices = to_remove_indices.squeeze(-1) # Shape: B for i in range(len(to_remove_indices)): if invalid_mask[i]: j = to_remove_indices[i] assert mask[i, j] mask[i, j] = False mask = mask.view(*mask_orig_shape) num_valid = torch.sum(mask.int(), dim=2).int() # (B, N) angles[~mask] = float('inf') inds_sorted = torch.argsort(angles, descending=False) num_valid = num_valid.view(-1) mask = mask.view(-1, 24) inds_sorted = inds_sorted.view(-1, 24) index = inds_sorted[:, :9] pad_values = torch.argmin(mask[:, 8:].float(), dim=-1) + 8 all_pads_mask = num_valid < 3 if all_pads_mask.sum() > 0: index[all_pads_mask] = pad_values[all_pads_mask].unsqueeze(-1) trailing_mask = (torch.arange(9, device=mask.device) >= num_valid.unsqueeze(-1)) if trailing_mask.sum() > 0: index[trailing_mask] = pad_values.unsqueeze(-1).expand_as(index)[trailing_mask] # Repeat the first index after enumerating all indices first_elements = index[:, 0].clone() index[torch.arange(len(index)), num_valid.long()] = first_elements index = index.view(B,N,9) return index
[docs]def calculate_area(idx_sorted: torch.Tensor, vertices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """calculate area of intersection Args: idx_sorted (torch.Tensor): (B, N, 9) vertices (torch.Tensor): (B, N, 24, 2) Returns: Tuple of (area, selected), where area: (B, N), area of intersection selected: (B, N, 9, 2), vertices of polygon with zero padding """ idx_ext = idx_sorted.unsqueeze(-1).repeat([1,1,1,2]) selected = torch.gather(vertices, 2, idx_ext) total = selected[:, :, 0:-1, 0]*selected[:, :, 1:, 1] - selected[:, :, 0:-1, 1]*selected[:, :, 1:, 0] total = torch.sum(total, dim=2) area = torch.abs(total) / 2 return area, selected
[docs]def oriented_box_intersection_2d(corners1: torch.Tensor, corners2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Calculates the intersection area of 2D rectangles. Args: corners1: (B, N, 4, 2) corners2: (B, N, 4, 2) Returns: Tuple of (area, selected): area: (B, N), area of intersection selected: (B, N, 9, 2), vertices of polygon with zero padding """ inters, mask_inter = box_intersection_th(corners1, corners2) c12, c21 = box_in_box_th(corners1, corners2) vertices, mask = build_vertices(corners1, corners2, c12, c21, inters, mask_inter) sorted_indices = sort_indices(vertices, mask) return calculate_area(sorted_indices, vertices)
[docs]def box2corners_th(box: torch.Tensor) -> torch.Tensor: """convert box coordinate to corners Args: box: (B, N, 5) with x, y, w, h, alpha Returns: (B, N, 4, 2) corners """ B = box.size()[0] x = box[..., 0:1] y = box[..., 1:2] w = box[..., 2:3] h = box[..., 3:4] alpha = box[..., 4:5] # (B, N, 1) x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).unsqueeze(0).unsqueeze(0).to(box.device) # (1,1,4) x4 = x4 * w # (B, N, 4) y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).unsqueeze(0).unsqueeze(0).to(box.device) y4 = y4 * h # (B, N, 4) corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) sin = torch.sin(alpha) cos = torch.cos(alpha) row1 = torch.cat([cos, sin], dim=-1) row2 = torch.cat([-sin, cos], dim=-1) # (B, N, 2) rot_T = torch.stack([row1, row2], dim=-2) # (B, N, 2, 2) rotated = torch.bmm(corners.view([-1,4,2]), rot_T.view([-1,2,2])) rotated = rotated.view([B,-1,4,2]) # (B*N, 4, 2) -> (B, N, 4, 2) rotated[..., 0] += x rotated[..., 1] += y return rotated
[docs]def box2corners_with_rear_factor(box: torch.Tensor, rear_factor: float = 1) -> torch.Tensor: """Convert bounding box coordinates to corners. The returned corners are starting from the rear of the bounding box up to a rear factor. If the rear factor is 1, then the whole bounding box is returned. Args: box: (B, N, 5) with x, y, w, h, alpha rear_factor: The relative amount of the bounding box will be preserved starting from the rear corners and up to rear_factor * w. If the factor is 1, then the whole bounding box is preserved. Returns: (B, N, 4, 2) corners """ B = box.size()[0] x = box[..., 0:1] y = box[..., 1:2] w = box[..., 2:3] h = box[..., 3:4] alpha = box[..., 4:5] # (B, N, 1) x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).unsqueeze(0).unsqueeze(0).to(box.device) # (1,1,4) x4 = x4 * w * rear_factor # (B, N, 4) y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).unsqueeze(0).unsqueeze(0).to(box.device) y4 = y4 * h # (B, N, 4) corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) sin = torch.sin(alpha) cos = torch.cos(alpha) row1 = torch.cat([cos, sin], dim=-1) row2 = torch.cat([-sin, cos], dim=-1) # (B, N, 2) rot_T = torch.stack([row1, row2], dim=-2) # (B, N, 2, 2) rotated = torch.bmm(corners.view([-1,4,2]), rot_T.view([-1,2,2])) rotated = rotated.view([B,-1,4,2]) # (B*N, 4, 2) -> (B, N, 4, 2) center_correction = torch.cat([ (w * (1 - rear_factor))/2, torch.zeros_like(h) ], dim=-1) center_correction = torch.bmm(center_correction.view([-1,1,2]), rot_T.view([-1,2,2])) center_correction = center_correction.view([B,-1,2]) rotated[..., 0] += x - center_correction[..., 0:1] rotated[..., 1] += y - center_correction[..., 1:2] return rotated
[docs]def iou_differentiable_fast(box1:torch.Tensor, box2:torch.Tensor) -> torch.Tensor: """Calculates the differentiable (approx.) IOU between two sets of bounding boxes. B is the batch size and N is the number of bounding boxes. A bounding box is described by 5 values in this order: (x, y, w, h, alpha). This method uses the shoelace formula (https://en.wikipedia.org/wiki/Shoelace_formula) to calculate the area of a convex polygon (the overlapping area of two rectangles). It offers fast computation at the expense of some IoU accuracy. Args: box1: (B, N, 5) box2: (B, N, 5) Returns: iou: (B, N) """ corners1 = box2corners_th(box1) corners2 = box2corners_th(box2) inter_area, _ = oriented_box_intersection_2d(corners1, corners2) #(B, N) area1 = box1[:, :, 2] * box1[:, :, 3] area2 = box2[:, :, 2] * box2[:, :, 3] u = area1 + area2 - inter_area iou = inter_area / u return iou
[docs]def iou_non_differentiable(boxes:torch.Tensor) -> torch.Tensor: """ Args: boxes: (N, 5) Returns: iou: (N, N) """ from pytorch3d.ops import box3d_overlap corners = box2corners_th(boxes.unsqueeze(0)).squeeze(0) corners_3d = torch.nn.functional.pad(corners, (0, 1)) corners_3d = torch.cat([corners_3d, corners_3d], dim=1) corners_3d[:, 4:, 2] = 1.0 _, iou = box3d_overlap(corners_3d, corners_3d) return iou