torchdrivesim.traffic_controls

Definitions of traffic controls. Currently, we support traffic lights, stop signs, and yield signs.

Classes

BaseTrafficControl

Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state.

TrafficLightControl

Traffic lights, with default allowed states of ['red', 'yellow', 'green'].

YieldControl

Yield sign, indicating that cross traffic has priority.

StopSignControl

Stop sign, indicating the vehicle should stop and yield to cross traffic.

Module Contents

class torchdrivesim.traffic_controls.BaseTrafficControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state. States are advanced by calling step and can be set either from a given recorded history of states or by compute_state if current timestep exceeds the given recorded history.

Parameters:
  • pos – BxNx5 stoplines tensor [x, y, length/thickness, width, orientation]

  • allowed_states – Allowed values of state, e.g. colors for traffic lights.

  • replay_states – BxNxT tensor of states to replay, represented as indices into allowed_states. Default T=0.

  • mask – BxN boolean tensor indicating whether a given traffic control element is present and not a padding.

classmethod _default_allowed_states() List[str][source]
_default_replay_states() torch.Tensor[source]
_default_mask() torch.Tensor[source]
_default_state() torch.Tensor[source]
property total_replay_time: int[source]
copy()[source]

Duplicates this object, allowing for independent subsequent execution.

to(device: torch.device)[source]

Modifies the object in-place, putting all tensors on the device provided.

extend(n: int, in_place: bool = True)[source]

Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening.

select_batch_elements(idx: torch.Tensor, in_place: bool = True)[source]

Picks selected elements of the batch. The input is a tensor of indices into the batch dimension.

set_state(state: torch.Tensor) None[source]

Sets control state tensor.

compute_state(time: int) torch.Tensor[source]

Computes the state at a given time index. By default, it repeats the current state. Ignores replay states.

step(time: int) None[source]

Advances the state of the control given a time step. If the time value is within the range of replayable states then this state will be used otherwise the function compute_state(time) will be used to calculate a new state.

compute_violation(agent_state: torch.Tensor) torch.Tensor[source]

Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.

Parameters:

agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation

Returns:

BxA boolean tensor indicating which agents violated the traffic control

class torchdrivesim.traffic_controls.TrafficLightControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Bases: BaseTrafficControl

Traffic lights, with default allowed states of [‘red’, ‘yellow’, ‘green’]. An agent is regarded to violate the traffic light if the light is red and the agent bounding box substantially overlaps with the stop line.

violation_rear_factor = 0.1[source]
classmethod _default_allowed_states() List[str][source]
compute_violation(agent_state: torch.Tensor) torch.Tensor[source]

Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.

Parameters:

agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation

Returns:

BxA boolean tensor indicating which agents violated the traffic control

class torchdrivesim.traffic_controls.YieldControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Bases: BaseTrafficControl

Yield sign, indicating that cross traffic has priority. Violations are not computed.

class torchdrivesim.traffic_controls.StopSignControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Bases: BaseTrafficControl

Stop sign, indicating the vehicle should stop and yield to cross traffic. Violations are not computed.