"""
Definitions of goal conditions. Currently, we support waypoint conditions.
"""
import copy
import dataclasses
from typing import List, Optional, Union, Dict
import torch
import torch.nn.functional as F
from torch import Tensor
[docs]TensorPerAgentType = Union[Tensor, Dict[str, Tensor]]
[docs]class WaypointGoal:
"""
Waypoints can be given either per agent type or directly as tensors. The waypoint tensor can contain `M` waypoints
for each collection of `N` waypoints that progressively get enabled. The provided mask should indicated which waypoints
are padding elements. A waypoint is marked as succesful when the agent state reaches the center of the waypoint
within some distance.
Args:
waypoints: a functor of BxAxNxMx2 waypoint tensors [x, y]
mask: a functor of BxAxNxM boolean tensor indicating whether a given waypoint element is present and not a padding.
"""
def __init__(self, waypoints: TensorPerAgentType, mask: Optional[TensorPerAgentType] = None):
self.waypoints = waypoints
self.mask = mask if mask is not None else self._default_mask()
if isinstance(self.waypoints, dict):
self.max_goal_idx = max([v.shape[2] for v in self.waypoints.values()])
else:
self.max_goal_idx = self.waypoints.shape[2]
self.state = self._default_state() # BxAx1
[docs] def _default_mask(self) -> Tensor:
if isinstance(self.waypoints, dict):
return {k: torch.ones(*v.shape[:-1], dtype=torch.bool, device=v.device) for k,v in self.waypoints.items()}
else:
return torch.ones(*self.waypoints.shape[:-1], dtype=torch.bool, device=self.waypoints.device)
[docs] def _default_state(self) -> Tensor:
if isinstance(self.waypoints, dict):
return {k: torch.zeros(*v.shape[:2] + (1, ), dtype=torch.long, device=v.device) for k,v in self.waypoints.items()}
else:
return torch.zeros(*self.waypoints.shape[:2] + (1, ), dtype=torch.long, device=self.waypoints.device)
[docs] def get_masks(self):
"""
Returns the waypoint mask according to the current state value.
"""
if isinstance(self.waypoints, dict):
return {k: torch.gather(v, 2, self.state[k][..., None].expand(-1, -1, -1, *v.shape[3:])).squeeze(2)
for k,v in self.mask.items()}
else:
return torch.gather(self.mask, 2, self.state[..., None].expand(-1, -1, -1, *self.mask.shape[3:])).squeeze(2)
[docs] def get_waypoints(self):
"""
Returns the waypoints according to the current state value.
"""
if isinstance(self.waypoints, dict):
return {k: torch.gather(v, 2, self.state[k][..., None, None].expand(-1, -1, -1, *v.shape[3:])).squeeze(2)
for k,v in self.waypoints.items()}
else:
return torch.gather(self.waypoints, 2, self.state[..., None, None]\
.expand(-1, -1, -1, *self.waypoints.shape[3:])).squeeze(2)
[docs] def copy(self):
"""
Duplicates this object, allowing for independent subsequent execution.
"""
other = self.__class__(
waypoints=copy.deepcopy(self.waypoints), mask=copy.deepcopy(self.mask),
)
other.state = copy.deepcopy(self.state)
return other
[docs] def to(self, device: torch.device):
"""
Modifies the object in-place, putting all tensors on the device provided.
"""
if isinstance(self.waypoints, dict):
self.waypoints = {k: v.to(device) for k, v in self.waypoints.items()}
else:
self.waypoints = self.waypoints.to(device)
if isinstance(self.mask, dict):
self.mask = {k: v.to(device) for k, v in self.mask.items()}
else:
self.mask = self.mask.to(device)
if isinstance(self.state, dict):
self.state = {k: v.to(device) for k, v in self.state.items()}
else:
self.state = self.state.to(device)
return self
[docs] def extend(self, n: int, in_place: bool = True):
"""
Multiplies the first batch dimension by the given number.
This is equivalent to introducing extra batch dimension on the right and then flattening.
"""
if not in_place:
other = self.copy()
other.extend(n, in_place=True)
return other
enlarge = lambda x: x.unsqueeze(1).expand(
(x.shape[0], n) + x.shape[1:]
).reshape((n * x.shape[0],) + x.shape[1:])
if isinstance(self.waypoints, dict):
self.waypoints = {k: enlarge(v) for k, v in self.waypoints.items()}
else:
self.waypoints = enlarge(self.waypoints)
if isinstance(self.mask, dict):
self.mask = {k: enlarge(v) for k, v in self.mask.items()}
else:
self.mask = enlarge(self.mask)
if isinstance(self.state, dict):
self.state = {k: enlarge(v) for k, v in self.state.items()}
else:
self.state = enlarge(self.state)
return self
[docs] def select_batch_elements(self, idx: Tensor, in_place: bool = True):
"""
Picks selected elements of the batch.
The input is a tensor of indices into the batch dimension.
"""
if not in_place:
other = self.copy()
other.select_batch_elements(idx, in_place=True)
return other
if isinstance(self.waypoints, dict):
self.waypoints = {k: v[idx] for k, v in self.waypoints.items()}
else:
self.waypoints = self.waypoints[idx]
if isinstance(self.mask, dict):
self.mask = {k: v[idx] for k, v in self.mask.items()}
else:
self.mask = self.mask[idx]
if isinstance(self.state, dict):
self.state = {k: v[idx] for k, v in self.state.items()}
else:
self.state = self.state[idx]
return self
[docs] def step(self, agent_states: TensorPerAgentType, time: int = 0, threshold: float = 2.0) -> None:
"""
Advances the state of the waypoints.
"""
if isinstance(agent_states, dict):
assert isinstance(self.waypoints, dict)
assert sum([a.shape[1] for a in agent_states.values()]) == sum(w.shape[1] for w in self.waypoints.values())
assert list(agent_states.keys()) == list(self.waypoints.keys())
waypoints = self.get_waypoints()
masks = self.get_masks()
for agent_type, agent_state in agent_states.items():
agent_overlap = self._agent_waypoint_overlap(agent_state[..., :2], waypoints[agent_type], threshold=threshold)
agent_overlap = agent_overlap.any(dim=-1, keepdim=True).expand_as(agent_overlap)
agent_overlap = torch.where(masks[agent_type], agent_overlap, True)
self.mask[agent_type] = self._update_mask(self.state[agent_type], self.mask[agent_type],
agent_overlap.logical_not())
self.state[agent_type] = self._advance_state(self.state[agent_type], agent_overlap)
else:
assert torch.is_tensor(self.waypoints)
assert agent_states.shape[1] == self.waypoints.shape[1]
waypoints = self.get_waypoints()
masks = self.get_masks()
agent_overlap = self._agent_waypoint_overlap(agent_states[..., :2], waypoints, threshold=threshold)
agent_overlap = agent_overlap.any(dim=-1, keepdim=True).expand_as(agent_overlap)
agent_overlap = torch.where(masks, agent_overlap, True)
self.mask = self._update_mask(self.state, self.mask, agent_overlap.logical_not())
self.state = self._advance_state(self.state, agent_overlap)
[docs] def _update_mask(self, state, mask, new_mask):
"""
Helper function to update the global masks on the current state with the `new_mask` values.
Args:
state: BxAx1 int tensor
mask: BxAxNxM boolean tensor
new_mask: BxAxM boolean tensor
Returns:
BxAxNxM boolean tensor updated with the `new_mask` values
"""
return torch.scatter(mask, 2, state[..., None].expand(-1, -1, -1, mask.shape[-1]), new_mask.unsqueeze(2))
[docs] def _advance_state(self, state, agent_overlap):
"""
Helper function to advance the state when agents overlap with waypoints.
Args:
state: BxAx1 int tensor
agent_overlap: BxAxM boolean tensor
Returns:
BxAxM boolean tensor indicating the state of the waypoints
"""
return torch.clamp(state + agent_overlap.any(dim=-1, keepdim=True), min=0, max=self.max_goal_idx-1)
[docs] def _agent_waypoint_overlap(self, agent_states: Tensor, waypoints: Tensor, threshold = 2.0) -> Tensor:
"""
Helper function to detect if agent states overlap with the current waypoints.
Args:
agent_states: BxAx2 tensor of agents states - x,y
waypoints: BxAxMx2 tensor of waypoint
Returns:
BxAxM boolean tensor indicating which agents overlap with the waypoints
"""
agent_states = agent_states[..., None, :].expand_as(waypoints)
dist = (((agent_states[..., 0] - waypoints[..., 0])**2) + ((agent_states[..., 1] - waypoints[..., 1])**2))**0.5
return dist <= threshold