torchdrivesim.goals

Definitions of goal conditions. Currently, we support waypoint conditions.

Attributes

Classes

WaypointGoal

Waypoints can be given either per agent type or directly as tensors. The waypoint tensor can contain M waypoints

Module Contents

torchdrivesim.goals.TensorPerAgentType[source]
class torchdrivesim.goals.WaypointGoal(waypoints: TensorPerAgentType, mask: TensorPerAgentType | None = None)[source]

Waypoints can be given either per agent type or directly as tensors. The waypoint tensor can contain M waypoints for each collection of N waypoints that progressively get enabled. The provided mask should indicated which waypoints are padding elements. A waypoint is marked as succesful when the agent state reaches the center of the waypoint within some distance.

Parameters:
  • waypoints – a functor of BxAxNxMx2 waypoint tensors [x, y]

  • mask – a functor of BxAxNxM boolean tensor indicating whether a given waypoint element is present and not a padding.

_default_mask() torch.Tensor[source]
_default_state() torch.Tensor[source]
get_masks()[source]

Returns the waypoint mask according to the current state value.

get_waypoints()[source]

Returns the waypoints according to the current state value.

copy()[source]

Duplicates this object, allowing for independent subsequent execution.

to(device: torch.device)[source]

Modifies the object in-place, putting all tensors on the device provided.

extend(n: int, in_place: bool = True)[source]

Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening.

select_batch_elements(idx: torch.Tensor, in_place: bool = True)[source]

Picks selected elements of the batch. The input is a tensor of indices into the batch dimension.

step(agent_states: TensorPerAgentType, time: int = 0, threshold: float = 2.0) None[source]

Advances the state of the waypoints.

_update_mask(state, mask, new_mask)[source]

Helper function to update the global masks on the current state with the new_mask values.

Parameters:
  • state – BxAx1 int tensor

  • mask – BxAxNxM boolean tensor

  • new_mask – BxAxM boolean tensor

Returns:

BxAxNxM boolean tensor updated with the new_mask values

_advance_state(state, agent_overlap)[source]

Helper function to advance the state when agents overlap with waypoints.

Parameters:
  • state – BxAx1 int tensor

  • agent_overlap – BxAxM boolean tensor

Returns:

BxAxM boolean tensor indicating the state of the waypoints

_agent_waypoint_overlap(agent_states: torch.Tensor, waypoints: torch.Tensor, threshold=2.0) torch.Tensor[source]

Helper function to detect if agent states overlap with the current waypoints.

Parameters:
  • agent_states – BxAx2 tensor of agents states - x,y

  • waypoints – BxAxMx2 tensor of waypoint

Returns:

BxAxM boolean tensor indicating which agents overlap with the waypoints