"""
Definitions of traffic controls. Currently, we support traffic lights, stop signs, and yield signs.
"""
from typing import List, Optional
import torch
from torch import Tensor
from torchdrivesim._iou_utils import oriented_box_intersection_2d, box2corners_th, box2corners_with_rear_factor
[docs]class BaseTrafficControl:
"""
Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state.
States are advanced by calling `step` and can be set either from a given recorded history of states or
by `compute_state` if current timestep exceeds the given recorded history.
Args:
pos: BxNx5 stoplines tensor [x, y, length/thickness, width, orientation]
allowed_states: Allowed values of state, e.g. colors for traffic lights.
replay_states: BxNxT tensor of states to replay, represented as indices into `allowed_states`. Default T=0.
mask: BxN boolean tensor indicating whether a given traffic control element is present and not a padding.
"""
def __init__(self, pos: Tensor, allowed_states: Optional[List[str]] = None,
replay_states: Optional[Tensor] = None, mask: Optional[Tensor] = None):
[docs] self.allowed_states = allowed_states if allowed_states is not None else self._default_allowed_states()
[docs] self.replay_states = replay_states if replay_states is not None else self._default_replay_states()
[docs] self.mask = mask if mask is not None else self._default_mask()
[docs] self.corners = box2corners_th(self.pos)
corner_mask = self.mask.to(self.corners.dtype).reshape(self.mask.shape[0], self.mask.shape[1], 1, 1)
self.corners = self.corners * corner_mask + (1 - corner_mask) * -1000 # Set masked bboxes far from center
[docs] self.state = self._default_state()
@classmethod
[docs] def _default_allowed_states(cls) -> List[str]:
return ['none']
[docs] def _default_replay_states(self) -> Tensor:
return torch.zeros(*self.pos.shape[:2] + (0,), dtype=torch.long, device=self.pos.device)
[docs] def _default_mask(self) -> Tensor:
return torch.ones(*self.pos.shape[:2], dtype=torch.bool, device=self.pos.device)
[docs] def _default_state(self) -> Tensor:
if self.replay_states.shape[-1] > 0:
return self.replay_states[..., 0]
else:
return torch.zeros(*self.pos.shape[:2], dtype=torch.long, device=self.pos.device)
@property
[docs] def total_replay_time(self) -> int:
return self.replay_states.shape[-1]
[docs] def copy(self):
"""
Duplicates this object, allowing for independent subsequent execution.
"""
other = self.__class__(
pos=self.pos.clone(), allowed_states=self.allowed_states.copy(),
replay_states=self.replay_states.clone(), mask=self.mask.clone(),
)
other.state = self.state.clone()
return other
[docs] def to(self, device: torch.device):
"""
Modifies the object in-place, putting all tensors on the device provided.
"""
self.pos = self.pos.to(device)
self.corners = self.corners.to(device)
self.replay_states = self.replay_states.to(device)
self.mask = self.mask.to(device)
self.state = self.state.to(device)
return self
[docs] def extend(self, n: int, in_place: bool = True):
"""
Multiplies the first batch dimension by the given number.
This is equivalent to introducing extra batch dimension on the right and then flattening.
"""
if not in_place:
other = self.copy()
other.extend(n, in_place=True)
return other
enlarge = lambda x: x.unsqueeze(1).expand(
(x.shape[0], n) + x.shape[1:]
).reshape((n * x.shape[0],) + x.shape[1:])
self.pos = enlarge(self.pos)
self.corners = enlarge(self.corners)
self.mask = enlarge(self.mask)
self.replay_states = enlarge(self.replay_states)
self.state = enlarge(self.state)
return self
[docs] def select_batch_elements(self, idx: Tensor, in_place: bool = True):
"""
Picks selected elements of the batch.
The input is a tensor of indices into the batch dimension.
"""
if not in_place:
other = self.copy()
other.select_batch_elements(idx, in_place=True)
return other
self.pos = self.pos[idx]
self.corners = self.corners[idx]
self.mask = self.mask[idx]
self.replay_states = self.replay_states[idx]
self.state = self.state[idx]
return self
[docs] def set_state(self, state: Tensor) -> None:
"""
Sets control state tensor.
"""
self.state = state
[docs] def compute_state(self, time: int) -> Tensor:
"""
Computes the state at a given time index. By default, it repeats the current state. Ignores replay states.
"""
return self.state
[docs] def step(self, time: int) -> None:
"""
Advances the state of the control given a time step. If the time value is within the range of replayable states
then this state will be used otherwise the function `compute_state(time)` will be used to calculate a new state.
"""
if time < self.total_replay_time:
self.set_state(self.replay_states[..., time])
else:
state = self.compute_state(time)
self.set_state(state)
[docs] def compute_violation(self, agent_state: Tensor) -> Tensor:
"""
Given a collection of agents, computes which agents violate the traffic control.
As deciding violations is typically complex, this function provides an approximate answer only.
The base class reports no violations.
Args:
agent_state: BxAx5 tensor of agents states - x,y,length,width,orientation
Returns:
BxA boolean tensor indicating which agents violated the traffic control
"""
return torch.zeros(agent_state.shape[0], agent_state.shape[1], dtype=torch.bool, device=agent_state.device)
[docs]class TrafficLightControl(BaseTrafficControl):
"""
Traffic lights, with default allowed states of ['red', 'yellow', 'green'].
An agent is regarded to violate the traffic light if the light is red and the agent
bounding box substantially overlaps with the stop line.
"""
[docs] violation_rear_factor = 0.1
@classmethod
[docs] def _default_allowed_states(cls) -> List[str]:
return ['red', 'yellow', 'green']
[docs] def compute_violation(self, agent_state: Tensor) -> Tensor:
batch_size, num_agents = agent_state.shape[0], agent_state.shape[1]
num_lights = self.corners.shape[1]
if batch_size > 0 and num_agents > 0 and num_lights > 0:
agent_corners = box2corners_with_rear_factor(agent_state, rear_factor=self.violation_rear_factor)
agent_corners = agent_corners.reshape(-1, 1, 4, 2).expand(-1, num_lights, -1, -1)
control_corners = self.corners.unsqueeze(1).expand(-1, num_agents, -1, -1, -1).reshape(-1, num_lights, 4, 2)
overlap = (oriented_box_intersection_2d(agent_corners, control_corners)[0] > 0)
is_red_light = (self.state.to(overlap.device) == self.allowed_states.index('red'))
is_red_light = is_red_light.unsqueeze(1).expand(-1, num_agents, -1).reshape(-1, num_lights)
red_light_violations = torch.logical_and(overlap, is_red_light)
red_light_violations = red_light_violations.any(dim=-1).reshape(batch_size, num_agents)
else:
red_light_violations = torch.zeros(batch_size, num_agents, dtype=torch.bool, device=agent_state.device)
return red_light_violations
[docs]class YieldControl(BaseTrafficControl):
"""
Yield sign, indicating that cross traffic has priority.
Violations are not computed.
"""
pass
[docs]class StopSignControl(BaseTrafficControl):
"""
Stop sign, indicating the vehicle should stop and yield to cross traffic.
Violations are not computed.
"""
pass