torchdrivesim.goals =================== .. py:module:: torchdrivesim.goals .. autoapi-nested-parse:: Definitions of goal conditions. Currently, we support waypoint conditions. Attributes ---------- .. autoapisummary:: torchdrivesim.goals.TensorPerAgentType Classes ------- .. autoapisummary:: torchdrivesim.goals.WaypointGoal Module Contents --------------- .. py:data:: TensorPerAgentType .. py:class:: WaypointGoal(waypoints: TensorPerAgentType, mask: Optional[TensorPerAgentType] = None) Waypoints can be given either per agent type or directly as tensors. The waypoint tensor can contain `M` waypoints for each collection of `N` waypoints that progressively get enabled. The provided mask should indicated which waypoints are padding elements. A waypoint is marked as succesful when the agent state reaches the center of the waypoint within some distance. :param waypoints: a functor of BxAxNxMx2 waypoint tensors [x, y] :param mask: a functor of BxAxNxM boolean tensor indicating whether a given waypoint element is present and not a padding. .. py:method:: _default_mask() -> torch.Tensor .. py:method:: _default_state() -> torch.Tensor .. py:method:: get_masks() Returns the waypoint mask according to the current state value. .. py:method:: get_waypoints() Returns the waypoints according to the current state value. .. py:method:: copy() Duplicates this object, allowing for independent subsequent execution. .. py:method:: to(device: torch.device) Modifies the object in-place, putting all tensors on the device provided. .. py:method:: extend(n: int, in_place: bool = True) Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening. .. py:method:: select_batch_elements(idx: torch.Tensor, in_place: bool = True) Picks selected elements of the batch. The input is a tensor of indices into the batch dimension. .. py:method:: step(agent_states: TensorPerAgentType, time: int = 0, threshold: float = 2.0) -> None Advances the state of the waypoints. .. py:method:: _update_mask(state, mask, new_mask) Helper function to update the global masks on the current state with the `new_mask` values. :param state: BxAx1 int tensor :param mask: BxAxNxM boolean tensor :param new_mask: BxAxM boolean tensor :returns: BxAxNxM boolean tensor updated with the `new_mask` values .. py:method:: _advance_state(state, agent_overlap) Helper function to advance the state when agents overlap with waypoints. :param state: BxAx1 int tensor :param agent_overlap: BxAxM boolean tensor :returns: BxAxM boolean tensor indicating the state of the waypoints .. py:method:: _agent_waypoint_overlap(agent_states: torch.Tensor, waypoints: torch.Tensor, threshold=2.0) -> torch.Tensor Helper function to detect if agent states overlap with the current waypoints. :param agent_states: BxAx2 tensor of agents states - x,y :param waypoints: BxAxMx2 tensor of waypoint :returns: BxAxM boolean tensor indicating which agents overlap with the waypoints