470 lines
16 KiB
Python
470 lines
16 KiB
Python
"""
|
|
Spatial Graph System for hierarchical location management and perception propagation.
|
|
|
|
Implements:
|
|
- Hierarchical location graph (world → region → location → POI)
|
|
- Perception bubble-up algorithm with LLM-powered filtering
|
|
- Portal-based information propagation (vision/sound)
|
|
- Spatial metadata for entity positioning within locations
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field, asdict
|
|
from typing import Optional, Dict, List, Tuple, Any
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PropagationType(Enum):
|
|
"""Types of information that propagate through space."""
|
|
|
|
VISION = "vision"
|
|
SOUND = "sound"
|
|
ACTION = "action"
|
|
|
|
|
|
@dataclass
|
|
class PortalConnection:
|
|
"""Represents a connection between two spatial nodes with propagation properties."""
|
|
|
|
target: str # Target location_id
|
|
portal: Optional[str] = None # Portal type (door, window, peephole, etc.)
|
|
portal_state_descriptor: Optional[str] = (
|
|
None # Current state of portal (open, closed, tinted, etc.)
|
|
)
|
|
vision_prop: int = 5 # 0-10 scale: 0=no vision, 10=clear vision
|
|
sound_prop: int = 5 # 0-10 scale: 0=no sound, 10=all sound propagates
|
|
bidirectional: bool = True # Whether connection is symmetric
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass
|
|
class PerceptionInfo:
|
|
"""Filtered/transformed perception for an entity."""
|
|
|
|
recipient_id: str
|
|
original_action: str
|
|
transformed_action: Optional[str] # None if blocked
|
|
propagation_path: str # e.g., "direct", "through_wall", "muffled_from_distance"
|
|
vision_clarity: int # 0-10
|
|
sound_clarity: int # 0-10
|
|
perceivable: bool # Whether entity perceives anything
|
|
|
|
|
|
@dataclass
|
|
class SpatialNode:
|
|
"""A node in the spatial hierarchy (world, region, location, POI)."""
|
|
|
|
id: str
|
|
name: str
|
|
description: str
|
|
node_type: str # "world", "region", "location", "poi"
|
|
parent_id: Optional[str] = None
|
|
children: List[str] = field(default_factory=list)
|
|
connections: List[PortalConnection] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
data = asdict(self)
|
|
data["connections"] = [c.to_dict() for c in self.connections]
|
|
return data
|
|
|
|
|
|
@dataclass
|
|
class EntityPosition:
|
|
"""Spatial metadata for an entity's position."""
|
|
|
|
entity_id: str
|
|
location_id: str # Leaf-level location (POI or location)
|
|
spatial_descriptor: str # "Leaning against the far left edge, near the kegs"
|
|
|
|
|
|
class SpatialGraph:
|
|
"""
|
|
Manages hierarchical spatial structure and perception propagation.
|
|
|
|
Hierarchy: World → Region → Location → POI (Point of Interest)
|
|
|
|
Entities have:
|
|
- location_id: Which leaf node they're in
|
|
- spatial_descriptor: Precise position within that node
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.nodes: Dict[str, SpatialNode] = {}
|
|
self.entity_positions: Dict[str, EntityPosition] = {}
|
|
self.node_hierarchy: Dict[str, List[str]] = {} # location_id → path to root
|
|
|
|
def load_from_json(self, data: Dict[str, Any]) -> None:
|
|
"""Load spatial graph from JSON structure (from demo.json)."""
|
|
logger.info("Loading spatial graph from JSON")
|
|
self._build_hierarchy(data["world"], None)
|
|
logger.info(f"Spatial graph loaded: {len(self.nodes)} nodes")
|
|
|
|
def _build_hierarchy(
|
|
self, node_data: Dict[str, Any], parent_id: Optional[str]
|
|
) -> None:
|
|
"""Recursively build spatial hierarchy from JSON."""
|
|
node_id = node_data.get("id")
|
|
node_type = self._determine_node_type(parent_id)
|
|
|
|
connections = []
|
|
for conn_data in node_data.get("connections", []):
|
|
conn = PortalConnection(
|
|
target=conn_data["target"],
|
|
portal=conn_data.get("portal"),
|
|
portal_state_descriptor=conn_data.get("portal_state_descriptor"),
|
|
vision_prop=conn_data.get("vision_prop", 5),
|
|
sound_prop=conn_data.get("sound_prop", 5),
|
|
bidirectional=conn_data.get("bidirectional", True),
|
|
)
|
|
connections.append(conn)
|
|
|
|
node = SpatialNode(
|
|
id=node_id,
|
|
name=node_data.get("name", node_id),
|
|
description=node_data.get("description", ""),
|
|
node_type=node_type,
|
|
parent_id=parent_id,
|
|
connections=connections,
|
|
)
|
|
|
|
self.nodes[node_id] = node
|
|
|
|
# Process children
|
|
children_key = self._get_children_key(node_type)
|
|
for child_data in node_data.get(children_key, []):
|
|
child_id = child_data.get("id")
|
|
node.children.append(child_id)
|
|
self._build_hierarchy(child_data, node_id)
|
|
|
|
def _determine_node_type(self, parent_id: Optional[str]) -> str:
|
|
"""Determine node type based on parent."""
|
|
if parent_id is None:
|
|
return "world"
|
|
parent = self.nodes.get(parent_id)
|
|
if parent.node_type == "world":
|
|
return "region"
|
|
elif parent.node_type == "region":
|
|
return "location"
|
|
else:
|
|
return "poi"
|
|
|
|
def _get_children_key(self, node_type: str) -> str:
|
|
"""Get the key for children based on node type."""
|
|
mapping = {
|
|
"world": "regions",
|
|
"region": "locations",
|
|
"location": "pois",
|
|
"poi": "items",
|
|
}
|
|
return mapping.get(node_type, "children")
|
|
|
|
def set_entity_position(
|
|
self, entity_id: str, location_id: str, spatial_descriptor: str = ""
|
|
) -> None:
|
|
"""Set entity's spatial position."""
|
|
if location_id not in self.nodes:
|
|
logger.warning(f"Unknown location_id: {location_id}")
|
|
return
|
|
|
|
node = self.nodes[location_id]
|
|
if node.node_type not in ["location", "poi"]:
|
|
logger.warning(
|
|
f"Entity position must be at leaf level (location or poi), got {node.node_type}"
|
|
)
|
|
return
|
|
|
|
self.entity_positions[entity_id] = EntityPosition(
|
|
entity_id=entity_id,
|
|
location_id=location_id,
|
|
spatial_descriptor=spatial_descriptor,
|
|
)
|
|
logger.info(
|
|
f"Entity {entity_id} positioned at {location_id}: {spatial_descriptor}"
|
|
)
|
|
|
|
def get_entity_position(self, entity_id: str) -> Optional[EntityPosition]:
|
|
"""Get entity's current position."""
|
|
return self.entity_positions.get(entity_id)
|
|
|
|
def get_entities_in_location(self, location_id: str) -> List[str]:
|
|
"""Get all entities in a specific location."""
|
|
return [
|
|
entity_id
|
|
for entity_id, pos in self.entity_positions.items()
|
|
if pos.location_id == location_id
|
|
]
|
|
|
|
def get_path_to_root(self, node_id: str) -> List[str]:
|
|
"""Get path from node to root (world)."""
|
|
path = []
|
|
current = node_id
|
|
while current:
|
|
path.append(current)
|
|
node = self.nodes.get(current)
|
|
if not node:
|
|
break
|
|
current = node.parent_id
|
|
return path
|
|
|
|
def get_immediate_children(self, node_id: str) -> List[str]:
|
|
"""Get immediate children of a node."""
|
|
node = self.nodes.get(node_id)
|
|
return node.children if node else []
|
|
|
|
def get_all_descendants(self, node_id: str) -> List[str]:
|
|
"""Get all descendants recursively."""
|
|
descendants = []
|
|
node = self.nodes.get(node_id)
|
|
if not node:
|
|
return descendants
|
|
|
|
for child_id in node.children:
|
|
descendants.append(child_id)
|
|
descendants.extend(self.get_all_descendants(child_id))
|
|
|
|
return descendants
|
|
|
|
def get_sibling_locations(self, node_id: str) -> List[str]:
|
|
"""Get all sibling nodes (shared parent)."""
|
|
node = self.nodes.get(node_id)
|
|
if not node or not node.parent_id:
|
|
return []
|
|
|
|
parent = self.nodes.get(node.parent_id)
|
|
if not parent:
|
|
return []
|
|
|
|
return [child for child in parent.children if child != node_id]
|
|
|
|
def get_connected_locations(
|
|
self, node_id: str
|
|
) -> List[Tuple[str, PortalConnection]]:
|
|
"""
|
|
Get all locations connected via portals.
|
|
|
|
Returns: List of (location_id, PortalConnection)
|
|
"""
|
|
node = self.nodes.get(node_id)
|
|
if not node:
|
|
return []
|
|
|
|
connections = []
|
|
for portal_conn in node.connections:
|
|
target_node = self.nodes.get(portal_conn.target)
|
|
if target_node:
|
|
connections.append((portal_conn.target, portal_conn))
|
|
|
|
return connections
|
|
|
|
def broadcast_to_immediate_location(self, location_id: str) -> List[str]:
|
|
"""
|
|
Step 1: Immediate Leaf Check
|
|
Get all entities in the same leaf location (no loss).
|
|
"""
|
|
return self.get_entities_in_location(location_id)
|
|
|
|
def broadcast_to_siblings_with_portals(
|
|
self, location_id: str, action: str, llm_filter=None
|
|
) -> Dict[str, PerceptionInfo]:
|
|
"""
|
|
Step 2: Sibling Check (Horizontal Propagation)
|
|
Propagate perception to sibling locations through portals.
|
|
|
|
llm_filter: Callable that transforms action based on portal properties
|
|
Returns: Dict mapping entity_id → PerceptionInfo
|
|
"""
|
|
perceptions: Dict[str, PerceptionInfo] = {}
|
|
connected = self.get_connected_locations(location_id)
|
|
|
|
for target_id, portal_conn in connected:
|
|
# Get entities in target location
|
|
target_entities = self.get_entities_in_location(target_id)
|
|
|
|
# Transform action based on portal properties
|
|
transformed_action = action
|
|
if llm_filter:
|
|
try:
|
|
transformed_action = llm_filter(
|
|
action=action,
|
|
source_id=location_id,
|
|
target_id=target_id,
|
|
portal_conn=portal_conn,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"LLM filter failed: {e}")
|
|
transformed_action = action
|
|
|
|
for entity_id in target_entities:
|
|
perceptions[entity_id] = PerceptionInfo(
|
|
recipient_id=entity_id,
|
|
original_action=action,
|
|
transformed_action=transformed_action,
|
|
propagation_path="portal",
|
|
vision_clarity=portal_conn.vision_prop,
|
|
sound_clarity=portal_conn.sound_prop,
|
|
perceivable=transformed_action is not None,
|
|
)
|
|
|
|
logger.info(f"Sibling broadcast: {len(perceptions)} entities reached")
|
|
return perceptions
|
|
|
|
def broadcast_to_parent_children(
|
|
self, location_id: str, action: str, escalation_check=None
|
|
) -> Dict[str, PerceptionInfo]:
|
|
"""
|
|
Step 3: Parent Check (Vertical Propagation)
|
|
Check if action warrants escalation to parent's other children.
|
|
|
|
escalation_check: Callable that determines if action warrants escalation
|
|
Returns: Dict mapping entity_id → PerceptionInfo
|
|
"""
|
|
perceptions: Dict[str, PerceptionInfo] = {}
|
|
|
|
node = self.nodes.get(location_id)
|
|
if not node or not node.parent_id:
|
|
return perceptions
|
|
|
|
parent = self.nodes.get(node.parent_id)
|
|
if not parent:
|
|
return perceptions
|
|
|
|
# Check if action warrants escalation
|
|
should_escalate = True
|
|
if escalation_check:
|
|
try:
|
|
should_escalate = escalation_check(action, location_id, parent.id)
|
|
except Exception as e:
|
|
logger.error(f"Escalation check failed: {e}")
|
|
|
|
if not should_escalate:
|
|
logger.info("Action did not warrant escalation to parent")
|
|
return perceptions
|
|
|
|
# Broadcast to all siblings (other children of parent)
|
|
for sibling_id in parent.children:
|
|
if sibling_id == location_id:
|
|
continue
|
|
|
|
sibling_entities = self.get_entities_in_location(sibling_id)
|
|
for entity_id in sibling_entities:
|
|
perceptions[entity_id] = PerceptionInfo(
|
|
recipient_id=entity_id,
|
|
original_action=action,
|
|
transformed_action=f"[From {location_id}] {action}",
|
|
propagation_path="escalated_from_sibling",
|
|
vision_clarity=3,
|
|
sound_clarity=4,
|
|
perceivable=True,
|
|
)
|
|
|
|
logger.info(f"Parent escalation: {len(perceptions)} entities reached")
|
|
return perceptions
|
|
|
|
def bubble_up_broadcast(
|
|
self,
|
|
location_id: str,
|
|
action: str,
|
|
actor_id: str,
|
|
llm_filter=None,
|
|
escalation_check=None,
|
|
) -> Dict[str, PerceptionInfo]:
|
|
"""
|
|
Full bubble-up algorithm:
|
|
1. Immediate leaf check (all entities in same location)
|
|
2. Sibling check (portal propagation)
|
|
3. Parent check (escalation if warranted)
|
|
|
|
Returns: Dict mapping all perceiving entity_id → PerceptionInfo
|
|
"""
|
|
logger.info(f"Bubble-up broadcast from {location_id}: '{action}'")
|
|
all_perceptions: Dict[str, PerceptionInfo] = {}
|
|
|
|
# Step 1: Immediate location (full perception)
|
|
immediate = self.broadcast_to_immediate_location(location_id)
|
|
for entity_id in immediate:
|
|
if entity_id != actor_id: # Don't perceive own action
|
|
all_perceptions[entity_id] = PerceptionInfo(
|
|
recipient_id=entity_id,
|
|
original_action=action,
|
|
transformed_action=action,
|
|
propagation_path="immediate",
|
|
vision_clarity=10,
|
|
sound_clarity=10,
|
|
perceivable=True,
|
|
)
|
|
|
|
# Step 2: Sibling propagation through portals
|
|
sibling_perceptions = self.broadcast_to_siblings_with_portals(
|
|
location_id, action, llm_filter
|
|
)
|
|
all_perceptions.update(sibling_perceptions)
|
|
|
|
# Step 3: Parent escalation
|
|
parent_perceptions = self.broadcast_to_parent_children(
|
|
location_id, action, escalation_check
|
|
)
|
|
all_perceptions.update(parent_perceptions)
|
|
|
|
logger.info(f"Total entities perceiving: {len(all_perceptions)}")
|
|
return all_perceptions
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Serialize spatial graph to dict."""
|
|
return {
|
|
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
|
"entity_positions": {
|
|
entity_id: asdict(pos)
|
|
for entity_id, pos in self.entity_positions.items()
|
|
},
|
|
}
|
|
|
|
def save_to_json(self, filepath: str) -> None:
|
|
"""Save spatial graph to JSON file."""
|
|
try:
|
|
with open(filepath, "w") as f:
|
|
json.dump(self.to_dict(), f, indent=2)
|
|
logger.info(f"Spatial graph saved to {filepath}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save spatial graph: {e}")
|
|
|
|
@classmethod
|
|
def load_from_json_file(cls, filepath: str) -> "SpatialGraph":
|
|
"""Load spatial graph from JSON file."""
|
|
try:
|
|
with open(filepath, "r") as f:
|
|
data = json.load(f)
|
|
|
|
graph = cls()
|
|
|
|
# Rebuild nodes
|
|
for node_id, node_data in data.get("nodes", {}).items():
|
|
connections = [
|
|
PortalConnection(**conn)
|
|
for conn in node_data.get("connections", [])
|
|
]
|
|
node = SpatialNode(
|
|
id=node_data["id"],
|
|
name=node_data["name"],
|
|
description=node_data["description"],
|
|
node_type=node_data["node_type"],
|
|
parent_id=node_data.get("parent_id"),
|
|
children=node_data.get("children", []),
|
|
connections=connections,
|
|
)
|
|
graph.nodes[node_id] = node
|
|
|
|
# Rebuild entity positions
|
|
for entity_id, pos_data in data.get("entity_positions", {}).items():
|
|
graph.entity_positions[entity_id] = EntityPosition(**pos_data)
|
|
|
|
logger.info(f"Spatial graph loaded from {filepath}")
|
|
return graph
|
|
except Exception as e:
|
|
logger.error(f"Failed to load spatial graph: {e}")
|
|
return cls()
|