refactor: Refactor and split scenario_builder_tui

This commit is contained in:
2026-04-18 18:19:25 +05:30
parent a38d078d51
commit 352d83ab75
5 changed files with 1143 additions and 1085 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
from .entities import EntitiesScreen
from .prompts import NodeTypePrompt, PathPrompt
from .spatial_graph import SpatialGraphScreen
__all__ = ["EntitiesScreen", "NodeTypePrompt", "PathPrompt", "SpatialGraphScreen"]

View File

@@ -0,0 +1,321 @@
import json
from typing import Any, Optional
from textual.app import ComposeResult
from textual.containers import Horizontal, Vertical, VerticalScroll
from textual.screen import Screen
from textual.widgets import (
Button,
Footer,
Header,
Input,
Label,
ListItem,
ListView,
TextArea,
)
from omnia.tools.scenario_builder import DEFAULT_ENTITY
class EntityItem(ListItem):
def __init__(self, index: int, entity: dict[str, Any]) -> None:
super().__init__(Label(self._label(entity)))
self.index = index
@staticmethod
def _label(entity: dict[str, Any]) -> str:
name = str(entity.get("name") or "Unnamed").strip()
entity_id = str(entity.get("id") or "no-id").strip()
return f"{name} ({entity_id})"
class EntitiesScreen(Screen[Optional[dict[str, Any]]]):
CSS = """
Screen {
layout: vertical;
}
#entities-main {
height: 1fr;
}
#entity-left {
width: 35%;
min-width: 24;
border: round $primary;
padding: 1;
}
#entity-right {
width: 1fr;
border: round $primary;
padding: 1;
}
#entity_list {
height: 1fr;
}
#entity-form {
height: 1fr;
}
#entity_stats,
#entity_voice,
#entity_memories {
height: 4;
}
#entity-traits-label,
#entity-stats-label,
#entity-voice-label,
#entity-memories-label {
margin-top: 1;
}
#entity_status {
height: 1;
padding: 0 1;
}
"""
def __init__(self, scenario: dict[str, Any]) -> None:
super().__init__()
self.scenario = scenario
self.entities = self.scenario.setdefault("entities", [])
if not isinstance(self.entities, list):
self.entities = []
self.scenario["entities"] = self.entities
self.selected_index: Optional[int] = None
def compose(self) -> ComposeResult:
yield Header(show_clock=True)
with Horizontal(id="entities-main"):
with Vertical(id="entity-left"):
yield Label("Entities", id="entity-title")
yield ListView(id="entity_list")
with Horizontal(id="entity-actions"):
yield Button("New", id="entity_new")
yield Button("Delete", id="entity_delete")
with Vertical(id="entity-right"):
yield Label("Entity Details", id="entity-current")
with VerticalScroll(id="entity-form"):
yield Label("Entity ID", id="entity-id-label")
yield Input(id="entity_id")
yield Label("Name", id="entity-name-label")
yield Input(id="entity_name")
yield Label("Traits (comma-separated)", id="entity-traits-label")
yield Input(id="entity_traits")
yield Label("Stats (key=value per line)", id="entity-stats-label")
yield TextArea(id="entity_stats")
yield Label("Voice sample", id="entity-voice-label")
yield TextArea(id="entity_voice")
yield Label("Current mood", id="entity-mood-label")
yield Input(id="entity_mood")
yield Label("Location ID", id="entity-location-label")
yield Input(id="entity_location")
yield Label("Spatial descriptor", id="entity-spatial-label")
yield Input(id="entity_spatial")
yield Label(
"Memories (one per line; JSON allowed)",
id="entity-memories-label",
)
yield TextArea(id="entity_memories")
with Horizontal(id="entity-actions-right"):
yield Button("Save", id="entity_save", variant="primary")
yield Button("Apply", id="entity_apply")
yield Button("Cancel", id="entity_cancel")
yield Label("", id="entity_status")
yield Footer()
def on_mount(self) -> None:
self._reload_entity_list(select_index=0 if self.entities else None)
if self.entities:
self._load_entity_into_form(self.entities[0])
self.selected_index = 0
else:
self._load_entity_into_form(self._default_entity())
self.selected_index = None
self._set_status("Editing entities")
def _set_status(self, message: str) -> None:
self.query_one("#entity_status", Label).update(message)
def _default_entity(self) -> dict[str, Any]:
return json.loads(json.dumps(DEFAULT_ENTITY))
def _load_entity_into_form(self, entity: dict[str, Any]) -> None:
self.query_one("#entity_id", Input).value = str(entity.get("id", ""))
self.query_one("#entity_name", Input).value = str(entity.get("name", ""))
traits = entity.get("traits", [])
if isinstance(traits, list):
traits_text = ", ".join(str(t).strip() for t in traits if str(t).strip())
else:
traits_text = str(traits)
self.query_one("#entity_traits", Input).value = traits_text
stats = entity.get("stats", {})
stats_lines = []
if isinstance(stats, dict):
for key, value in stats.items():
stats_lines.append(f"{key}={value}")
self.query_one("#entity_stats", TextArea).text = "\n".join(stats_lines)
self.query_one("#entity_voice", TextArea).text = str(
entity.get("voice_sample", "")
)
self.query_one("#entity_mood", Input).value = str(
entity.get("current_mood", "")
)
metadata = (
entity.get("metadata", {})
if isinstance(entity.get("metadata"), dict)
else {}
)
self.query_one("#entity_location", Input).value = str(
metadata.get("location", "")
)
self.query_one("#entity_spatial", Input).value = str(
metadata.get("spatial_descriptor", "")
)
memories = entity.get("memories", [])
memory_lines = []
if isinstance(memories, list):
for memory in memories:
if isinstance(memory, (dict, list)):
memory_lines.append(json.dumps(memory))
else:
memory_lines.append(str(memory))
self.query_one("#entity_memories", TextArea).text = "\n".join(memory_lines)
def _build_entity_from_form(self) -> Optional[dict[str, Any]]:
entity_id = self.query_one("#entity_id", Input).value.strip()
name = self.query_one("#entity_name", Input).value.strip()
if not entity_id or not name:
self._set_status("Entity requires non-empty ID and name")
return None
entity = self._default_entity()
entity["id"] = entity_id
entity["name"] = name
traits_value = self.query_one("#entity_traits", Input).value
traits = [t.strip() for t in traits_value.split(",") if t.strip()]
entity["traits"] = traits
stats_text = self.query_one("#entity_stats", TextArea).text
stats: dict[str, int] = {}
for line in stats_text.splitlines():
raw = line.strip()
if not raw:
continue
if "=" not in raw:
self._set_status(f"Invalid stats line (missing '='): {raw}")
return None
key, value = raw.split("=", 1)
key = key.strip()
value = value.strip()
if not key or not value:
self._set_status(f"Invalid stats line: {raw}")
return None
try:
stats[key] = int(value)
except ValueError:
self._set_status(f"Stat value must be an integer: {raw}")
return None
entity["stats"] = stats
entity["voice_sample"] = self.query_one("#entity_voice", TextArea).text.strip()
mood = self.query_one("#entity_mood", Input).value.strip()
entity["current_mood"] = mood if mood else "Neutral"
entity["metadata"]["location"] = self.query_one(
"#entity_location", Input
).value.strip()
entity["metadata"]["spatial_descriptor"] = self.query_one(
"#entity_spatial", Input
).value.strip()
memories_text = self.query_one("#entity_memories", TextArea).text
memories: list[Any] = []
for line in memories_text.splitlines():
raw = line.strip()
if not raw:
continue
if raw.startswith(("{", "[")):
try:
memories.append(json.loads(raw))
except json.JSONDecodeError:
self._set_status(f"Invalid memory JSON: {raw}")
return None
else:
memories.append(raw)
entity["memories"] = memories
return entity
def _reload_entity_list(self, select_index: Optional[int] = None) -> None:
list_view = self.query_one("#entity_list", ListView)
list_view.clear()
for index, entity in enumerate(self.entities):
list_view.append(EntityItem(index, entity))
if select_index is not None and self.entities:
clamped = max(0, min(select_index, len(self.entities) - 1))
list_view.index = clamped
self.selected_index = clamped
def on_list_view_selected(self, event: ListView.Selected) -> None:
if isinstance(event.item, EntityItem):
self.selected_index = event.item.index
self._load_entity_into_form(self.entities[event.item.index])
async def on_button_pressed(self, event: Button.Pressed) -> None:
button_id = event.button.id
if button_id == "entity_new":
self.selected_index = None
self._load_entity_into_form(self._default_entity())
self._set_status("Creating new entity")
return
if button_id == "entity_save":
entity = self._build_entity_from_form()
if entity is None:
return
existing_ids = {
e.get("id")
for i, e in enumerate(self.entities)
if i != self.selected_index
}
if entity["id"] in existing_ids:
self._set_status(f"Entity ID '{entity['id']}' already exists")
return
if self.selected_index is None:
self.entities.append(entity)
self.selected_index = len(self.entities) - 1
else:
self.entities[self.selected_index] = entity
self._reload_entity_list(select_index=self.selected_index)
self._set_status("Entity saved")
return
if button_id == "entity_delete":
if self.selected_index is None:
self._set_status("Select an entity to delete")
return
removed = self.entities.pop(self.selected_index)
self._reload_entity_list(select_index=self.selected_index)
if self.entities:
self._load_entity_into_form(
self.entities[min(self.selected_index, len(self.entities) - 1)]
)
else:
self._load_entity_into_form(self._default_entity())
self.selected_index = None
self._set_status(f"Deleted {removed.get('name')}")
return
if button_id == "entity_apply":
self.scenario["entities"] = self.entities
self.dismiss(self.scenario)
return
if button_id == "entity_cancel":
self.dismiss(None)

