torchdrivesim.traffic_controls#

Definitions of traffic controls. Currently, we support traffic lights, stop signs, and yield signs.

Module Contents#

Classes#

BaseTrafficControl

Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state.

TrafficLightControl

Traffic lights, with default allowed states of ['red', 'yellow', 'green'].

YieldControl

Yield sign, indicating that cross traffic has priority.

StopSignControl

Stop sign, indicating the vehicle should stop and yield to cross traffic.

class torchdrivesim.traffic_controls.BaseTrafficControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#

Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state. States are advanced by calling step and can be set either from a given recorded history of states or by compute_state if current timestep exceeds the given recorded history.

Parameters:
  • pos – BxNx5 stoplines tensor [x, y, length/thickness, width, orientation]

  • allowed_states – Allowed values of state, e.g. colors for traffic lights.

  • replay_states – BxNxT tensor of states to replay, represented as indices into allowed_states. Default T=0.

  • mask – BxN boolean tensor indicating whether a given traffic control element is present and not a padding.

property total_replay_time: int[source]#
classmethod _default_allowed_states() List[str][source]#
_default_replay_states() torch.Tensor[source]#
_default_mask() torch.Tensor[source]#
_default_state() torch.Tensor[source]#
copy()[source]#

Duplicates this object, allowing for independent subsequent execution.

to(device: torch.device)[source]#

Modifies the object in-place, putting all tensors on the device provided.

extend(n: int, in_place: bool = True)[source]#

Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening.

select_batch_elements(idx: torch.Tensor, in_place: bool = True)[source]#

Picks selected elements of the batch. The input is a tensor of indices into the batch dimension.

set_state(state: torch.Tensor) None[source]#

Sets control state tensor.

compute_state(time: int) torch.Tensor[source]#

Computes the state at a given time index. By default, it repeats the current state. Ignores replay states.

step(time: int) None[source]#

Advances the state of the control given a time step. If the time value is within the range of replayable states then this state will be used otherwise the function compute_state(time) will be used to calculate a new state.

compute_violation(agent_state: torch.Tensor) torch.Tensor[source]#

Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.

Parameters:

agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation

Returns:

BxA boolean tensor indicating which agents violated the traffic control

class torchdrivesim.traffic_controls.TrafficLightControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#

Bases: BaseTrafficControl

Traffic lights, with default allowed states of [‘red’, ‘yellow’, ‘green’]. An agent is regarded to violate the traffic light if the light is red and the agent bounding box substantially overlaps with the stop line.

violation_rear_factor = 0.1[source]#
classmethod _default_allowed_states() List[str][source]#
compute_violation(agent_state: torch.Tensor) torch.Tensor[source]#

Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.

Parameters:

agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation

Returns:

BxA boolean tensor indicating which agents violated the traffic control

class torchdrivesim.traffic_controls.YieldControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#

Bases: BaseTrafficControl

Yield sign, indicating that cross traffic has priority. Violations are not computed.

class torchdrivesim.traffic_controls.StopSignControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]#

Bases: BaseTrafficControl

Stop sign, indicating the vehicle should stop and yield to cross traffic. Violations are not computed.