Source code for smart_traffic_env.env

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from typing import Tuple, Optional

from .demand import DemandGenerator
from .flow_model import (
    compute_reward,
    compute_service_rate,
    G_MAX,
    G_MIN,
    DEFAULT_ALL_RED,
    DEFAULT_SAT_FLOW,
)


[docs] class UrbanTrafficEnv(gym.Env): """Gym-style env for urban traffic signal control.""" def __init__( self, num_intersections: int = 4, lanes_per_intersection: int = 2, base_green: float = 30.0, delta_max: float = 5.0, control_interval: float = 60.0, episode_length: int = 60, demand_profile: Optional[np.ndarray] = None, seed: int = None, ) -> None: super(UrbanTrafficEnv, self).__init__() self.num_intersections = num_intersections self.lanes_per_intersection = lanes_per_intersection self.base_green = base_green self.delta_max = delta_max self.control_interval = control_interval self.episode_length = episode_length self.demand_profile = demand_profile self.seed = seed self.rng = np.random.RandomState(seed) self.num_lanes = self.num_intersections * self.lanes_per_intersection self.demand_gen = DemandGenerator( num_steps=episode_length, num_lanes=self.num_lanes, rng=self.rng ) N = self.num_intersections M = self.num_lanes # Define constants for observation space bounds MAX_QUEUE_VEHICLES = 5000.0 MAX_DEMAND_VEHICLES_PER_HOUR = 500.0 obs_low = np.concatenate([ np.zeros(M, dtype=np.float32), # Queues np.zeros(M, dtype=np.float32), # Demand np.full(N, G_MIN, dtype=np.float32) # Greens ]) obs_high = np.concatenate([ np.full(M, MAX_QUEUE_VEHICLES, dtype=np.float32), np.full(M, MAX_DEMAND_VEHICLES_PER_HOUR, dtype=np.float32), np.full(N, G_MAX, dtype=np.float32) ]) self.observation_space = spaces.Box( low=obs_low, high=obs_high, dtype=np.float32 ) self.action_space = spaces.Box( low=-self.delta_max, high=self.delta_max, shape=(N,), dtype=np.float32 ) # Env state self.step_count = 0 self.queues = np.zeros(M, dtype=np.float32) self.greens = np.full(N, self.base_green, dtype=np.float32) self.demand_trajectory = None # For diagnostics self.current_reward = 0 self.current_info = {}
[docs] def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, dict]: """Executes one time step within the environment.""" # 1. Clip action to be within [-delta_max, +delta_max] clipped_action = np.clip(action, -self.delta_max, self.delta_max) # 2. Update green times self.greens = np.clip( self.greens + clipped_action, G_MIN, G_MAX ).astype(np.float32) # 3. Fetch current demand arrival = self.demand_trajectory[self.step_count] # 4. Compute service rates service = compute_service_rate( greens=self.greens, num_lanes=self.lanes_per_intersection, all_red_time=DEFAULT_ALL_RED, saturation_flow=DEFAULT_SAT_FLOW, ) # 5. Update queues arrival_rate_per_sec = arrival / 3600.0 self.queues = np.maximum( 0, self.queues + (arrival_rate_per_sec * self.control_interval) - (service * self.control_interval), ).astype(np.float32) # 6. Compute reward reward = compute_reward(self.queues) self.current_reward = reward # 7. Increment step count and check for termination self.step_count += 1 done = self.step_count >= self.episode_length # 8. Assemble next observation and info obs = self._get_obs() info = self._get_info() self.current_info = info return obs, reward, done, False, info # Gymnasium expects (obs, rew, terminated, truncated, info)
[docs] def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple[np.ndarray, dict]: # Reset the state of the environment to an initial state super().reset(seed=seed) if seed is not None: self.rng = np.random.RandomState(seed) self.demand_gen.rng = self.rng self.step_count = 0 self.queues = np.zeros(self.num_lanes, dtype=np.float32) self.greens = np.full(self.num_intersections, self.base_green, dtype=np.float32) if self.demand_profile is not None: self.demand_trajectory = self.demand_profile else: self.demand_trajectory = self.demand_gen.generate() # Initial observation obs = self._get_obs() info = self._get_info() return obs, info
[docs] def render(self, mode: str = "human"): """Prints a one-line summary of the current state.""" if mode == "human": avg_queue = np.mean(self.queues) print( f"Step: {self.step_count}, " f"Avg Queues: {avg_queue:.2f}, " f"Greens: {np.array2string(self.greens, precision=1)}" ) else: super().render(mode=mode) # Let gym handle other modes
[docs] def close(self): """Clean up any resources.""" pass
def _get_obs(self) -> np.ndarray: """Constructs the observation vector.""" # Get next demand slice if self.step_count < self.episode_length: next_demand = self.demand_trajectory[self.step_count] else: next_demand = np.zeros(self.num_lanes) # Concatenate [queues, next_demand, greens] obs = np.concatenate([self.queues, next_demand, self.greens]).astype(np.float32) return obs def _get_info(self) -> dict: """Constructs the info dictionary.""" return {"queues": self.queues, "greens": self.greens}