torchdrivesim.traffic_controls
#
Definitions of traffic controls. Currently, we support traffic lights, stop signs, and yield signs.
Module Contents#
Classes#
Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state. |
|
Traffic lights, with default allowed states of ['red', 'yellow', 'green']. |
|
Yield sign, indicating that cross traffic has priority. |
|
Stop sign, indicating the vehicle should stop and yield to cross traffic. |
- class torchdrivesim.traffic_controls.BaseTrafficControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#
Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state. States are advanced by calling step and can be set either from a given recorded history of states or by compute_state if current timestep exceeds the given recorded history.
- Parameters:
pos – BxNx5 stoplines tensor [x, y, length/thickness, width, orientation]
allowed_states – Allowed values of state, e.g. colors for traffic lights.
replay_states – BxNxT tensor of states to replay, represented as indices into allowed_states. Default T=0.
mask – BxN boolean tensor indicating whether a given traffic control element is present and not a padding.
- to(device: torch.device)[source]#
Modifies the object in-place, putting all tensors on the device provided.
- extend(n: int, in_place: bool = True)[source]#
Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening.
- select_batch_elements(idx: torch.Tensor, in_place: bool = True)[source]#
Picks selected elements of the batch. The input is a tensor of indices into the batch dimension.
- compute_state(time: int) torch.Tensor [source]#
Computes the state at a given time index. By default, it repeats the current state. Ignores replay states.
- step(time: int) None [source]#
Advances the state of the control given a time step. If the time value is within the range of replayable states then this state will be used otherwise the function compute_state(time) will be used to calculate a new state.
- compute_violation(agent_state: torch.Tensor) torch.Tensor [source]#
Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.
- Parameters:
agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation
- Returns:
BxA boolean tensor indicating which agents violated the traffic control
- class torchdrivesim.traffic_controls.TrafficLightControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#
Bases:
BaseTrafficControl
Traffic lights, with default allowed states of [‘red’, ‘yellow’, ‘green’]. An agent is regarded to violate the traffic light if the light is red and the agent bounding box substantially overlaps with the stop line.
- compute_violation(agent_state: torch.Tensor) torch.Tensor [source]#
Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.
- Parameters:
agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation
- Returns:
BxA boolean tensor indicating which agents violated the traffic control
- class torchdrivesim.traffic_controls.YieldControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#
Bases:
BaseTrafficControl
Yield sign, indicating that cross traffic has priority. Violations are not computed.
- class torchdrivesim.traffic_controls.StopSignControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#
Bases:
BaseTrafficControl
Stop sign, indicating the vehicle should stop and yield to cross traffic. Violations are not computed.