""" Spatial Graph System for hierarchical location management and perception propagation. Implements: - Hierarchical location graph (world → region → location → POI) - Perception bubble-up algorithm with LLM-powered filtering - Portal-based information propagation (vision/sound) - Spatial metadata for entity positioning within locations """ import json import logging from dataclasses import dataclass, field, asdict from typing import Optional, Dict, List, Tuple, Any from enum import Enum logger = logging.getLogger(__name__) class PropagationType(Enum): """Types of information that propagate through space.""" VISION = "vision" SOUND = "sound" ACTION = "action" @dataclass class PortalConnection: """Represents a connection between two spatial nodes with propagation properties.""" target: str # Target location_id portal: Optional[str] = None # Portal type (door, window, peephole, etc.) portal_state_descriptor: Optional[str] = ( None # Current state of portal (open, closed, tinted, etc.) ) vision_prop: int = 5 # 0-10 scale: 0=no vision, 10=clear vision sound_prop: int = 5 # 0-10 scale: 0=no sound, 10=all sound propagates bidirectional: bool = True # Whether connection is symmetric def to_dict(self) -> Dict[str, Any]: return asdict(self) @dataclass class PerceptionInfo: """Filtered/transformed perception for an entity.""" recipient_id: str original_action: str transformed_action: Optional[str] # None if blocked propagation_path: str # e.g., "direct", "through_wall", "muffled_from_distance" vision_clarity: int # 0-10 sound_clarity: int # 0-10 perceivable: bool # Whether entity perceives anything @dataclass class SpatialNode: """A node in the spatial hierarchy (world, region, location, POI).""" id: str name: str description: str node_type: str # "world", "region", "location", "poi" parent_id: Optional[str] = None children: List[str] = field(default_factory=list) connections: List[PortalConnection] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: data = asdict(self) data["connections"] = [c.to_dict() for c in self.connections] return data @dataclass class EntityPosition: """Spatial metadata for an entity's position.""" entity_id: str location_id: str # Leaf-level location (POI or location) spatial_descriptor: str # "Leaning against the far left edge, near the kegs" class SpatialGraph: """ Manages hierarchical spatial structure and perception propagation. Hierarchy: World → Region → Location → POI (Point of Interest) Entities have: - location_id: Which leaf node they're in - spatial_descriptor: Precise position within that node """ def __init__(self): self.nodes: Dict[str, SpatialNode] = {} self.entity_positions: Dict[str, EntityPosition] = {} self.node_hierarchy: Dict[str, List[str]] = {} # location_id → path to root def load_from_json(self, data: Dict[str, Any]) -> None: """Load spatial graph from JSON structure (from demo.json).""" logger.info("Loading spatial graph from JSON") self._build_hierarchy(data["world"], None) logger.info(f"Spatial graph loaded: {len(self.nodes)} nodes") def _build_hierarchy( self, node_data: Dict[str, Any], parent_id: Optional[str] ) -> None: """Recursively build spatial hierarchy from JSON.""" node_id = node_data.get("id") node_type = self._determine_node_type(parent_id) connections = [] for conn_data in node_data.get("connections", []): conn = PortalConnection( target=conn_data["target"], portal=conn_data.get("portal"), portal_state_descriptor=conn_data.get("portal_state_descriptor"), vision_prop=conn_data.get("vision_prop", 5), sound_prop=conn_data.get("sound_prop", 5), bidirectional=conn_data.get("bidirectional", True), ) connections.append(conn) node = SpatialNode( id=node_id, name=node_data.get("name", node_id), description=node_data.get("description", ""), node_type=node_type, parent_id=parent_id, connections=connections, ) self.nodes[node_id] = node # Process children children_key = self._get_children_key(node_type) for child_data in node_data.get(children_key, []): child_id = child_data.get("id") node.children.append(child_id) self._build_hierarchy(child_data, node_id) def _determine_node_type(self, parent_id: Optional[str]) -> str: """Determine node type based on parent.""" if parent_id is None: return "world" parent = self.nodes.get(parent_id) if parent.node_type == "world": return "region" elif parent.node_type == "region": return "location" else: return "poi" def _get_children_key(self, node_type: str) -> str: """Get the key for children based on node type.""" mapping = { "world": "regions", "region": "locations", "location": "pois", "poi": "items", } return mapping.get(node_type, "children") def set_entity_position( self, entity_id: str, location_id: str, spatial_descriptor: str = "" ) -> None: """Set entity's spatial position.""" if location_id not in self.nodes: logger.warning(f"Unknown location_id: {location_id}") return node = self.nodes[location_id] if node.node_type not in ["location", "poi"]: logger.warning( f"Entity position must be at leaf level (location or poi), got {node.node_type}" ) return self.entity_positions[entity_id] = EntityPosition( entity_id=entity_id, location_id=location_id, spatial_descriptor=spatial_descriptor, ) logger.info( f"Entity {entity_id} positioned at {location_id}: {spatial_descriptor}" ) def get_entity_position(self, entity_id: str) -> Optional[EntityPosition]: """Get entity's current position.""" return self.entity_positions.get(entity_id) def get_entities_in_location(self, location_id: str) -> List[str]: """Get all entities in a specific location.""" return [ entity_id for entity_id, pos in self.entity_positions.items() if pos.location_id == location_id ] def get_path_to_root(self, node_id: str) -> List[str]: """Get path from node to root (world).""" path = [] current = node_id while current: path.append(current) node = self.nodes.get(current) if not node: break current = node.parent_id return path def get_immediate_children(self, node_id: str) -> List[str]: """Get immediate children of a node.""" node = self.nodes.get(node_id) return node.children if node else [] def get_all_descendants(self, node_id: str) -> List[str]: """Get all descendants recursively.""" descendants = [] node = self.nodes.get(node_id) if not node: return descendants for child_id in node.children: descendants.append(child_id) descendants.extend(self.get_all_descendants(child_id)) return descendants def get_sibling_locations(self, node_id: str) -> List[str]: """Get all sibling nodes (shared parent).""" node = self.nodes.get(node_id) if not node or not node.parent_id: return [] parent = self.nodes.get(node.parent_id) if not parent: return [] return [child for child in parent.children if child != node_id] def get_connected_locations( self, node_id: str ) -> List[Tuple[str, PortalConnection]]: """ Get all locations connected via portals. Returns: List of (location_id, PortalConnection) """ node = self.nodes.get(node_id) if not node: return [] connections = [] for portal_conn in node.connections: target_node = self.nodes.get(portal_conn.target) if target_node: connections.append((portal_conn.target, portal_conn)) return connections def broadcast_to_immediate_location(self, location_id: str) -> List[str]: """ Step 1: Immediate Leaf Check Get all entities in the same leaf location (no loss). """ return self.get_entities_in_location(location_id) def broadcast_to_siblings_with_portals( self, location_id: str, action: str, llm_filter=None ) -> Dict[str, PerceptionInfo]: """ Step 2: Sibling Check (Horizontal Propagation) Propagate perception to sibling locations through portals. llm_filter: Callable that transforms action based on portal properties Returns: Dict mapping entity_id → PerceptionInfo """ perceptions: Dict[str, PerceptionInfo] = {} connected = self.get_connected_locations(location_id) for target_id, portal_conn in connected: # Get entities in target location target_entities = self.get_entities_in_location(target_id) # Transform action based on portal properties transformed_action = action if llm_filter: try: transformed_action = llm_filter( action=action, source_id=location_id, target_id=target_id, portal_conn=portal_conn, ) except Exception as e: logger.error(f"LLM filter failed: {e}") transformed_action = action for entity_id in target_entities: perceptions[entity_id] = PerceptionInfo( recipient_id=entity_id, original_action=action, transformed_action=transformed_action, propagation_path="portal", vision_clarity=portal_conn.vision_prop, sound_clarity=portal_conn.sound_prop, perceivable=transformed_action is not None, ) logger.info(f"Sibling broadcast: {len(perceptions)} entities reached") return perceptions def broadcast_to_parent_children( self, location_id: str, action: str, escalation_check=None ) -> Dict[str, PerceptionInfo]: """ Step 3: Parent Check (Vertical Propagation) Check if action warrants escalation to parent's other children. escalation_check: Callable that determines if action warrants escalation Returns: Dict mapping entity_id → PerceptionInfo """ perceptions: Dict[str, PerceptionInfo] = {} node = self.nodes.get(location_id) if not node or not node.parent_id: return perceptions parent = self.nodes.get(node.parent_id) if not parent: return perceptions # Check if action warrants escalation should_escalate = True if escalation_check: try: should_escalate = escalation_check(action, location_id, parent.id) except Exception as e: logger.error(f"Escalation check failed: {e}") if not should_escalate: logger.info("Action did not warrant escalation to parent") return perceptions # Broadcast to all siblings (other children of parent) for sibling_id in parent.children: if sibling_id == location_id: continue sibling_entities = self.get_entities_in_location(sibling_id) for entity_id in sibling_entities: perceptions[entity_id] = PerceptionInfo( recipient_id=entity_id, original_action=action, transformed_action=f"[From {location_id}] {action}", propagation_path="escalated_from_sibling", vision_clarity=3, sound_clarity=4, perceivable=True, ) logger.info(f"Parent escalation: {len(perceptions)} entities reached") return perceptions def bubble_up_broadcast( self, location_id: str, action: str, actor_id: str, llm_filter=None, escalation_check=None, ) -> Dict[str, PerceptionInfo]: """ Full bubble-up algorithm: 1. Immediate leaf check (all entities in same location) 2. Sibling check (portal propagation) 3. Parent check (escalation if warranted) Returns: Dict mapping all perceiving entity_id → PerceptionInfo """ logger.info(f"Bubble-up broadcast from {location_id}: '{action}'") all_perceptions: Dict[str, PerceptionInfo] = {} # Step 1: Immediate location (full perception) immediate = self.broadcast_to_immediate_location(location_id) for entity_id in immediate: if entity_id != actor_id: # Don't perceive own action all_perceptions[entity_id] = PerceptionInfo( recipient_id=entity_id, original_action=action, transformed_action=action, propagation_path="immediate", vision_clarity=10, sound_clarity=10, perceivable=True, ) # Step 2: Sibling propagation through portals sibling_perceptions = self.broadcast_to_siblings_with_portals( location_id, action, llm_filter ) all_perceptions.update(sibling_perceptions) # Step 3: Parent escalation parent_perceptions = self.broadcast_to_parent_children( location_id, action, escalation_check ) all_perceptions.update(parent_perceptions) logger.info(f"Total entities perceiving: {len(all_perceptions)}") return all_perceptions def to_dict(self) -> Dict[str, Any]: """Serialize spatial graph to dict.""" return { "nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()}, "entity_positions": { entity_id: asdict(pos) for entity_id, pos in self.entity_positions.items() }, } def save_to_json(self, filepath: str) -> None: """Save spatial graph to JSON file.""" try: with open(filepath, "w") as f: json.dump(self.to_dict(), f, indent=2) logger.info(f"Spatial graph saved to {filepath}") except Exception as e: logger.error(f"Failed to save spatial graph: {e}") @classmethod def load_from_json_file(cls, filepath: str) -> "SpatialGraph": """Load spatial graph from JSON file.""" try: with open(filepath, "r") as f: data = json.load(f) graph = cls() # Rebuild nodes for node_id, node_data in data.get("nodes", {}).items(): connections = [ PortalConnection(**conn) for conn in node_data.get("connections", []) ] node = SpatialNode( id=node_data["id"], name=node_data["name"], description=node_data["description"], node_type=node_data["node_type"], parent_id=node_data.get("parent_id"), children=node_data.get("children", []), connections=connections, ) graph.nodes[node_id] = node # Rebuild entity positions for entity_id, pos_data in data.get("entity_positions", {}).items(): graph.entity_positions[entity_id] = EntityPosition(**pos_data) logger.info(f"Spatial graph loaded from {filepath}") return graph except Exception as e: logger.error(f"Failed to load spatial graph: {e}") return cls()