torchdrivesim.goals
#
Definitions of goal conditions. Currently, we support waypoint conditions.
Module Contents#
Classes#
Waypoints can be given either per agent type or directly as tensors. The waypoint tensor can contain M waypoints |
Attributes#
- class torchdrivesim.goals.WaypointGoal(waypoints: TensorPerAgentType, mask: TensorPerAgentType | None = None)[source]#
Waypoints can be given either per agent type or directly as tensors. The waypoint tensor can contain M waypoints for each collection of N waypoints that progressively get enabled. The provided mask should indicated which waypoints are padding elements. A waypoint is marked as succesful when the agent state reaches the center of the waypoint within some distance.
- Parameters:
waypoints – a functor of BxAxNxMx2 waypoint tensors [x, y]
mask – a functor of BxAxNxM boolean tensor indicating whether a given waypoint element is present and not a padding.
- to(device: torch.device)[source]#
Modifies the object in-place, putting all tensors on the device provided.
- extend(n: int, in_place: bool = True)[source]#
Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening.
- select_batch_elements(idx: torch.Tensor, in_place: bool = True)[source]#
Picks selected elements of the batch. The input is a tensor of indices into the batch dimension.
- step(agent_states: TensorPerAgentType, time: int, threshold=2.0) None [source]#
Advances the state of the waypoints.
- _update_mask(state, mask, new_mask)[source]#
Helper function to update the global masks on the current state with the new_mask values.
- Parameters:
state – BxAx1 int tensor
mask – BxAxNxM boolean tensor
new_mask – BxAxM boolean tensor
- Returns:
BxAxNxM boolean tensor updated with the new_mask values
- _advance_state(state, agent_overlap)[source]#
Helper function to advance the state when agents overlap with waypoints.
- Parameters:
state – BxAx1 int tensor
agent_overlap – BxAxM boolean tensor
- Returns:
BxAxM boolean tensor indicating the state of the waypoints
- _agent_waypoint_overlap(agent_states: torch.Tensor, waypoints: torch.Tensor, threshold=2.0) torch.Tensor [source]#
Helper function to detect if agent states overlap with the current waypoints.
- Parameters:
agent_states – BxAx2 tensor of agents states - x,y
waypoints – BxAxMx2 tensor of waypoint
- Returns:
BxAxM boolean tensor indicating which agents overlap with the waypoints