View File

@@ -0,0 +1,71 @@
from typing import Optional
from textual.app import ComposeResult
from textual.containers import Grid, Horizontal
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label
class PathPrompt(ModalScreen[Optional[str]]):
def __init__(self, title: str, default: str = "") -> None:
super().__init__()
self.title = title
self.default = default
def compose(self) -> ComposeResult:
yield Label(self.title, id="prompt-title")
yield Input(value=self.default, id="prompt-input")
with Horizontal(id="prompt-buttons"):
yield Button("OK", id="ok", variant="primary")
yield Button("Cancel", id="cancel")
def on_mount(self) -> None:
self.query_one(Input).focus()
def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "ok":
value = self.query_one(Input).value.strip()
self.dismiss(value or None)
else:
self.dismiss(None)
class NodeTypePrompt(ModalScreen[Optional[str]]):
CSS = """
NodeTypePrompt {
align: center middle;
}
#dialog {
layout: grid;
grid-size: 2;
grid-gutter: 1 2;
grid-rows: 1fr 3;
padding: 0 1;
width: 60;
height: 11;
border: thick $background 80%;
background: $surface;
}
#node-type-title {
column-span: 2;
content-align: center middle;
}
"""
def compose(self) -> ComposeResult:
with Grid(id="dialog"):
yield Label("Select node type", id="node-type-title")
yield Button("🗺️ Region", id="node_region", variant="primary")
yield Button("📍 Location", id="node_location")
yield Button("⭐ POI", id="node_poi")
yield Button("Cancel", id="node_cancel")
def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "node_region":
self.dismiss("region")
elif event.button.id == "node_location":
self.dismiss("location")
elif event.button.id == "node_poi":
self.dismiss("poi")
else:
self.dismiss(None)

