from typing import List, Optional, Dict
import random
import torch
from torch import Tensor
import invertedai as iai
from invertedai.common import AgentState, AgentProperties, Point, TrafficLightState, RecurrentState
from torchdrivesim.behavior.common import InitializationFailedError
from torchdrivesim.simulator import NPCController, Simulator, SpawnController
from torchdrivesim.traffic_lights import TrafficLightController
[docs]def unpack_attributes(attributes) -> torch.Tensor:
return torch.tensor([attributes.length, attributes.width, attributes.rear_axis_offset])
[docs]def agent_attributes_to_basic_agent_properties(agent_attributes: Tensor) -> Dict[str, float]:
return {
'length': agent_attributes[0],
'width': agent_attributes[1],
'rear_axis_offset': agent_attributes[2],
}
[docs]def agent_properties_to_agent_attributes(agent_properties: Dict[str, float]) -> Tensor:
return torch.Tensor([agent_properties['length'], agent_properties['width'], agent_properties['rear_axis_offset']])
[docs]def iai_initialize(location, agent_count, center=(0, 0), traffic_light_state_history: Optional[List[Dict[int, TrafficLightState]]] = None) -> (Tensor, Tensor, List[RecurrentState]):
import invertedai
try:
seed = random.randint(1, 10000)
response = invertedai.api.initialize(
location=location, agent_count=agent_count, location_of_interest=center,
traffic_light_state_history=traffic_light_state_history,
random_seed=seed,
)
except invertedai.error.InvalidRequestError:
raise InitializationFailedError()
agent_properties_dict = [dict(ap) for ap in response.agent_properties]
agent_properties = [agent_properties_to_agent_attributes(ap) for ap in agent_properties_dict]
agent_attributes = torch.stack(agent_properties, dim=-2)
agent_states = torch.stack(
[torch.tensor(st.tolist()) for st in response.agent_states], dim=-2
)
return agent_attributes, agent_states, response.recurrent_states
[docs]def iai_drive(location: str, agent_states: Tensor, agent_attributes: Tensor, recurrent_states: List[RecurrentState], traffic_lights_states: Optional[Dict[int, TrafficLightState]] = None) -> (Tensor, List[RecurrentState]):
import invertedai
from invertedai.common import AgentState, AgentProperties, Point
agent_properties = [AgentProperties(**agent_attributes_to_basic_agent_properties(at)) for at in agent_attributes]
agent_states = [AgentState(center=Point(x=st[0], y=st[1]), orientation=st[2], speed=st[3]) for st in agent_states]
seed = random.randint(1, 10000)
response = invertedai.api.drive(
location=location, agent_states=agent_states, agent_properties=agent_properties,
recurrent_states=recurrent_states,
traffic_lights_states=traffic_lights_states,
random_seed=seed
)
agent_states = torch.stack(
[torch.tensor(st.tolist()) for st in response.agent_states], dim=-2
)
return agent_states, response.recurrent_states
[docs]class IAINPCController(NPCController):
def __init__(self, npc_size, npc_state, npc_lr, location, drive_model_version: Optional[str] = None, npc_present_mask: Optional[torch.Tensor] = None, time: int = 0,
npc_types: Optional[Tensor] = None, agent_type_names: Optional[List[str]] = None, traffic_light_controller: Optional[TrafficLightController] = None,
light_states_all_timesteps: Optional[List[Dict[int, TrafficLightState]]] = None, spawn_controller: Optional[SpawnController] = None):
"""
Calls IAI API to predict the next states of NPCs, to make the NPCs aware of ego, also passes ego states during the call, but removes it afterwards.
npc_size: BxAx2 tensor with the length and width of agents
npc_state: BxAx4 tensor with the states of agents
npc_lr: BxA tensor with the rear axis offsets of agents
npc_present_mask: BxA boolean tensor with the presence masks of agents
npc_types: BxA tensor with the types of agents
"""
super().__init__(npc_size, npc_state, npc_present_mask, npc_types, agent_type_names, spawn_controller)
[docs] self.location = location
[docs] self.agent_type_names = agent_type_names
if self.agent_type_names is None:
self.agent_type_names = ['vehicle']
[docs] self.npc_attribute = torch.cat([npc_size, npc_lr.unsqueeze(-1)], dim=-1) # npc_attribute is of shape BxAx3
[docs] self.recurrent_state = None
[docs] self.drive_model_version = drive_model_version
[docs] self.npc_state_shape = self.npc_state.shape
[docs] self._traffic_light_controller = traffic_light_controller
[docs] self.light_states_all_timesteps = light_states_all_timesteps
[docs] def advance_npcs(self, simulator: Simulator) -> None:
ego_state = simulator.get_state()
ego_size = simulator.get_agent_size()
ego_type = simulator.get_agent_type()
def get_iai_agent_type(npc_type):
agent_type = self.agent_type_names[npc_type]
if agent_type != "pedestrian":
return "car"
return agent_type
def iai_drive(location, agent_states, agent_attributes, recurrent_states, drive_model_version, present_mask):
agent_states = agent_states[:, present_mask[0], :]
agent_attributes = agent_attributes[:, present_mask[0], :]
npc_types = self.npc_types[:, present_mask[0]]
if agent_states.shape[-2] == 0 or agent_attributes.shape[-2] == 0:
return agent_states, []
agent_properties = [AgentProperties(length=agent_attributes[0][i][0],
width=agent_attributes[0][i][1],
rear_axis_offset=agent_attributes[0][i][2] if not agent_attributes[0][i][2].isnan() else 0,
agent_type=get_iai_agent_type(npc_types[0][i])) for i in range(agent_attributes.shape[-2])]
agent_states = [AgentState(center=Point(x=agent_states[0][i][0], y=agent_states[0][i][1]),
orientation=agent_states[0][i][2],
speed=agent_states[0][i][3]) for i in range(agent_states.shape[-2])]
# insert ego vehicle to API payload
agent_properties.insert(0, AgentProperties(length=ego_size[0, 0, 0],
width=ego_size[0, 0, 1],
rear_axis_offset=ego_size[0, 0, 0].item() / 4,
agent_type=get_iai_agent_type(ego_type[0, 0])))
agent_states.insert(0, (AgentState(center=Point(x=ego_state[0, 0, 0], y=ego_state[0, 0, 1]),
orientation=ego_state[0, 0, 2],
speed=ego_state[0, 0, 3])))
seed = random.randint(1, 10000)
if self.light_states_all_timesteps is not None:
traffic_light_states = self.light_states_all_timesteps[self.time]
elif self._traffic_light_controller is not None:
self._traffic_light_controller.tick(0.1)
traffic_light_states = self._traffic_light_controller.current_state_with_name
else:
traffic_light_states = None
if len(agent_states) < 100:
response = iai.drive(location=location, agent_states=agent_states, agent_properties=agent_properties,
recurrent_states=recurrent_states, random_seed=seed, api_model_version=drive_model_version, traffic_lights_states=traffic_light_states)
else:
response = iai.large_drive(location=location, agent_states=agent_states, agent_properties=agent_properties,
recurrent_states=recurrent_states, random_seed=seed, api_model_version=drive_model_version, traffic_lights_states=traffic_light_states)
agent_states = torch.stack([torch.tensor(st.tolist()) for st in response.agent_states[1:]], dim=-2)
return agent_states, response.recurrent_states
self.time += 1
predicted_npc_state, predicted_npc_recurrent_state = iai_drive(location=self.location, agent_states=self.npc_state, agent_attributes=self.npc_attribute,
recurrent_states=self.recurrent_state, drive_model_version=self.drive_model_version, present_mask=self.npc_present_mask)
self.npc_state = predicted_npc_state.clone().detach().unsqueeze(0).to(self.npc_state.device)
reconstructed_state = torch.full(self.npc_state_shape, 0.0).to(self.npc_state.device)
reconstructed_state[:, self.npc_present_mask[0], :] = self.npc_state
self.npc_state = reconstructed_state
self.recurrent_state = predicted_npc_recurrent_state
self.spawn_despawn_npcs(simulator)
[docs] def copy(self):
"""Create a deep copy of the controller."""
other = self.__class__(
npc_size=self.npc_size.clone(),
npc_state=self.npc_state.clone(),
npc_lr=self.npc_attribute[..., 2].clone(),
location=self.location,
drive_model_version=self.drive_model_version,
npc_present_mask=self.npc_present_mask.clone() if self.npc_present_mask is not None else None,
time=self.time,
npc_types=self.npc_types.clone() if self.npc_types is not None else None,
agent_type_names=self.agent_type_names.copy() if self.agent_type_names is not None else None,
traffic_light_controller=self._traffic_light_controller,
light_states_all_timesteps=self.light_states_all_timesteps,
spawn_controller=self.spawn_controller.copy() if self.spawn_controller is not None else None
)
other.npc_attribute = self.npc_attribute.clone()
other.recurrent_state = self.recurrent_state
other.npc_state_shape = self.npc_state_shape
return other
[docs] def to(self, device):
super().to(device)
self.npc_size = self.npc_size.to(device)
self.npc_state = self.npc_state.to(device)
self.npc_present_mask = self.npc_present_mask.to(device)
self.npc_types = self.npc_types.to(device)
self.spawn_controller.to(device)
return self