from typing import List, Optional, Dict
import random
import torch
from torch import Tensor
from typing_extensions import Self
from invertedai.common import TrafficLightState, RecurrentState
from torchdrivesim.behavior.common import InitializationFailedError
from torchdrivesim.simulator import NPCWrapper, SimulatorInterface, TensorPerAgentType, HomogeneousWrapper
from torchdrivesim.traffic_lights import TrafficLightController, current_light_state_tensor_from_controller
[docs]def unpack_attributes(attributes) -> torch.Tensor:
return torch.tensor([attributes.length, attributes.width, attributes.rear_axis_offset])
[docs]def iai_initialize(location, agent_count, center=(0, 0), traffic_light_state_history: Optional[List[Dict[int, TrafficLightState]]] = None) -> (Tensor, Tensor, List[RecurrentState]):
import invertedai
try:
seed = random.randint(1, 10000)
response = invertedai.api.initialize(
location=location, agent_count=agent_count, location_of_interest=center,
traffic_light_state_history=traffic_light_state_history,
random_seed=seed,
)
except invertedai.error.InvalidRequestError:
raise InitializationFailedError()
agent_attributes = torch.stack(
[torch.tensor(unpack_attributes(at)) for at in response.agent_attributes], dim=-2
)
agent_states = torch.stack(
[torch.tensor(st.tolist()) for st in response.agent_states], dim=-2
)
return agent_attributes, agent_states, response.recurrent_states
[docs]def iai_drive(location: str, agent_states: Tensor, agent_attributes: Tensor, recurrent_states: List[RecurrentState], traffic_lights_states: Optional[Dict[int, TrafficLightState]] = None) -> (Tensor, List[RecurrentState]):
import invertedai
from invertedai.common import AgentState, AgentAttributes, Point
agent_attributes = [AgentAttributes(length=at[0], width=at[1], rear_axis_offset=at[2]) for at in agent_attributes]
agent_states = [AgentState(center=Point(x=st[0], y=st[1]), orientation=st[2], speed=st[3]) for st in agent_states]
seed = random.randint(1, 10000)
response = invertedai.api.drive(
location=location, agent_states=agent_states, agent_attributes=agent_attributes,
recurrent_states=recurrent_states,
traffic_lights_states=traffic_lights_states,
random_seed=seed
)
agent_states = torch.stack(
[torch.tensor(st.tolist()) for st in response.agent_states], dim=-2
)
return agent_states, response.recurrent_states
[docs]class IAIWrapper(NPCWrapper):
"""
Uses IAI API to control NPCs, making a call every time step.
Requires IAI_API_KEY to be set.
This wrapper should be applied before any other agent-hiding wrappers.
Args:
simulator: existing simulator to wrap
npc_mask: A functor of tensors with a single dimension of size A, indicating which agents to replay.
The tensors can not have batch dimensions.
recurrent_states: A functor of BxAxN tensors with NPC recurrent states, usually obtained from `iai_initialize`.
rear_axis_offset: A BxAx1 tensor, concatenated across agent types, specifying the rear axis offset parameter
used in the kinematic bicycle model. By default, a realistic value based on the vehicle length is used.
"""
def __init__(self, simulator: SimulatorInterface, npc_mask: TensorPerAgentType,
recurrent_states: List[List], locations: List[str], rear_axis_offset: Optional[Tensor] = None,
replay_states: Optional[Tensor] = None, replay_mask: Optional[Tensor] = None,
traffic_light_controller: Optional[TrafficLightController] = None, traffic_light_ids: Optional[List[int]] = None):
super().__init__(simulator=simulator, npc_mask=npc_mask)
[docs] self._locations = locations
[docs] self._npc_predictions = None
[docs] self._recurrent_states = recurrent_states
[docs] self._replay_states = replay_states
[docs] self._replay_mask = replay_mask
[docs] self._replay_timestep = 0
[docs] self._traffic_light_controller = traffic_light_controller
[docs] self._traffic_light_ids = traffic_light_ids
assert (self._traffic_light_controller is None) == (self._traffic_light_ids is None), \
"Both _traffic_light_controller and _traffic_light_ids should be either None or not None"
[docs] lenwid = HomogeneousWrapper(self.inner_simulator).get_agent_size()[..., :2]
if rear_axis_offset is None:
rear_axis_offset = 1.4 * torch.ones_like(lenwid[..., :1]) # TODO: use value proportional to length
[docs] self._agent_attributes = torch.cat([lenwid, rear_axis_offset], dim=-1)
[docs] def _get_npc_predictions(self):
states, recurrent = [], []
agent_states = HomogeneousWrapper(self.inner_simulator).get_state()
for i in range(self.batch_size):
s, r = iai_drive(location=self._locations[i], agent_states=agent_states[i],
agent_attributes=self._agent_attributes[i], recurrent_states=self._recurrent_states[i],
traffic_lights_states=self._traffic_light_controller.current_state_with_name if self._traffic_light_controller is not None else None)
if self._replay_states is not None:
if self._replay_timestep < self._replay_states.shape[2]:
s[self._replay_mask[i, :, self._replay_timestep]] = self._replay_states[i, self._replay_mask[i, :, self._replay_timestep], self._replay_timestep, :]
states.append(s)
recurrent.append(r)
states = torch.stack(states, dim=0).to(self.get_state().device)
return states, recurrent
[docs] def _npc_teleport_to(self):
return self._npc_predictions
[docs] def step(self, action):
self._npc_predictions, self._recurrent_states = self._get_npc_predictions() # TODO: run async after step
if self._traffic_light_controller is not None:
self._traffic_light_controller.tick(0.1)
self.get_innermost_simulator().traffic_controls['traffic_light'].set_state(
current_light_state_tensor_from_controller(self._traffic_light_controller, self._traffic_light_ids).unsqueeze(0).expand(self.batch_size, -1))
super().step(action)
self._npc_predictions = None
self._replay_timestep += 1
[docs] def to(self, device) -> Self:
super().to(device)
[docs] def copy(self):
inner_copy = self.inner_simulator.copy()
other = self.__class__(inner_copy, npc_mask=self.npc_mask, recurrent_states=self._recurrent_states,
locations=self._locations)
return other
[docs] def select_batch_elements(self, idx, in_place=True):
other = super().select_batch_elements(idx, in_place=in_place)
other.replay_states = other.across_agent_types(lambda x: x[idx], other.replay_states)
other.present_masks = other.across_agent_types(lambda x: x[idx], other.present_masks)
other._batch_size = len(idx)
return other