View File

@@ -0,0 +1,738 @@
import json
from dataclasses import dataclass
from typing import Any, Optional
from langchain_core.messages import HumanMessage, SystemMessage
from textual.app import ComposeResult
from textual.containers import Horizontal, Vertical
from textual.screen import Screen
from textual.widgets import (
Button,
Footer,
Header,
Input,
Label,
Select,
TextArea,
Tree,
)
from omnia.tools.scenario_builder import DEFAULT_SCENARIO
from omnia.tools.scenario_tui.prompts import NodeTypePrompt
@dataclass
class SpatialNodeRef:
node_type: str
obj: dict[str, Any]
class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
CSS = """
Screen {
layout: vertical;
}
#spatial-main {
height: 1fr;
}
#spatial-left {
width: 35%;
min-width: 24;
border: round $primary;
padding: 1;
}
#spatial-right {
width: 1fr;
border: round $primary;
padding: 1;
}
.hidden {
display: none;
}
#spatial-tree {
height: 1fr;
}
#node-description {
height: 5;
}
#spatial-status {
height: 1;
padding: 0 1;
}
#spatial-actions {
height: auto;
width: auto;
align-horizontal: left;
}
"""
def __init__(self, scenario: dict[str, Any]) -> None:
super().__init__()
self.scenario = scenario
self.selected_node = None
self.node_map: dict[int, Any] = {}
self.parent_option_map: dict[str, tuple[str, dict[str, Any]]] = {}
def compose(self) -> ComposeResult:
yield Header(show_clock=True)
with Horizontal(id="spatial-main"):
with Vertical(id="spatial-left"):
yield Label("Spatial Graph", id="spatial-title")
yield Tree("🌍 World", id="spatial-tree")
with Horizontal(id="spatial-tree-actions"):
yield Button("Add", id="spatial_add", variant="primary")
yield Button("Delete", id="spatial_delete")
with Horizontal(id="spatial-nesting-actions"):
yield Button("Promote", id="spatial_promote")
yield Button("Demote", id="spatial_demote")
with Vertical(id="spatial-right"):
yield Label("Node Details", id="node-title")
with Vertical(id="node-details"):
yield Label("Type", id="node-type-label")
yield Label("", id="node-type-value")
yield Label("ID", id="node-id-label")
yield Input(id="node_id")
yield Label("Name", id="node-name-label")
yield Input(id="node_name")
yield Label("Description", id="node-desc-label")
yield TextArea(id="node-description")
yield Label("Parent", id="node-parent-label")
yield Select([], prompt="Select parent", id="node_parent")
with Horizontal(id="node-actions"):
yield Button("Generate", id="node_generate")
yield Button("Update", id="node_update", variant="primary")
yield Button("Reset", id="node_reset")
with Horizontal(id="spatial-actions"):
yield Button("Save", id="spatial_save", variant="primary")
yield Button("Cancel", id="spatial_cancel")
yield Label("", id="spatial-status")
yield Footer()
def on_mount(self) -> None:
self._ensure_spatial_graph()
self._build_tree()
self._clear_form()
self._set_status("Select a node to edit")
def _set_status(self, message: str) -> None:
self.query_one("#spatial-status", Label).update(message)
def _toast(self, message: str, severity: str = "warning") -> None:
try:
self.app.notify(message, severity=severity)
except Exception:
self._set_status(message)
def _ensure_spatial_graph(self) -> None:
spatial_graph = self.scenario.get("spatial_graph")
if spatial_graph is None:
spatial_graph = json.loads(json.dumps(DEFAULT_SCENARIO["spatial_graph"]))
self.scenario["spatial_graph"] = spatial_graph
world = spatial_graph.setdefault("world", {})
world.setdefault("regions", [])
world.setdefault("name", "World")
world.setdefault("id", "world")
def _world(self) -> dict[str, Any]:
return self.scenario["spatial_graph"]["world"]
def _build_tree(
self,
select_obj: Optional[dict[str, Any]] = None,
expanded: Optional[set[int]] = None,
) -> None:
tree = self.query_one("#spatial-tree", Tree)
world = self._world()
tree.reset(f"🌍 {world.get('name', 'World')}")
tree.root.data = SpatialNodeRef("world", world)
tree.root.expand()
self.selected_node = None
self.node_map = {}
expanded = expanded or set()
def add_region(region: dict[str, Any]) -> None:
node = tree.root.add(
self._label_for("region", region), data=SpatialNodeRef("region", region)
)
self.node_map[id(region)] = node
if id(region) in expanded:
node.expand()
for location in region.get("locations", []):
add_location(node, location)
def add_location(parent_node: Any, location: dict[str, Any]) -> None:
loc_node = parent_node.add(
self._label_for("location", location),
data=SpatialNodeRef("location", location),
)
self.node_map[id(location)] = loc_node
if id(location) in expanded:
loc_node.expand()
for nested in location.get("locations", []):
add_location(loc_node, nested)
for poi in location.get("pois", []):
poi_node = loc_node.add(
self._label_for("poi", poi), data=SpatialNodeRef("poi", poi)
)
self.node_map[id(poi)] = poi_node
if id(poi) in expanded:
poi_node.expand()
for region in world.get("regions", []):
add_region(region)
if select_obj and id(select_obj) in self.node_map:
self.selected_node = self.node_map[id(select_obj)]
self._expand_ancestors(self.selected_node)
self._load_form_from_node(self.selected_node)
else:
self._clear_form()
def _label_for(self, node_type: str, obj: dict[str, Any]) -> str:
name = str(obj.get("name") or "Unnamed").strip()
node_id = str(obj.get("id") or "no-id").strip()
icon = {"region": "🗺️", "location": "📍", "poi": ""}.get(node_type, "")
return f"{icon} {name} ({node_id})"
def _capture_expanded_nodes(self) -> set[int]:
tree = self.query_one("#spatial-tree", Tree)
expanded: set[int] = set()
def walk(node: Any) -> None:
if node.is_expanded and isinstance(node.data, SpatialNodeRef):
expanded.add(id(node.data.obj))
for child in node.children:
walk(child)
walk(tree.root)
return expanded
def _expand_ancestors(self, node: Any) -> None:
current = node
while current is not None:
current.expand()
current = current.parent
def _clear_form(self) -> None:
self.query_one("#node-type-value", Label).update("")
self.query_one("#node_id", Input).value = ""
self.query_one("#node_name", Input).value = ""
self.query_one("#node-description", TextArea).text = ""
self.query_one("#node_parent", Select).set_options([])
self.query_one("#node_parent", Select).value = Select.NULL
self.query_one("#node-details", Vertical).add_class("hidden")
self.query_one("#node_generate", Button).disabled = True
self.query_one("#node_update", Button).disabled = True
self.query_one("#node_reset", Button).disabled = True
self.query_one("#spatial_delete", Button).disabled = True
self.query_one("#spatial_promote", Button).disabled = True
self.query_one("#spatial_demote", Button).disabled = True
def _load_form_from_node(self, node: Any) -> None:
ref = node.data
if not isinstance(ref, SpatialNodeRef) or ref.node_type == "world":
self._clear_form()
return
self.query_one("#node-details", Vertical).remove_class("hidden")
self.query_one("#node-type-value", Label).update(ref.node_type.title())
self.query_one("#node_id", Input).value = str(ref.obj.get("id", ""))
self.query_one("#node_name", Input).value = str(ref.obj.get("name", ""))
self.query_one("#node-description", TextArea).text = str(
ref.obj.get("description", "")
)
self.query_one("#node_generate", Button).disabled = False
self.query_one("#node_update", Button).disabled = False
self.query_one("#node_reset", Button).disabled = False
self.query_one("#spatial_delete", Button).disabled = False
is_location = ref.node_type == "location"
can_promote, can_demote = (
self._location_move_state(node) if is_location else (False, False)
)
self.query_one("#spatial_promote", Button).disabled = not can_promote
self.query_one("#spatial_demote", Button).disabled = not can_demote
parent_select = self.query_one("#node_parent", Select)
if is_location:
options, selected_value, option_map = self._build_parent_options(ref.obj)
self.parent_option_map = option_map
parent_select.set_options(options)
parent_select.value = (
selected_value if selected_value is not None else Select.NULL
)
parent_select.disabled = False
else:
self.parent_option_map = {}
parent_select.set_options([])
parent_select.value = Select.NULL
parent_select.disabled = True
def on_tree_node_selected(self, event: Tree.NodeSelected) -> None:
self.selected_node = event.node
self._load_form_from_node(event.node)
def _open_node_type_prompt(self) -> None:
self.app.push_screen(NodeTypePrompt(), callback=self._handle_node_type)
def _handle_node_type(self, result: Optional[str]) -> None:
if result:
self._add_node(result)
def _selected_ref(self) -> Optional[SpatialNodeRef]:
if self.selected_node is None:
return None
ref = self.selected_node.data
if isinstance(ref, SpatialNodeRef):
return ref
return None
def _selected_parent_node(self) -> Optional[Any]:
if self.selected_node is None:
return None
return self.selected_node.parent
def _parent_locations_list(
self, parent_node: Any
) -> Optional[list[dict[str, Any]]]:
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
return None
parent_ref = parent_node.data
if parent_ref.node_type not in {"region", "location"}:
return None
return parent_ref.obj.setdefault("locations", [])
def _location_move_state(self, node: Any) -> tuple[bool, bool]:
if node is None or not isinstance(node.data, SpatialNodeRef):
return (False, False)
ref = node.data
if ref.node_type != "location":
return (False, False)
parent_node = node.parent
parent_locations = self._parent_locations_list(parent_node)
if parent_locations is None:
return (False, False)
can_promote = False
if parent_node is not None and isinstance(parent_node.data, SpatialNodeRef):
parent_ref = parent_node.data
can_promote = parent_ref.node_type == "location"
can_demote = False
if ref.obj in parent_locations:
index = parent_locations.index(ref.obj)
can_demote = index < len(parent_locations) - 1
return (can_promote, can_demote)
def _parent_pois_list(self, parent_node: Any) -> Optional[list[dict[str, Any]]]:
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
return None
parent_ref = parent_node.data
if parent_ref.node_type != "location":
return None
return parent_ref.obj.setdefault("pois", [])
def _collect_locations(self) -> list[dict[str, Any]]:
locations: list[dict[str, Any]] = []
def walk(location_list: list[dict[str, Any]]) -> None:
for location in location_list:
locations.append(location)
walk(location.get("locations", []))
for region in self._world().get("regions", []):
walk(region.get("locations", []))
return locations
def _location_descendants(self, location: dict[str, Any]) -> set[int]:
descendants: set[int] = set()
def walk(children: list[dict[str, Any]]) -> None:
for child in children:
descendants.add(id(child))
walk(child.get("locations", []))
walk(location.get("locations", []))
return descendants
def _find_location_parent(
self, target: dict[str, Any]
) -> Optional[tuple[str, dict[str, Any]]]:
world = self._world()
for region in world.get("regions", []):
if target in region.get("locations", []):
return ("region", region)
for location in region.get("locations", []):
found = self._find_location_parent_in_location(location, target)
if found:
return found
return None
def _find_location_parent_in_location(
self, location: dict[str, Any], target: dict[str, Any]
) -> Optional[tuple[str, dict[str, Any]]]:
if target in location.get("locations", []):
return ("location", location)
for nested in location.get("locations", []):
found = self._find_location_parent_in_location(nested, target)
if found:
return found
return None
def _build_parent_options(
self, current: dict[str, Any]
) -> tuple[list[tuple[str, str]], Optional[str], dict[str, tuple[str, dict[str, Any]]]]:
options: list[tuple[str, str]] = []
option_map: dict[str, tuple[str, dict[str, Any]]] = {}
forbidden = {id(current)} | self._location_descendants(current)
selected_value: Optional[str] = None
for region in self._world().get("regions", []):
region_id = str(region.get("id", "")).strip()
region_name = str(region.get("name", "")).strip()
value = f"region@{id(region)}"
label = (
f"🗺️ {region_name} ({region_id})"
if region_name
else f"🗺️ {region_id or 'region'}"
)
options.append((label, value))
option_map[value] = ("region", region)
for location in self._collect_locations():
if id(location) in forbidden:
continue
location_id = str(location.get("id", "")).strip()
location_name = str(location.get("name", "")).strip()
value = f"location@{id(location)}"
label = (
f"📍 {location_name} ({location_id})"
if location_name
else f"📍 {location_id or 'location'}"
)
options.append((label, value))
option_map[value] = ("location", location)
parent_info = self._find_location_parent(current)
if parent_info:
parent_type, parent_obj = parent_info
for value, (entry_type, entry_obj) in option_map.items():
if entry_type == parent_type and entry_obj is parent_obj:
selected_value = value
break
return options, selected_value, option_map
def _update_selected_node(self) -> None:
ref = self._selected_ref()
if ref is None or ref.node_type == "world":
self._set_status("Select a node to update")
return
node_id = self.query_one("#node_id", Input).value.strip()
name = self.query_one("#node_name", Input).value.strip()
description = self.query_one("#node-description", TextArea).text.strip()
if not node_id or not name:
self._set_status("ID and name are required")
return
ref.obj["id"] = node_id
ref.obj["name"] = name
ref.obj["description"] = description
if ref.node_type == "region":
ref.obj.setdefault("locations", [])
elif ref.node_type == "location":
ref.obj.setdefault("locations", [])
ref.obj.setdefault("pois", [])
elif ref.node_type == "poi":
ref.obj.setdefault("connections", [])
if ref.node_type == "location":
parent_select = self.query_one("#node_parent", Select)
if parent_select.value is not Select.NULL:
selected_value = str(parent_select.value)
parent_info = self.parent_option_map.get(selected_value)
if parent_info:
parent_type, parent_obj = parent_info
current_parent = self._find_location_parent(ref.obj)
if current_parent is None or current_parent[1] is not parent_obj:
if current_parent:
current_type, current_obj = current_parent
if current_type == "region":
current_obj.get("locations", []).remove(ref.obj)
elif current_type == "location":
current_obj.get("locations", []).remove(ref.obj)
parent_obj.setdefault("locations", []).append(ref.obj)
expanded = self._capture_expanded_nodes()
self._build_tree(select_obj=ref.obj, expanded=expanded)
self._set_status("Node updated")
return
if self.selected_node:
self.selected_node.label = self._label_for(ref.node_type, ref.obj)
self._set_status("Node updated")
def _node_path(self, node: Any) -> str:
parts = []
current = node
while current is not None and isinstance(current.data, SpatialNodeRef):
ref = current.data
if ref.node_type == "world":
break
label = str(ref.obj.get("name") or ref.obj.get("id") or ref.node_type)
parts.append(f"{ref.node_type}: {label}")
current = current.parent
return " > ".join(reversed(parts)) if parts else "World"
def _collect_ids(self, node_type: str) -> set[str]:
world = self._world()
ids: set[str] = set()
if node_type == "region":
for region in world.get("regions", []):
if region_id := str(region.get("id", "")).strip():
ids.add(region_id)
return ids
def walk_locations(locations: list[dict[str, Any]]) -> None:
for location in locations:
if node_type == "location":
if loc_id := str(location.get("id", "")).strip():
ids.add(loc_id)
if node_type == "poi":
for poi in location.get("pois", []):
if poi_id := str(poi.get("id", "")).strip():
ids.add(poi_id)
walk_locations(location.get("locations", []))
for region in world.get("regions", []):
walk_locations(region.get("locations", []))
return ids
def _generate_node_fields(self) -> None:
ref = self._selected_ref()
if ref is None or ref.node_type == "world":
self._toast("Select a node to generate")
return
scenario_meta = self.scenario.get("scenario", {})
node_type = ref.node_type
system_prompt = """
You are helping build a game scenario.
Return ONLY a single-line JSON object with keys:
id, name, description.
- id: short snake_case identifier
- description: 1-3 sentences
No commentary. ASCII only.
""".strip()
user_prompt = f"""
Scenario title: {scenario_meta.get("title", "")}
Scenario description: {scenario_meta.get("description", "")}
Node type: {node_type}
Node path: {self._node_path(self.selected_node)}
Current id: {ref.obj.get("id", "")}
Current name: {ref.obj.get("name", "")}
""".strip()
from omnia.llm_runtime import invoke_llm
response_text = invoke_llm(
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
)
try:
payload = json.loads(response_text)
except json.JSONDecodeError:
self._toast("LLM response was not valid JSON", severity="error")
return
if not isinstance(payload, dict):
self._toast("LLM response must be a JSON object", severity="error")
return
generated_id = str(payload.get("id", "")).strip()
generated_name = str(payload.get("name", "")).strip()
generated_description = str(payload.get("description", "")).strip()
if not generated_id or not generated_name:
self._toast("Generated data missing id or name", severity="error")
return
existing_ids = self._collect_ids(node_type)
if generated_id in existing_ids and generated_id != str(ref.obj.get("id", "")).strip():
self._toast(
f"Generated id '{generated_id}' already exists", severity="warning"
)
return
ref.obj["id"] = generated_id
ref.obj["name"] = generated_name
ref.obj["description"] = generated_description
if self.selected_node:
self.selected_node.label = self._label_for(node_type, ref.obj)
self._load_form_from_node(self.selected_node)
self._set_status("Generated fields applied")
def _reset_selected_node(self) -> None:
if self.selected_node is None:
return
self._load_form_from_node(self.selected_node)
self._set_status("Reverted changes")
def _add_node(self, node_type: str) -> None:
world = self._world()
expanded = self._capture_expanded_nodes()
if node_type == "region":
region = {"id": "", "name": "", "description": "", "locations": []}
world.setdefault("regions", []).append(region)
self._build_tree(select_obj=region, expanded=expanded)
self._set_status("Region added")
return
parent_node = self.selected_node
if node_type == "location":
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
self._toast("Select a region or location for the new location")
return
parent_ref = parent_node.data
if parent_ref.node_type not in {"region", "location"}:
self._toast("Select a region or location for the new location")
return
location = {
"id": "",
"name": "",
"description": "",
"locations": [],
"pois": [],
}
parent_ref.obj.setdefault("locations", []).append(location)
self._build_tree(select_obj=location, expanded=expanded)
self._set_status("Location added")
return
if node_type == "poi":
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
self._toast("Select a location for the new POI")
return
parent_ref = parent_node.data
if parent_ref.node_type == "poi":
parent_node = parent_node.parent
if parent_node is None or not isinstance(
parent_node.data, SpatialNodeRef
):
self._toast("Select a location for the new POI")
return
parent_ref = parent_node.data
if parent_ref.node_type != "location":
self._toast("Select a location for the new POI")
return
poi = {"id": "", "name": "", "description": "", "connections": []}
parent_ref.obj.setdefault("pois", []).append(poi)
self._build_tree(select_obj=poi, expanded=expanded)
self._set_status("POI added")
def _delete_node(self) -> None:
ref = self._selected_ref()
if ref is None or ref.node_type == "world":
self._set_status("Select a node to delete")
return
parent_node = self._selected_parent_node()
expanded = self._capture_expanded_nodes()
if ref.node_type == "region":
world = self._world()
world["regions"] = [r for r in world.get("regions", []) if r is not ref.obj]
self._build_tree(expanded=expanded)
self._set_status("Region deleted")
return
if ref.node_type == "location":
parent_locations = self._parent_locations_list(parent_node)
if parent_locations is None:
self._set_status("Unable to delete location")
return
parent_locations[:] = [
loc for loc in parent_locations if loc is not ref.obj
]
self._build_tree(expanded=expanded)
self._set_status("Location deleted")
return
if ref.node_type == "poi":
parent_pois = self._parent_pois_list(parent_node)
if parent_pois is None:
self._set_status("Unable to delete POI")
return
parent_pois[:] = [poi for poi in parent_pois if poi is not ref.obj]
self._build_tree(expanded=expanded)
self._set_status("POI deleted")
def _promote_location(self) -> None:
ref = self._selected_ref()
if ref is None or ref.node_type != "location":
self._set_status("Select a location to promote")
return
parent_node = self._selected_parent_node()
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
self._set_status("Unable to promote location")
return
parent_ref = parent_node.data
if parent_ref.node_type != "location":
self._set_status("Location is already top-level")
return
grandparent_node = parent_node.parent
parent_locations = parent_ref.obj.setdefault("locations", [])
if ref.obj in parent_locations:
parent_locations.remove(ref.obj)
target_list = self._parent_locations_list(grandparent_node)
if target_list is None:
self._set_status("Unable to promote location")
return
target_list.append(ref.obj)
expanded = self._capture_expanded_nodes()
self._build_tree(select_obj=ref.obj, expanded=expanded)
self._set_status("Location promoted")
def _demote_location(self) -> None:
ref = self._selected_ref()
if ref is None or ref.node_type != "location":
self._set_status("Select a location to demote")
return
parent_node = self._selected_parent_node()
parent_locations = self._parent_locations_list(parent_node)
if parent_locations is None:
self._set_status("Unable to demote location")
return
try:
index = parent_locations.index(ref.obj)
except ValueError:
self._set_status("Unable to demote location")
return
if index >= len(parent_locations) - 1:
self._set_status("No next sibling to demote under")
return
new_parent = parent_locations[index + 1]
new_parent.setdefault("locations", []).append(ref.obj)
parent_locations.remove(ref.obj)
expanded = self._capture_expanded_nodes()
self._build_tree(select_obj=ref.obj, expanded=expanded)
self._set_status("Location demoted")
def on_button_pressed(self, event: Button.Pressed) -> None:
button_id = event.button.id
if button_id == "spatial_add":
self._open_node_type_prompt()
return
if button_id == "spatial_delete":
self._delete_node()
return
if button_id == "spatial_promote":
self._promote_location()
return
if button_id == "spatial_demote":
self._demote_location()
return
if button_id == "node_generate":
self._generate_node_fields()
return
if button_id == "node_update":
self._update_selected_node()
return
if button_id == "node_reset":
self._reset_selected_node()
return
if button_id == "spatial_save":
self.dismiss(self.scenario)
return
if button_id == "spatial_cancel":
self.dismiss(None)