Source code for torchdrivesim.behavior.replay

import os
from typing import Optional, List

import numpy as np
import pandas as pd
import torch
from torch import Tensor

from torchdrivesim.behavior.common import InitializationFailedError
from torchdrivesim.simulator import NPCController, Simulator, SpawnController


[docs]def interaction_replay(location, dataset_path, initial_frame=1, segment_length=40, recording=0): recording_path = os.path.join(dataset_path, 'recorded_trackfiles', location, 'vehicle_tracks_{:03d}.csv'.format(recording)) df = pd.read_csv(recording_path) final_frame = initial_frame + segment_length - 1 for frame in [initial_frame, final_frame]: if frame not in df.frame_id.unique(): raise InitializationFailedError(f'Frame {frame} not available in {recording_path}') df = df[(df.frame_id >= initial_frame) & (df.frame_id <= final_frame)].copy() df = df.sort_values(['track_id', 'frame_id']) df['rear_offset'] = 1.4 agent_ids = sorted(df.track_id.unique()) agent_attributes = [] for agent_id in agent_ids: attr = df[df.track_id == agent_id][['length', 'width', 'rear_offset']].to_numpy() attr = torch.from_numpy(attr).mean(dim=-2) agent_attributes.append(attr) agent_attributes = torch.stack(agent_attributes, dim=-2).unsqueeze(0) df['present'] = True df['speed'] = np.sqrt(df.vx ** 2 + df.vy ** 2) frame_ids = sorted(df.frame_id.unique()) dense_index = pd.MultiIndex.from_product([agent_ids, frame_ids], names=["track_id", "frame_id"]) padding = pd.DataFrame(index=dense_index, data=dict(x=0.0, y=0.0, psi_rad=0.0, speed=0.0, present=False)) daug = df.set_index(['track_id', 'frame_id']).reindex(dense_index).combine_first(padding) agent_states = torch.from_numpy(daug[['x', 'y', 'psi_rad', 'speed']].to_numpy()).unsqueeze(0) agent_states = agent_states.reshape(1, len(agent_ids), len(frame_ids), 4) present_mask = torch.from_numpy(daug['present'].astype(bool).to_numpy()).unsqueeze(0) present_mask = present_mask.reshape(1, len(agent_ids), len(frame_ids)) return agent_attributes, agent_states, present_mask
[docs]class ReplayController(NPCController): def __init__(self, npc_size, npc_states, npc_present_masks: Optional[torch.Tensor] = None, time: int = 0, npc_types: Optional[Tensor] = None, agent_type_names: Optional[List[str]] = None, spawn_controller: Optional[SpawnController] = None):
[docs] self.time = time
[docs] self.npc_states = npc_states
[docs] self.npc_present_masks = npc_present_masks
if self.npc_present_masks is None: self.npc_present_masks = torch.ones_like(self.npc_states[..., 0], dtype=torch.bool) super().__init__(npc_size, self.npc_states[..., self.time, :], self.npc_present_masks[..., self.time], npc_types, agent_type_names, spawn_controller)
[docs] def advance_npcs(self, simulator: Simulator) -> None: self.time += 1 if self.time == self.npc_states.shape[-2]: self.time = 0 self.npc_state = self.npc_states[..., self.time, :] self.npc_present_mask = self.npc_present_masks[..., self.time] self.spawn_despawn_npcs(simulator)
[docs] def to(self, device): self.npc_size = self.npc_size.to(device) self.npc_state = self.npc_state.to(device) self.npc_present_mask = self.npc_present_mask.to(device) self.npc_types = self.npc_types.to(device) self.npc_states = self.npc_states.to(device) self.npc_present_masks = self.npc_present_masks.to(device) self.spawn_controller.to(device) return self
[docs] def copy(self): obj = self.__class__(self.npc_size, self.npc_states, self.npc_present_masks, self.time, self.npc_types, self.agent_type_names, self.spawn_controller.copy()) obj.npc_state = self.npc_state.clone() obj.npc_present_mask = self.npc_present_mask.clone() return obj
[docs] def extend(self, n, in_place=True): if not in_place: other = self.copy() other.extend(n, in_place=True) return other enlarge = lambda x: x.unsqueeze(1).expand((x.shape[0], n) + x.shape[1:]).reshape((n * x.shape[0],) + x.shape[1:]) self.npc_size = enlarge(self.npc_size) self.npc_state = enlarge(self.npc_state) self.npc_present_mask = enlarge(self.npc_present_mask) self.npc_types = enlarge(self.npc_types) self.npc_states = enlarge(self.npc_states) self.npc_present_masks = enlarge(self.npc_present_masks) self.spawn_controller.extend(n, in_place=True) return self
[docs] def select_batch_elements(self, idx, in_place=True): if not in_place: return self.copy().select_batch_elements(idx, in_place=True) self.npc_size = self.npc_size[idx] self.npc_state = self.npc_state[idx] self.npc_present_mask = self.npc_present_mask[idx] self.npc_types = self.npc_types[idx] self.npc_states = self.npc_states[idx] self.npc_present_masks = self.npc_present_masks[idx] self.spawn_controller.select_batch_elements(idx, in_place=True) return self