torchdrivesim.traffic_controls ============================== .. py:module:: torchdrivesim.traffic_controls .. autoapi-nested-parse:: Definitions of traffic controls. Currently, we support traffic lights, stop signs, and yield signs. Classes ------- .. autoapisummary:: torchdrivesim.traffic_controls.BaseTrafficControl torchdrivesim.traffic_controls.TrafficLightControl torchdrivesim.traffic_controls.YieldControl torchdrivesim.traffic_controls.StopSignControl Module Contents --------------- .. py:class:: BaseTrafficControl(pos: torch.Tensor, allowed_states: Optional[List[str]] = None, replay_states: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None) Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state. States are advanced by calling `step` and can be set either from a given recorded history of states or by `compute_state` if current timestep exceeds the given recorded history. :param pos: BxNx5 stoplines tensor [x, y, length/thickness, width, orientation] :param allowed_states: Allowed values of state, e.g. colors for traffic lights. :param replay_states: BxNxT tensor of states to replay, represented as indices into `allowed_states`. Default T=0. :param mask: BxN boolean tensor indicating whether a given traffic control element is present and not a padding. .. py:attribute:: pos .. py:attribute:: allowed_states :value: None .. py:attribute:: replay_states :value: None .. py:attribute:: mask :value: None .. py:attribute:: corners .. py:attribute:: state .. py:method:: _default_allowed_states() -> List[str] :classmethod: .. py:method:: _default_replay_states() -> torch.Tensor .. py:method:: _default_mask() -> torch.Tensor .. py:method:: _default_state() -> torch.Tensor .. py:property:: total_replay_time :type: int .. py:method:: copy() Duplicates this object, allowing for independent subsequent execution. .. py:method:: to(device: torch.device) Modifies the object in-place, putting all tensors on the device provided. .. py:method:: extend(n: int, in_place: bool = True) Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening. .. py:method:: select_batch_elements(idx: torch.Tensor, in_place: bool = True) Picks selected elements of the batch. The input is a tensor of indices into the batch dimension. .. py:method:: set_state(state: torch.Tensor) -> None Sets control state tensor. .. py:method:: compute_state(time: int) -> torch.Tensor Computes the state at a given time index. By default, it repeats the current state. Ignores replay states. .. py:method:: step(time: int) -> None Advances the state of the control given a time step. If the time value is within the range of replayable states then this state will be used otherwise the function `compute_state(time)` will be used to calculate a new state. .. py:method:: compute_violation(agent_state: torch.Tensor) -> torch.Tensor Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations. :param agent_state: BxAx5 tensor of agents states - x,y,length,width,orientation :returns: BxA boolean tensor indicating which agents violated the traffic control .. py:class:: TrafficLightControl(pos: torch.Tensor, allowed_states: Optional[List[str]] = None, replay_states: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None) Bases: :py:obj:`BaseTrafficControl` Traffic lights, with default allowed states of ['red', 'yellow', 'green']. An agent is regarded to violate the traffic light if the light is red and the agent bounding box substantially overlaps with the stop line. .. py:attribute:: violation_rear_factor :value: 0.1 .. py:method:: _default_allowed_states() -> List[str] :classmethod: .. py:method:: compute_violation(agent_state: torch.Tensor) -> torch.Tensor Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations. :param agent_state: BxAx5 tensor of agents states - x,y,length,width,orientation :returns: BxA boolean tensor indicating which agents violated the traffic control .. py:class:: YieldControl(pos: torch.Tensor, allowed_states: Optional[List[str]] = None, replay_states: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None) Bases: :py:obj:`BaseTrafficControl` Yield sign, indicating that cross traffic has priority. Violations are not computed. .. py:class:: StopSignControl(pos: torch.Tensor, allowed_states: Optional[List[str]] = None, replay_states: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None) Bases: :py:obj:`BaseTrafficControl` Stop sign, indicating the vehicle should stop and yield to cross traffic. Violations are not computed.