diff --git a/src/omnia/tools/scenario_builder_tui.py b/src/omnia/tools/scenario_builder_tui.py index ae9ac5e..fafddac 100644 --- a/src/omnia/tools/scenario_builder_tui.py +++ b/src/omnia/tools/scenario_builder_tui.py @@ -1,13 +1,10 @@ import json -from dataclasses import dataclass from pathlib import Path from typing import Any, Optional -from langchain_core.messages import HumanMessage, SystemMessage from textual.app import App, ComposeResult -from textual.containers import Grid, Horizontal, Vertical, VerticalScroll +from textual.containers import Horizontal, Vertical, VerticalScroll from textual.reactive import reactive -from textual.screen import ModalScreen, Screen from textual.widgets import ( Button, Footer, @@ -16,80 +13,16 @@ from textual.widgets import ( Label, ListItem, ListView, + Rule, Select, TabbedContent, TabPane, TextArea, - Tree, - Rule, ) -from omnia.tools.scenario_builder import DEFAULT_ENTITY, DEFAULT_SCENARIO - - -class PathPrompt(ModalScreen[Optional[str]]): - def __init__(self, title: str, default: str = "") -> None: - super().__init__() - self.title = title - self.default = default - - def compose(self) -> ComposeResult: - yield Label(self.title, id="prompt-title") - yield Input(value=self.default, id="prompt-input") - with Horizontal(id="prompt-buttons"): - yield Button("OK", id="ok", variant="primary") - yield Button("Cancel", id="cancel") - - def on_mount(self) -> None: - self.query_one(Input).focus() - - def on_button_pressed(self, event: Button.Pressed) -> None: - if event.button.id == "ok": - value = self.query_one(Input).value.strip() - self.dismiss(value or None) - else: - self.dismiss(None) - - -class NodeTypePrompt(ModalScreen[Optional[str]]): - CSS = """ - NodeTypePrompt { - align: center middle; - } - #dialog { - layout: grid; - grid-size: 2; - grid-gutter: 1 2; - grid-rows: 1fr 3; - padding: 0 1; - width: 60; - height: 11; - border: thick $background 80%; - background: $surface; - } - #node-type-title { - column-span: 2; - content-align: center middle; - } - """ - - def compose(self) -> ComposeResult: - with Grid(id="dialog"): - yield Label("Select node type", id="node-type-title") - yield Button("πŸ—ΊοΈ Region", id="node_region", variant="primary") - yield Button("πŸ“ Location", id="node_location") - yield Button("⭐ POI", id="node_poi") - yield Button("Cancel", id="node_cancel") - - def on_button_pressed(self, event: Button.Pressed) -> None: - if event.button.id == "node_region": - self.dismiss("region") - elif event.button.id == "node_location": - self.dismiss("location") - elif event.button.id == "node_poi": - self.dismiss("poi") - else: - self.dismiss(None) +from omnia.tools.scenario_builder import DEFAULT_SCENARIO +from omnia.tools.scenario_tui import EntitiesScreen, SpatialGraphScreen +from omnia.tools.scenario_tui.prompts import PathPrompt class ScenarioItem(ListItem): @@ -98,1018 +31,6 @@ class ScenarioItem(ListItem): self.path = path -class EntityItem(ListItem): - def __init__(self, index: int, entity: dict[str, Any]) -> None: - super().__init__(Label(self._label(entity))) - self.index = index - - @staticmethod - def _label(entity: dict[str, Any]) -> str: - name = str(entity.get("name") or "Unnamed").strip() - entity_id = str(entity.get("id") or "no-id").strip() - return f"{name} ({entity_id})" - - -@dataclass -class SpatialNodeRef: - node_type: str - obj: dict[str, Any] - - -class EntitiesScreen(Screen[Optional[dict[str, Any]]]): - CSS = """ - Screen { - layout: vertical; - } - #entities-main { - height: 1fr; - } - #entity-left { - width: 35%; - min-width: 24; - border: round $primary; - padding: 1; - } - #entity-right { - width: 1fr; - border: round $primary; - padding: 1; - } - #entity_list { - height: 1fr; - } - #entity-form { - height: 1fr; - } - #entity_stats, - #entity_voice, - #entity_memories { - height: 4; - } - #entity-traits-label, - #entity-stats-label, - #entity-voice-label, - #entity-memories-label { - margin-top: 1; - } - #entity_status { - height: 1; - padding: 0 1; - } - """ - - def __init__(self, scenario: dict[str, Any]) -> None: - super().__init__() - self.scenario = scenario - self.entities = self.scenario.setdefault("entities", []) - if not isinstance(self.entities, list): - self.entities = [] - self.scenario["entities"] = self.entities - self.selected_index: Optional[int] = None - - def compose(self) -> ComposeResult: - yield Header(show_clock=True) - with Horizontal(id="entities-main"): - with Vertical(id="entity-left"): - yield Label("Entities", id="entity-title") - yield ListView(id="entity_list") - with Horizontal(id="entity-actions"): - yield Button("New", id="entity_new") - yield Button("Delete", id="entity_delete") - with Vertical(id="entity-right"): - yield Label("Entity Details", id="entity-current") - with VerticalScroll(id="entity-form"): - yield Label("Entity ID", id="entity-id-label") - yield Input(id="entity_id") - yield Label("Name", id="entity-name-label") - yield Input(id="entity_name") - yield Label("Traits (comma-separated)", id="entity-traits-label") - yield Input(id="entity_traits") - yield Label("Stats (key=value per line)", id="entity-stats-label") - yield TextArea(id="entity_stats") - yield Label("Voice sample", id="entity-voice-label") - yield TextArea(id="entity_voice") - yield Label("Current mood", id="entity-mood-label") - yield Input(id="entity_mood") - yield Label("Location ID", id="entity-location-label") - yield Input(id="entity_location") - yield Label("Spatial descriptor", id="entity-spatial-label") - yield Input(id="entity_spatial") - yield Label( - "Memories (one per line; JSON allowed)", - id="entity-memories-label", - ) - yield TextArea(id="entity_memories") - with Horizontal(id="entity-actions-right"): - yield Button("Save", id="entity_save", variant="primary") - yield Button("Apply", id="entity_apply") - yield Button("Cancel", id="entity_cancel") - yield Label("", id="entity_status") - yield Footer() - - def on_mount(self) -> None: - self._reload_entity_list(select_index=0 if self.entities else None) - if self.entities: - self._load_entity_into_form(self.entities[0]) - self.selected_index = 0 - else: - self._load_entity_into_form(self._default_entity()) - self.selected_index = None - self._set_status("Editing entities") - - def _set_status(self, message: str) -> None: - self.query_one("#entity_status", Label).update(message) - - def _default_entity(self) -> dict[str, Any]: - return json.loads(json.dumps(DEFAULT_ENTITY)) - - def _load_entity_into_form(self, entity: dict[str, Any]) -> None: - self.query_one("#entity_id", Input).value = str(entity.get("id", "")) - self.query_one("#entity_name", Input).value = str(entity.get("name", "")) - traits = entity.get("traits", []) - if isinstance(traits, list): - traits_text = ", ".join(str(t).strip() for t in traits if str(t).strip()) - else: - traits_text = str(traits) - self.query_one("#entity_traits", Input).value = traits_text - - stats = entity.get("stats", {}) - stats_lines = [] - if isinstance(stats, dict): - for key, value in stats.items(): - stats_lines.append(f"{key}={value}") - self.query_one("#entity_stats", TextArea).text = "\n".join(stats_lines) - - self.query_one("#entity_voice", TextArea).text = str( - entity.get("voice_sample", "") - ) - self.query_one("#entity_mood", Input).value = str( - entity.get("current_mood", "") - ) - - metadata = ( - entity.get("metadata", {}) - if isinstance(entity.get("metadata"), dict) - else {} - ) - self.query_one("#entity_location", Input).value = str( - metadata.get("location", "") - ) - self.query_one("#entity_spatial", Input).value = str( - metadata.get("spatial_descriptor", "") - ) - - memories = entity.get("memories", []) - memory_lines = [] - if isinstance(memories, list): - for memory in memories: - if isinstance(memory, (dict, list)): - memory_lines.append(json.dumps(memory)) - else: - memory_lines.append(str(memory)) - self.query_one("#entity_memories", TextArea).text = "\n".join(memory_lines) - - def _build_entity_from_form(self) -> Optional[dict[str, Any]]: - entity_id = self.query_one("#entity_id", Input).value.strip() - name = self.query_one("#entity_name", Input).value.strip() - if not entity_id or not name: - self._set_status("Entity requires non-empty ID and name") - return None - - entity = self._default_entity() - entity["id"] = entity_id - entity["name"] = name - - traits_value = self.query_one("#entity_traits", Input).value - traits = [t.strip() for t in traits_value.split(",") if t.strip()] - entity["traits"] = traits - - stats_text = self.query_one("#entity_stats", TextArea).text - stats: dict[str, int] = {} - for line in stats_text.splitlines(): - raw = line.strip() - if not raw: - continue - if "=" not in raw: - self._set_status(f"Invalid stats line (missing '='): {raw}") - return None - key, value = raw.split("=", 1) - key = key.strip() - value = value.strip() - if not key or not value: - self._set_status(f"Invalid stats line: {raw}") - return None - try: - stats[key] = int(value) - except ValueError: - self._set_status(f"Stat value must be an integer: {raw}") - return None - entity["stats"] = stats - - entity["voice_sample"] = self.query_one("#entity_voice", TextArea).text.strip() - mood = self.query_one("#entity_mood", Input).value.strip() - entity["current_mood"] = mood if mood else "Neutral" - - entity["metadata"]["location"] = self.query_one( - "#entity_location", Input - ).value.strip() - entity["metadata"]["spatial_descriptor"] = self.query_one( - "#entity_spatial", Input - ).value.strip() - - memories_text = self.query_one("#entity_memories", TextArea).text - memories: list[Any] = [] - for line in memories_text.splitlines(): - raw = line.strip() - if not raw: - continue - if raw.startswith(("{", "[")): - try: - memories.append(json.loads(raw)) - except json.JSONDecodeError: - self._set_status(f"Invalid memory JSON: {raw}") - return None - else: - memories.append(raw) - entity["memories"] = memories - - return entity - - def _reload_entity_list(self, select_index: Optional[int] = None) -> None: - list_view = self.query_one("#entity_list", ListView) - list_view.clear() - for index, entity in enumerate(self.entities): - list_view.append(EntityItem(index, entity)) - if select_index is not None and self.entities: - clamped = max(0, min(select_index, len(self.entities) - 1)) - list_view.index = clamped - self.selected_index = clamped - - def on_list_view_selected(self, event: ListView.Selected) -> None: - if isinstance(event.item, EntityItem): - self.selected_index = event.item.index - self._load_entity_into_form(self.entities[event.item.index]) - - async def on_button_pressed(self, event: Button.Pressed) -> None: - button_id = event.button.id - if button_id == "entity_new": - self.selected_index = None - self._load_entity_into_form(self._default_entity()) - self._set_status("Creating new entity") - return - - if button_id == "entity_save": - entity = self._build_entity_from_form() - if entity is None: - return - - existing_ids = { - e.get("id") - for i, e in enumerate(self.entities) - if i != self.selected_index - } - if entity["id"] in existing_ids: - self._set_status(f"Entity ID '{entity['id']}' already exists") - return - - if self.selected_index is None: - self.entities.append(entity) - self.selected_index = len(self.entities) - 1 - else: - self.entities[self.selected_index] = entity - - self._reload_entity_list(select_index=self.selected_index) - self._set_status("Entity saved") - return - - if button_id == "entity_delete": - if self.selected_index is None: - self._set_status("Select an entity to delete") - return - removed = self.entities.pop(self.selected_index) - self._reload_entity_list(select_index=self.selected_index) - if self.entities: - self._load_entity_into_form( - self.entities[min(self.selected_index, len(self.entities) - 1)] - ) - else: - self._load_entity_into_form(self._default_entity()) - self.selected_index = None - self._set_status(f"Deleted {removed.get('name')}") - return - - if button_id == "entity_apply": - self.scenario["entities"] = self.entities - self.dismiss(self.scenario) - return - - if button_id == "entity_cancel": - self.dismiss(None) - - -class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): - CSS = """ - Screen { - layout: vertical; - } - #spatial-main { - height: 1fr; - } - #spatial-left { - width: 35%; - min-width: 24; - border: round $primary; - padding: 1; - } - #spatial-right { - width: 1fr; - border: round $primary; - padding: 1; - } - .hidden { - display: none; - } - #spatial-tree { - height: 1fr; - } - #node-description { - height: 5; - } - #spatial-status { - height: 1; - padding: 0 1; - } - #spatial-actions { - height: auto; - width: auto; - align-horizontal: left; - } - """ - - def __init__(self, scenario: dict[str, Any]) -> None: - super().__init__() - self.scenario = scenario - self.selected_node = None - self.node_map: dict[int, Any] = {} - self.parent_option_map: dict[str, tuple[str, dict[str, Any]]] = {} - - def compose(self) -> ComposeResult: - yield Header(show_clock=True) - with Horizontal(id="spatial-main"): - with Vertical(id="spatial-left"): - yield Label("Spatial Graph", id="spatial-title") - yield Tree("🌍 World", id="spatial-tree") - with Horizontal(id="spatial-tree-actions"): - yield Button("Add", id="spatial_add", variant="primary") - yield Button("Delete", id="spatial_delete") - with Horizontal(id="spatial-nesting-actions"): - yield Button("Promote", id="spatial_promote") - yield Button("Demote", id="spatial_demote") - with Vertical(id="spatial-right"): - yield Label("Node Details", id="node-title") - with Vertical(id="node-details"): - yield Label("Type", id="node-type-label") - yield Label("", id="node-type-value") - yield Label("ID", id="node-id-label") - yield Input(id="node_id") - yield Label("Name", id="node-name-label") - yield Input(id="node_name") - yield Label("Description", id="node-desc-label") - yield TextArea(id="node-description") - yield Label("Parent", id="node-parent-label") - yield Select([], prompt="Select parent", id="node_parent") - with Horizontal(id="node-actions"): - yield Button("Generate", id="node_generate") - yield Button("Update", id="node_update", variant="primary") - yield Button("Reset", id="node_reset") - with Horizontal(id="spatial-actions"): - yield Button("Save", id="spatial_save", variant="primary") - yield Button("Cancel", id="spatial_cancel") - yield Label("", id="spatial-status") - yield Footer() - - def on_mount(self) -> None: - self._ensure_spatial_graph() - self._build_tree() - self._clear_form() - self._set_status("Select a node to edit") - - def _set_status(self, message: str) -> None: - self.query_one("#spatial-status", Label).update(message) - - def _toast(self, message: str, severity: str = "warning") -> None: - try: - self.app.notify(message, severity=severity) - except Exception: - self._set_status(message) - - def _ensure_spatial_graph(self) -> None: - spatial_graph = self.scenario.get("spatial_graph") - if spatial_graph is None: - spatial_graph = json.loads(json.dumps(DEFAULT_SCENARIO["spatial_graph"])) - self.scenario["spatial_graph"] = spatial_graph - world = spatial_graph.setdefault("world", {}) - world.setdefault("regions", []) - world.setdefault("name", "World") - world.setdefault("id", "world") - - def _world(self) -> dict[str, Any]: - return self.scenario["spatial_graph"]["world"] - - def _build_tree( - self, - select_obj: Optional[dict[str, Any]] = None, - expanded: Optional[set[int]] = None, - ) -> None: - tree = self.query_one("#spatial-tree", Tree) - world = self._world() - tree.reset(f"🌍 {world.get('name', 'World')}") - tree.root.data = SpatialNodeRef("world", world) - tree.root.expand() - self.selected_node = None - self.node_map = {} - expanded = expanded or set() - - def add_region(region: dict[str, Any]) -> None: - node = tree.root.add( - self._label_for("region", region), data=SpatialNodeRef("region", region) - ) - self.node_map[id(region)] = node - if id(region) in expanded: - node.expand() - for location in region.get("locations", []): - add_location(node, location) - - def add_location(parent_node: Any, location: dict[str, Any]) -> None: - loc_node = parent_node.add( - self._label_for("location", location), - data=SpatialNodeRef("location", location), - ) - self.node_map[id(location)] = loc_node - if id(location) in expanded: - loc_node.expand() - for nested in location.get("locations", []): - add_location(loc_node, nested) - for poi in location.get("pois", []): - poi_node = loc_node.add( - self._label_for("poi", poi), data=SpatialNodeRef("poi", poi) - ) - self.node_map[id(poi)] = poi_node - if id(poi) in expanded: - poi_node.expand() - - for region in world.get("regions", []): - add_region(region) - - if select_obj and id(select_obj) in self.node_map: - self.selected_node = self.node_map[id(select_obj)] - self._expand_ancestors(self.selected_node) - self._load_form_from_node(self.selected_node) - else: - self._clear_form() - - def _capture_expanded_nodes(self) -> set[int]: - tree = self.query_one("#spatial-tree", Tree) - expanded: set[int] = set() - - def walk(node: Any) -> None: - if node.is_expanded and isinstance(node.data, SpatialNodeRef): - expanded.add(id(node.data.obj)) - for child in node.children: - walk(child) - - walk(tree.root) - return expanded - - def _expand_ancestors(self, node: Any) -> None: - current = node - while current is not None: - current.expand() - current = current.parent - - def _label_for(self, node_type: str, obj: dict[str, Any]) -> str: - name = str(obj.get("name") or "Unnamed").strip() - node_id = str(obj.get("id") or "no-id").strip() - icon = {"region": "πŸ—ΊοΈ", "location": "πŸ“", "poi": "⭐"}.get(node_type, "β€’") - return f"{icon} {name} ({node_id})" - - def _clear_form(self) -> None: - self.query_one("#node-type-value", Label).update("") - self.query_one("#node_id", Input).value = "" - self.query_one("#node_name", Input).value = "" - self.query_one("#node-description", TextArea).text = "" - self.query_one("#node_parent", Select).set_options([]) - self.query_one("#node_parent", Select).value = Select.NULL - self.query_one("#node-details", Vertical).add_class("hidden") - self.query_one("#node_generate", Button).disabled = True - self.query_one("#node_update", Button).disabled = True - self.query_one("#node_reset", Button).disabled = True - self.query_one("#spatial_delete", Button).disabled = True - self.query_one("#spatial_promote", Button).disabled = True - self.query_one("#spatial_demote", Button).disabled = True - - def _load_form_from_node(self, node: Any) -> None: - ref = node.data - if not isinstance(ref, SpatialNodeRef) or ref.node_type == "world": - self._clear_form() - return - self.query_one("#node-details", Vertical).remove_class("hidden") - self.query_one("#node-type-value", Label).update(ref.node_type.title()) - self.query_one("#node_id", Input).value = str(ref.obj.get("id", "")) - self.query_one("#node_name", Input).value = str(ref.obj.get("name", "")) - self.query_one("#node-description", TextArea).text = str( - ref.obj.get("description", "") - ) - self.query_one("#node_generate", Button).disabled = False - self.query_one("#node_update", Button).disabled = False - self.query_one("#node_reset", Button).disabled = False - self.query_one("#spatial_delete", Button).disabled = False - is_location = ref.node_type == "location" - can_promote, can_demote = ( - self._location_move_state(node) if is_location else (False, False) - ) - self.query_one("#spatial_promote", Button).disabled = not can_promote - self.query_one("#spatial_demote", Button).disabled = not can_demote - parent_select = self.query_one("#node_parent", Select) - if is_location: - options, selected_value, option_map = self._build_parent_options(ref.obj) - self.parent_option_map = option_map - parent_select.set_options(options) - parent_select.value = ( - selected_value if selected_value is not None else Select.NULL - ) - parent_select.disabled = False - else: - self.parent_option_map = {} - parent_select.set_options([]) - parent_select.value = Select.NULL - parent_select.disabled = True - - def on_tree_node_selected(self, event: Tree.NodeSelected) -> None: - self.selected_node = event.node - self._load_form_from_node(event.node) - - def _open_node_type_prompt(self) -> None: - self.app.push_screen(NodeTypePrompt(), callback=self._handle_node_type) - - def _handle_node_type(self, result: Optional[str]) -> None: - if result: - self._add_node(result) - - def _selected_ref(self) -> Optional[SpatialNodeRef]: - if self.selected_node is None: - return None - ref = self.selected_node.data - if isinstance(ref, SpatialNodeRef): - return ref - return None - - def _selected_parent_node(self) -> Optional[Any]: - if self.selected_node is None: - return None - return self.selected_node.parent - - def _parent_locations_list( - self, parent_node: Any - ) -> Optional[list[dict[str, Any]]]: - if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): - return None - parent_ref = parent_node.data - if parent_ref.node_type not in {"region", "location"}: - return None - return parent_ref.obj.setdefault("locations", []) - - def _location_move_state(self, node: Any) -> tuple[bool, bool]: - if node is None or not isinstance(node.data, SpatialNodeRef): - return (False, False) - ref = node.data - if ref.node_type != "location": - return (False, False) - parent_node = node.parent - parent_locations = self._parent_locations_list(parent_node) - if parent_locations is None: - return (False, False) - can_promote = False - if parent_node is not None and isinstance(parent_node.data, SpatialNodeRef): - parent_ref = parent_node.data - can_promote = parent_ref.node_type == "location" - can_demote = False - if ref.obj in parent_locations: - index = parent_locations.index(ref.obj) - can_demote = index < len(parent_locations) - 1 - return (can_promote, can_demote) - - def _parent_pois_list(self, parent_node: Any) -> Optional[list[dict[str, Any]]]: - if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): - return None - parent_ref = parent_node.data - if parent_ref.node_type != "location": - return None - return parent_ref.obj.setdefault("pois", []) - - def _collect_locations(self) -> list[dict[str, Any]]: - locations: list[dict[str, Any]] = [] - - def walk(location_list: list[dict[str, Any]]) -> None: - for location in location_list: - locations.append(location) - walk(location.get("locations", [])) - - for region in self._world().get("regions", []): - walk(region.get("locations", [])) - return locations - - def _location_descendants(self, location: dict[str, Any]) -> set[int]: - descendants: set[int] = set() - - def walk(children: list[dict[str, Any]]) -> None: - for child in children: - descendants.add(id(child)) - walk(child.get("locations", [])) - - walk(location.get("locations", [])) - return descendants - - def _find_location_parent(self, target: dict[str, Any]) -> Optional[tuple[str, dict[str, Any]]]: - world = self._world() - for region in world.get("regions", []): - if target in region.get("locations", []): - return ("region", region) - for location in region.get("locations", []): - found = self._find_location_parent_in_location(location, target) - if found: - return found - return None - - def _find_location_parent_in_location( - self, location: dict[str, Any], target: dict[str, Any] - ) -> Optional[tuple[str, dict[str, Any]]]: - if target in location.get("locations", []): - return ("location", location) - for nested in location.get("locations", []): - found = self._find_location_parent_in_location(nested, target) - if found: - return found - return None - - def _build_parent_options( - self, current: dict[str, Any] - ) -> tuple[list[tuple[str, str]], Optional[str], dict[str, tuple[str, dict[str, Any]]]]: - options: list[tuple[str, str]] = [] - option_map: dict[str, tuple[str, dict[str, Any]]] = {} - forbidden = {id(current)} | self._location_descendants(current) - selected_value: Optional[str] = None - - for region in self._world().get("regions", []): - region_id = str(region.get("id", "")).strip() - region_name = str(region.get("name", "")).strip() - value = f"region@{id(region)}" - label = f"πŸ—ΊοΈ {region_name} ({region_id})" if region_name else f"πŸ—ΊοΈ {region_id or 'region'}" - options.append((label, value)) - option_map[value] = ("region", region) - - for location in self._collect_locations(): - if id(location) in forbidden: - continue - location_id = str(location.get("id", "")).strip() - location_name = str(location.get("name", "")).strip() - value = f"location@{id(location)}" - label = ( - f"πŸ“ {location_name} ({location_id})" - if location_name - else f"πŸ“ {location_id or 'location'}" - ) - options.append((label, value)) - option_map[value] = ("location", location) - - parent_info = self._find_location_parent(current) - if parent_info: - parent_type, parent_obj = parent_info - for value, (entry_type, entry_obj) in option_map.items(): - if entry_type == parent_type and entry_obj is parent_obj: - selected_value = value - break - - return options, selected_value, option_map - - def _update_selected_node(self) -> None: - ref = self._selected_ref() - if ref is None or ref.node_type == "world": - self._set_status("Select a node to update") - return - node_id = self.query_one("#node_id", Input).value.strip() - name = self.query_one("#node_name", Input).value.strip() - description = self.query_one("#node-description", TextArea).text.strip() - if not node_id or not name: - self._set_status("ID and name are required") - return - ref.obj["id"] = node_id - ref.obj["name"] = name - ref.obj["description"] = description - if ref.node_type == "region": - ref.obj.setdefault("locations", []) - elif ref.node_type == "location": - ref.obj.setdefault("locations", []) - ref.obj.setdefault("pois", []) - elif ref.node_type == "poi": - ref.obj.setdefault("connections", []) - if ref.node_type == "location": - parent_select = self.query_one("#node_parent", Select) - if parent_select.value is not Select.NULL: - selected_value = str(parent_select.value) - parent_info = self.parent_option_map.get(selected_value) - if parent_info: - parent_type, parent_obj = parent_info - current_parent = self._find_location_parent(ref.obj) - if current_parent is None or current_parent[1] is not parent_obj: - if current_parent: - current_type, current_obj = current_parent - if current_type == "region": - current_obj.get("locations", []).remove(ref.obj) - elif current_type == "location": - current_obj.get("locations", []).remove(ref.obj) - parent_obj.setdefault("locations", []).append(ref.obj) - expanded = self._capture_expanded_nodes() - self._build_tree(select_obj=ref.obj, expanded=expanded) - self._set_status("Node updated") - return - if self.selected_node: - self.selected_node.label = self._label_for(ref.node_type, ref.obj) - self._set_status("Node updated") - - def _collect_ids(self, node_type: str) -> set[str]: - world = self._world() - ids: set[str] = set() - - if node_type == "region": - for region in world.get("regions", []): - if region_id := str(region.get("id", "")).strip(): - ids.add(region_id) - return ids - - def walk_locations(locations: list[dict[str, Any]]) -> None: - for location in locations: - if node_type == "location": - if loc_id := str(location.get("id", "")).strip(): - ids.add(loc_id) - if node_type == "poi": - for poi in location.get("pois", []): - if poi_id := str(poi.get("id", "")).strip(): - ids.add(poi_id) - walk_locations(location.get("locations", [])) - - for region in world.get("regions", []): - walk_locations(region.get("locations", [])) - - return ids - - def _node_path(self, node: Any) -> str: - parts = [] - current = node - while current is not None and isinstance(current.data, SpatialNodeRef): - ref = current.data - if ref.node_type == "world": - break - label = str(ref.obj.get("name") or ref.obj.get("id") or ref.node_type) - parts.append(f"{ref.node_type}: {label}") - current = current.parent - return " > ".join(reversed(parts)) if parts else "World" - - def _generate_node_fields(self) -> None: - ref = self._selected_ref() - if ref is None or ref.node_type == "world": - self._toast("Select a node to generate") - return - - scenario_meta = self.scenario.get("scenario", {}) - node_type = ref.node_type - system_prompt = """ -You are helping build a game scenario. -Return ONLY a single-line JSON object with keys: -id, name, description. -- id: short snake_case identifier -- description: 1-3 sentences -No commentary. ASCII only. -""".strip() - - user_prompt = f""" -Scenario title: {scenario_meta.get("title", "")} -Scenario description: {scenario_meta.get("description", "")} -Node type: {node_type} -Node path: {self._node_path(self.selected_node)} -Current id: {ref.obj.get("id", "")} -Current name: {ref.obj.get("name", "")} -""".strip() - - from omnia.llm_runtime import invoke_llm - - response_text = invoke_llm( - [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)] - ) - - try: - payload = json.loads(response_text) - except json.JSONDecodeError: - self._toast("LLM response was not valid JSON", severity="error") - return - - if not isinstance(payload, dict): - self._toast("LLM response must be a JSON object", severity="error") - return - - generated_id = str(payload.get("id", "")).strip() - generated_name = str(payload.get("name", "")).strip() - generated_description = str(payload.get("description", "")).strip() - - if not generated_id or not generated_name: - self._toast("Generated data missing id or name", severity="error") - return - - existing_ids = self._collect_ids(node_type) - if generated_id in existing_ids and generated_id != str(ref.obj.get("id", "")).strip(): - self._toast(f"Generated id '{generated_id}' already exists", severity="warning") - return - - ref.obj["id"] = generated_id - ref.obj["name"] = generated_name - ref.obj["description"] = generated_description - if self.selected_node: - self.selected_node.label = self._label_for(node_type, ref.obj) - self._load_form_from_node(self.selected_node) - self._set_status("Generated fields applied") - - def _reset_selected_node(self) -> None: - if self.selected_node is None: - return - self._load_form_from_node(self.selected_node) - self._set_status("Reverted changes") - - def _add_node(self, node_type: str) -> None: - world = self._world() - expanded = self._capture_expanded_nodes() - if node_type == "region": - region = {"id": "", "name": "", "description": "", "locations": []} - world.setdefault("regions", []).append(region) - self._build_tree(select_obj=region, expanded=expanded) - self._set_status("Region added") - return - - parent_node = self.selected_node - if node_type == "location": - if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): - self._toast("Select a region or location for the new location") - return - parent_ref = parent_node.data - if parent_ref.node_type not in {"region", "location"}: - self._toast("Select a region or location for the new location") - return - location = { - "id": "", - "name": "", - "description": "", - "locations": [], - "pois": [], - } - parent_ref.obj.setdefault("locations", []).append(location) - self._build_tree(select_obj=location, expanded=expanded) - self._set_status("Location added") - return - - if node_type == "poi": - if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): - self._toast("Select a location for the new POI") - return - parent_ref = parent_node.data - if parent_ref.node_type == "poi": - parent_node = parent_node.parent - if parent_node is None or not isinstance( - parent_node.data, SpatialNodeRef - ): - self._toast("Select a location for the new POI") - return - parent_ref = parent_node.data - if parent_ref.node_type != "location": - self._toast("Select a location for the new POI") - return - poi = {"id": "", "name": "", "description": "", "connections": []} - parent_ref.obj.setdefault("pois", []).append(poi) - self._build_tree(select_obj=poi, expanded=expanded) - self._set_status("POI added") - - def _delete_node(self) -> None: - ref = self._selected_ref() - if ref is None or ref.node_type == "world": - self._set_status("Select a node to delete") - return - parent_node = self._selected_parent_node() - expanded = self._capture_expanded_nodes() - if ref.node_type == "region": - world = self._world() - world["regions"] = [r for r in world.get("regions", []) if r is not ref.obj] - self._build_tree(expanded=expanded) - self._set_status("Region deleted") - return - if ref.node_type == "location": - parent_locations = self._parent_locations_list(parent_node) - if parent_locations is None: - self._set_status("Unable to delete location") - return - parent_locations[:] = [ - loc for loc in parent_locations if loc is not ref.obj - ] - self._build_tree(expanded=expanded) - self._set_status("Location deleted") - return - if ref.node_type == "poi": - parent_pois = self._parent_pois_list(parent_node) - if parent_pois is None: - self._set_status("Unable to delete POI") - return - parent_pois[:] = [poi for poi in parent_pois if poi is not ref.obj] - self._build_tree(expanded=expanded) - self._set_status("POI deleted") - - def _promote_location(self) -> None: - ref = self._selected_ref() - if ref is None or ref.node_type != "location": - self._set_status("Select a location to promote") - return - parent_node = self._selected_parent_node() - if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): - self._set_status("Unable to promote location") - return - parent_ref = parent_node.data - if parent_ref.node_type != "location": - self._set_status("Location is already top-level") - return - grandparent_node = parent_node.parent - parent_locations = parent_ref.obj.setdefault("locations", []) - if ref.obj in parent_locations: - parent_locations.remove(ref.obj) - target_list = self._parent_locations_list(grandparent_node) - if target_list is None: - self._set_status("Unable to promote location") - return - target_list.append(ref.obj) - expanded = self._capture_expanded_nodes() - self._build_tree(select_obj=ref.obj, expanded=expanded) - self._set_status("Location promoted") - - def _demote_location(self) -> None: - ref = self._selected_ref() - if ref is None or ref.node_type != "location": - self._set_status("Select a location to demote") - return - parent_node = self._selected_parent_node() - parent_locations = self._parent_locations_list(parent_node) - if parent_locations is None: - self._set_status("Unable to demote location") - return - try: - index = parent_locations.index(ref.obj) - except ValueError: - self._set_status("Unable to demote location") - return - if index >= len(parent_locations) - 1: - self._set_status("No next sibling to demote under") - return - new_parent = parent_locations[index + 1] - new_parent.setdefault("locations", []).append(ref.obj) - parent_locations.remove(ref.obj) - expanded = self._capture_expanded_nodes() - self._build_tree(select_obj=ref.obj, expanded=expanded) - self._set_status("Location demoted") - - def on_button_pressed(self, event: Button.Pressed) -> None: - button_id = event.button.id - if button_id == "spatial_add": - self._open_node_type_prompt() - return - if button_id == "spatial_delete": - self._delete_node() - return - if button_id == "spatial_promote": - self._promote_location() - return - if button_id == "spatial_demote": - self._demote_location() - return - if button_id == "node_generate": - self._generate_node_fields() - return - if button_id == "node_update": - self._update_selected_node() - return - if button_id == "node_reset": - self._reset_selected_node() - return - if button_id == "spatial_save": - self.dismiss(self.scenario) - return - if button_id == "spatial_cancel": - self.dismiss(None) - - class ScenarioBuilderTUI(App): CSS = """ Screen { @@ -1327,7 +248,9 @@ class ScenarioBuilderTUI(App): if location_id and location_id not in seen: seen.add(location_id) label = ( - f"{location_name} ({location_id})" if location_name else location_id + f"{location_name} ({location_id})" + if location_name + else location_id ) options.append((label, location_id)) walk_locations(location.get("locations", [])) diff --git a/src/omnia/tools/scenario_tui/__init__.py b/src/omnia/tools/scenario_tui/__init__.py new file mode 100644 index 0000000..0f61ee2 --- /dev/null +++ b/src/omnia/tools/scenario_tui/__init__.py @@ -0,0 +1,5 @@ +from .entities import EntitiesScreen +from .prompts import NodeTypePrompt, PathPrompt +from .spatial_graph import SpatialGraphScreen + +__all__ = ["EntitiesScreen", "NodeTypePrompt", "PathPrompt", "SpatialGraphScreen"] diff --git a/src/omnia/tools/scenario_tui/entities.py b/src/omnia/tools/scenario_tui/entities.py new file mode 100644 index 0000000..8691c02 --- /dev/null +++ b/src/omnia/tools/scenario_tui/entities.py @@ -0,0 +1,321 @@ +import json +from typing import Any, Optional + +from textual.app import ComposeResult +from textual.containers import Horizontal, Vertical, VerticalScroll +from textual.screen import Screen +from textual.widgets import ( + Button, + Footer, + Header, + Input, + Label, + ListItem, + ListView, + TextArea, +) + +from omnia.tools.scenario_builder import DEFAULT_ENTITY + + +class EntityItem(ListItem): + def __init__(self, index: int, entity: dict[str, Any]) -> None: + super().__init__(Label(self._label(entity))) + self.index = index + + @staticmethod + def _label(entity: dict[str, Any]) -> str: + name = str(entity.get("name") or "Unnamed").strip() + entity_id = str(entity.get("id") or "no-id").strip() + return f"{name} ({entity_id})" + + +class EntitiesScreen(Screen[Optional[dict[str, Any]]]): + CSS = """ + Screen { + layout: vertical; + } + #entities-main { + height: 1fr; + } + #entity-left { + width: 35%; + min-width: 24; + border: round $primary; + padding: 1; + } + #entity-right { + width: 1fr; + border: round $primary; + padding: 1; + } + #entity_list { + height: 1fr; + } + #entity-form { + height: 1fr; + } + #entity_stats, + #entity_voice, + #entity_memories { + height: 4; + } + #entity-traits-label, + #entity-stats-label, + #entity-voice-label, + #entity-memories-label { + margin-top: 1; + } + #entity_status { + height: 1; + padding: 0 1; + } + """ + + def __init__(self, scenario: dict[str, Any]) -> None: + super().__init__() + self.scenario = scenario + self.entities = self.scenario.setdefault("entities", []) + if not isinstance(self.entities, list): + self.entities = [] + self.scenario["entities"] = self.entities + self.selected_index: Optional[int] = None + + def compose(self) -> ComposeResult: + yield Header(show_clock=True) + with Horizontal(id="entities-main"): + with Vertical(id="entity-left"): + yield Label("Entities", id="entity-title") + yield ListView(id="entity_list") + with Horizontal(id="entity-actions"): + yield Button("New", id="entity_new") + yield Button("Delete", id="entity_delete") + with Vertical(id="entity-right"): + yield Label("Entity Details", id="entity-current") + with VerticalScroll(id="entity-form"): + yield Label("Entity ID", id="entity-id-label") + yield Input(id="entity_id") + yield Label("Name", id="entity-name-label") + yield Input(id="entity_name") + yield Label("Traits (comma-separated)", id="entity-traits-label") + yield Input(id="entity_traits") + yield Label("Stats (key=value per line)", id="entity-stats-label") + yield TextArea(id="entity_stats") + yield Label("Voice sample", id="entity-voice-label") + yield TextArea(id="entity_voice") + yield Label("Current mood", id="entity-mood-label") + yield Input(id="entity_mood") + yield Label("Location ID", id="entity-location-label") + yield Input(id="entity_location") + yield Label("Spatial descriptor", id="entity-spatial-label") + yield Input(id="entity_spatial") + yield Label( + "Memories (one per line; JSON allowed)", + id="entity-memories-label", + ) + yield TextArea(id="entity_memories") + with Horizontal(id="entity-actions-right"): + yield Button("Save", id="entity_save", variant="primary") + yield Button("Apply", id="entity_apply") + yield Button("Cancel", id="entity_cancel") + yield Label("", id="entity_status") + yield Footer() + + def on_mount(self) -> None: + self._reload_entity_list(select_index=0 if self.entities else None) + if self.entities: + self._load_entity_into_form(self.entities[0]) + self.selected_index = 0 + else: + self._load_entity_into_form(self._default_entity()) + self.selected_index = None + self._set_status("Editing entities") + + def _set_status(self, message: str) -> None: + self.query_one("#entity_status", Label).update(message) + + def _default_entity(self) -> dict[str, Any]: + return json.loads(json.dumps(DEFAULT_ENTITY)) + + def _load_entity_into_form(self, entity: dict[str, Any]) -> None: + self.query_one("#entity_id", Input).value = str(entity.get("id", "")) + self.query_one("#entity_name", Input).value = str(entity.get("name", "")) + traits = entity.get("traits", []) + if isinstance(traits, list): + traits_text = ", ".join(str(t).strip() for t in traits if str(t).strip()) + else: + traits_text = str(traits) + self.query_one("#entity_traits", Input).value = traits_text + + stats = entity.get("stats", {}) + stats_lines = [] + if isinstance(stats, dict): + for key, value in stats.items(): + stats_lines.append(f"{key}={value}") + self.query_one("#entity_stats", TextArea).text = "\n".join(stats_lines) + + self.query_one("#entity_voice", TextArea).text = str( + entity.get("voice_sample", "") + ) + self.query_one("#entity_mood", Input).value = str( + entity.get("current_mood", "") + ) + + metadata = ( + entity.get("metadata", {}) + if isinstance(entity.get("metadata"), dict) + else {} + ) + self.query_one("#entity_location", Input).value = str( + metadata.get("location", "") + ) + self.query_one("#entity_spatial", Input).value = str( + metadata.get("spatial_descriptor", "") + ) + + memories = entity.get("memories", []) + memory_lines = [] + if isinstance(memories, list): + for memory in memories: + if isinstance(memory, (dict, list)): + memory_lines.append(json.dumps(memory)) + else: + memory_lines.append(str(memory)) + self.query_one("#entity_memories", TextArea).text = "\n".join(memory_lines) + + def _build_entity_from_form(self) -> Optional[dict[str, Any]]: + entity_id = self.query_one("#entity_id", Input).value.strip() + name = self.query_one("#entity_name", Input).value.strip() + if not entity_id or not name: + self._set_status("Entity requires non-empty ID and name") + return None + + entity = self._default_entity() + entity["id"] = entity_id + entity["name"] = name + + traits_value = self.query_one("#entity_traits", Input).value + traits = [t.strip() for t in traits_value.split(",") if t.strip()] + entity["traits"] = traits + + stats_text = self.query_one("#entity_stats", TextArea).text + stats: dict[str, int] = {} + for line in stats_text.splitlines(): + raw = line.strip() + if not raw: + continue + if "=" not in raw: + self._set_status(f"Invalid stats line (missing '='): {raw}") + return None + key, value = raw.split("=", 1) + key = key.strip() + value = value.strip() + if not key or not value: + self._set_status(f"Invalid stats line: {raw}") + return None + try: + stats[key] = int(value) + except ValueError: + self._set_status(f"Stat value must be an integer: {raw}") + return None + entity["stats"] = stats + + entity["voice_sample"] = self.query_one("#entity_voice", TextArea).text.strip() + mood = self.query_one("#entity_mood", Input).value.strip() + entity["current_mood"] = mood if mood else "Neutral" + + entity["metadata"]["location"] = self.query_one( + "#entity_location", Input + ).value.strip() + entity["metadata"]["spatial_descriptor"] = self.query_one( + "#entity_spatial", Input + ).value.strip() + + memories_text = self.query_one("#entity_memories", TextArea).text + memories: list[Any] = [] + for line in memories_text.splitlines(): + raw = line.strip() + if not raw: + continue + if raw.startswith(("{", "[")): + try: + memories.append(json.loads(raw)) + except json.JSONDecodeError: + self._set_status(f"Invalid memory JSON: {raw}") + return None + else: + memories.append(raw) + entity["memories"] = memories + + return entity + + def _reload_entity_list(self, select_index: Optional[int] = None) -> None: + list_view = self.query_one("#entity_list", ListView) + list_view.clear() + for index, entity in enumerate(self.entities): + list_view.append(EntityItem(index, entity)) + if select_index is not None and self.entities: + clamped = max(0, min(select_index, len(self.entities) - 1)) + list_view.index = clamped + self.selected_index = clamped + + def on_list_view_selected(self, event: ListView.Selected) -> None: + if isinstance(event.item, EntityItem): + self.selected_index = event.item.index + self._load_entity_into_form(self.entities[event.item.index]) + + async def on_button_pressed(self, event: Button.Pressed) -> None: + button_id = event.button.id + if button_id == "entity_new": + self.selected_index = None + self._load_entity_into_form(self._default_entity()) + self._set_status("Creating new entity") + return + + if button_id == "entity_save": + entity = self._build_entity_from_form() + if entity is None: + return + + existing_ids = { + e.get("id") + for i, e in enumerate(self.entities) + if i != self.selected_index + } + if entity["id"] in existing_ids: + self._set_status(f"Entity ID '{entity['id']}' already exists") + return + + if self.selected_index is None: + self.entities.append(entity) + self.selected_index = len(self.entities) - 1 + else: + self.entities[self.selected_index] = entity + + self._reload_entity_list(select_index=self.selected_index) + self._set_status("Entity saved") + return + + if button_id == "entity_delete": + if self.selected_index is None: + self._set_status("Select an entity to delete") + return + removed = self.entities.pop(self.selected_index) + self._reload_entity_list(select_index=self.selected_index) + if self.entities: + self._load_entity_into_form( + self.entities[min(self.selected_index, len(self.entities) - 1)] + ) + else: + self._load_entity_into_form(self._default_entity()) + self.selected_index = None + self._set_status(f"Deleted {removed.get('name')}") + return + + if button_id == "entity_apply": + self.scenario["entities"] = self.entities + self.dismiss(self.scenario) + return + + if button_id == "entity_cancel": + self.dismiss(None) diff --git a/src/omnia/tools/scenario_tui/prompts.py b/src/omnia/tools/scenario_tui/prompts.py new file mode 100644 index 0000000..3a280f2 --- /dev/null +++ b/src/omnia/tools/scenario_tui/prompts.py @@ -0,0 +1,71 @@ +from typing import Optional + +from textual.app import ComposeResult +from textual.containers import Grid, Horizontal +from textual.screen import ModalScreen +from textual.widgets import Button, Input, Label + + +class PathPrompt(ModalScreen[Optional[str]]): + def __init__(self, title: str, default: str = "") -> None: + super().__init__() + self.title = title + self.default = default + + def compose(self) -> ComposeResult: + yield Label(self.title, id="prompt-title") + yield Input(value=self.default, id="prompt-input") + with Horizontal(id="prompt-buttons"): + yield Button("OK", id="ok", variant="primary") + yield Button("Cancel", id="cancel") + + def on_mount(self) -> None: + self.query_one(Input).focus() + + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "ok": + value = self.query_one(Input).value.strip() + self.dismiss(value or None) + else: + self.dismiss(None) + + +class NodeTypePrompt(ModalScreen[Optional[str]]): + CSS = """ + NodeTypePrompt { + align: center middle; + } + #dialog { + layout: grid; + grid-size: 2; + grid-gutter: 1 2; + grid-rows: 1fr 3; + padding: 0 1; + width: 60; + height: 11; + border: thick $background 80%; + background: $surface; + } + #node-type-title { + column-span: 2; + content-align: center middle; + } + """ + + def compose(self) -> ComposeResult: + with Grid(id="dialog"): + yield Label("Select node type", id="node-type-title") + yield Button("πŸ—ΊοΈ Region", id="node_region", variant="primary") + yield Button("πŸ“ Location", id="node_location") + yield Button("⭐ POI", id="node_poi") + yield Button("Cancel", id="node_cancel") + + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "node_region": + self.dismiss("region") + elif event.button.id == "node_location": + self.dismiss("location") + elif event.button.id == "node_poi": + self.dismiss("poi") + else: + self.dismiss(None) diff --git a/src/omnia/tools/scenario_tui/spatial_graph.py b/src/omnia/tools/scenario_tui/spatial_graph.py new file mode 100644 index 0000000..f4c4734 --- /dev/null +++ b/src/omnia/tools/scenario_tui/spatial_graph.py @@ -0,0 +1,738 @@ +import json +from dataclasses import dataclass +from typing import Any, Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from textual.app import ComposeResult +from textual.containers import Horizontal, Vertical +from textual.screen import Screen +from textual.widgets import ( + Button, + Footer, + Header, + Input, + Label, + Select, + TextArea, + Tree, +) + +from omnia.tools.scenario_builder import DEFAULT_SCENARIO +from omnia.tools.scenario_tui.prompts import NodeTypePrompt + + +@dataclass +class SpatialNodeRef: + node_type: str + obj: dict[str, Any] + + +class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): + CSS = """ + Screen { + layout: vertical; + } + #spatial-main { + height: 1fr; + } + #spatial-left { + width: 35%; + min-width: 24; + border: round $primary; + padding: 1; + } + #spatial-right { + width: 1fr; + border: round $primary; + padding: 1; + } + .hidden { + display: none; + } + #spatial-tree { + height: 1fr; + } + #node-description { + height: 5; + } + #spatial-status { + height: 1; + padding: 0 1; + } + #spatial-actions { + height: auto; + width: auto; + align-horizontal: left; + } + """ + + def __init__(self, scenario: dict[str, Any]) -> None: + super().__init__() + self.scenario = scenario + self.selected_node = None + self.node_map: dict[int, Any] = {} + self.parent_option_map: dict[str, tuple[str, dict[str, Any]]] = {} + + def compose(self) -> ComposeResult: + yield Header(show_clock=True) + with Horizontal(id="spatial-main"): + with Vertical(id="spatial-left"): + yield Label("Spatial Graph", id="spatial-title") + yield Tree("🌍 World", id="spatial-tree") + with Horizontal(id="spatial-tree-actions"): + yield Button("Add", id="spatial_add", variant="primary") + yield Button("Delete", id="spatial_delete") + with Horizontal(id="spatial-nesting-actions"): + yield Button("Promote", id="spatial_promote") + yield Button("Demote", id="spatial_demote") + with Vertical(id="spatial-right"): + yield Label("Node Details", id="node-title") + with Vertical(id="node-details"): + yield Label("Type", id="node-type-label") + yield Label("", id="node-type-value") + yield Label("ID", id="node-id-label") + yield Input(id="node_id") + yield Label("Name", id="node-name-label") + yield Input(id="node_name") + yield Label("Description", id="node-desc-label") + yield TextArea(id="node-description") + yield Label("Parent", id="node-parent-label") + yield Select([], prompt="Select parent", id="node_parent") + with Horizontal(id="node-actions"): + yield Button("Generate", id="node_generate") + yield Button("Update", id="node_update", variant="primary") + yield Button("Reset", id="node_reset") + with Horizontal(id="spatial-actions"): + yield Button("Save", id="spatial_save", variant="primary") + yield Button("Cancel", id="spatial_cancel") + yield Label("", id="spatial-status") + yield Footer() + + def on_mount(self) -> None: + self._ensure_spatial_graph() + self._build_tree() + self._clear_form() + self._set_status("Select a node to edit") + + def _set_status(self, message: str) -> None: + self.query_one("#spatial-status", Label).update(message) + + def _toast(self, message: str, severity: str = "warning") -> None: + try: + self.app.notify(message, severity=severity) + except Exception: + self._set_status(message) + + def _ensure_spatial_graph(self) -> None: + spatial_graph = self.scenario.get("spatial_graph") + if spatial_graph is None: + spatial_graph = json.loads(json.dumps(DEFAULT_SCENARIO["spatial_graph"])) + self.scenario["spatial_graph"] = spatial_graph + world = spatial_graph.setdefault("world", {}) + world.setdefault("regions", []) + world.setdefault("name", "World") + world.setdefault("id", "world") + + def _world(self) -> dict[str, Any]: + return self.scenario["spatial_graph"]["world"] + + def _build_tree( + self, + select_obj: Optional[dict[str, Any]] = None, + expanded: Optional[set[int]] = None, + ) -> None: + tree = self.query_one("#spatial-tree", Tree) + world = self._world() + tree.reset(f"🌍 {world.get('name', 'World')}") + tree.root.data = SpatialNodeRef("world", world) + tree.root.expand() + self.selected_node = None + self.node_map = {} + expanded = expanded or set() + + def add_region(region: dict[str, Any]) -> None: + node = tree.root.add( + self._label_for("region", region), data=SpatialNodeRef("region", region) + ) + self.node_map[id(region)] = node + if id(region) in expanded: + node.expand() + for location in region.get("locations", []): + add_location(node, location) + + def add_location(parent_node: Any, location: dict[str, Any]) -> None: + loc_node = parent_node.add( + self._label_for("location", location), + data=SpatialNodeRef("location", location), + ) + self.node_map[id(location)] = loc_node + if id(location) in expanded: + loc_node.expand() + for nested in location.get("locations", []): + add_location(loc_node, nested) + for poi in location.get("pois", []): + poi_node = loc_node.add( + self._label_for("poi", poi), data=SpatialNodeRef("poi", poi) + ) + self.node_map[id(poi)] = poi_node + if id(poi) in expanded: + poi_node.expand() + + for region in world.get("regions", []): + add_region(region) + + if select_obj and id(select_obj) in self.node_map: + self.selected_node = self.node_map[id(select_obj)] + self._expand_ancestors(self.selected_node) + self._load_form_from_node(self.selected_node) + else: + self._clear_form() + + def _label_for(self, node_type: str, obj: dict[str, Any]) -> str: + name = str(obj.get("name") or "Unnamed").strip() + node_id = str(obj.get("id") or "no-id").strip() + icon = {"region": "πŸ—ΊοΈ", "location": "πŸ“", "poi": "⭐"}.get(node_type, "β€’") + return f"{icon} {name} ({node_id})" + + def _capture_expanded_nodes(self) -> set[int]: + tree = self.query_one("#spatial-tree", Tree) + expanded: set[int] = set() + + def walk(node: Any) -> None: + if node.is_expanded and isinstance(node.data, SpatialNodeRef): + expanded.add(id(node.data.obj)) + for child in node.children: + walk(child) + + walk(tree.root) + return expanded + + def _expand_ancestors(self, node: Any) -> None: + current = node + while current is not None: + current.expand() + current = current.parent + + def _clear_form(self) -> None: + self.query_one("#node-type-value", Label).update("") + self.query_one("#node_id", Input).value = "" + self.query_one("#node_name", Input).value = "" + self.query_one("#node-description", TextArea).text = "" + self.query_one("#node_parent", Select).set_options([]) + self.query_one("#node_parent", Select).value = Select.NULL + self.query_one("#node-details", Vertical).add_class("hidden") + self.query_one("#node_generate", Button).disabled = True + self.query_one("#node_update", Button).disabled = True + self.query_one("#node_reset", Button).disabled = True + self.query_one("#spatial_delete", Button).disabled = True + self.query_one("#spatial_promote", Button).disabled = True + self.query_one("#spatial_demote", Button).disabled = True + + def _load_form_from_node(self, node: Any) -> None: + ref = node.data + if not isinstance(ref, SpatialNodeRef) or ref.node_type == "world": + self._clear_form() + return + self.query_one("#node-details", Vertical).remove_class("hidden") + self.query_one("#node-type-value", Label).update(ref.node_type.title()) + self.query_one("#node_id", Input).value = str(ref.obj.get("id", "")) + self.query_one("#node_name", Input).value = str(ref.obj.get("name", "")) + self.query_one("#node-description", TextArea).text = str( + ref.obj.get("description", "") + ) + self.query_one("#node_generate", Button).disabled = False + self.query_one("#node_update", Button).disabled = False + self.query_one("#node_reset", Button).disabled = False + self.query_one("#spatial_delete", Button).disabled = False + is_location = ref.node_type == "location" + can_promote, can_demote = ( + self._location_move_state(node) if is_location else (False, False) + ) + self.query_one("#spatial_promote", Button).disabled = not can_promote + self.query_one("#spatial_demote", Button).disabled = not can_demote + parent_select = self.query_one("#node_parent", Select) + if is_location: + options, selected_value, option_map = self._build_parent_options(ref.obj) + self.parent_option_map = option_map + parent_select.set_options(options) + parent_select.value = ( + selected_value if selected_value is not None else Select.NULL + ) + parent_select.disabled = False + else: + self.parent_option_map = {} + parent_select.set_options([]) + parent_select.value = Select.NULL + parent_select.disabled = True + + def on_tree_node_selected(self, event: Tree.NodeSelected) -> None: + self.selected_node = event.node + self._load_form_from_node(event.node) + + def _open_node_type_prompt(self) -> None: + self.app.push_screen(NodeTypePrompt(), callback=self._handle_node_type) + + def _handle_node_type(self, result: Optional[str]) -> None: + if result: + self._add_node(result) + + def _selected_ref(self) -> Optional[SpatialNodeRef]: + if self.selected_node is None: + return None + ref = self.selected_node.data + if isinstance(ref, SpatialNodeRef): + return ref + return None + + def _selected_parent_node(self) -> Optional[Any]: + if self.selected_node is None: + return None + return self.selected_node.parent + + def _parent_locations_list( + self, parent_node: Any + ) -> Optional[list[dict[str, Any]]]: + if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): + return None + parent_ref = parent_node.data + if parent_ref.node_type not in {"region", "location"}: + return None + return parent_ref.obj.setdefault("locations", []) + + def _location_move_state(self, node: Any) -> tuple[bool, bool]: + if node is None or not isinstance(node.data, SpatialNodeRef): + return (False, False) + ref = node.data + if ref.node_type != "location": + return (False, False) + parent_node = node.parent + parent_locations = self._parent_locations_list(parent_node) + if parent_locations is None: + return (False, False) + can_promote = False + if parent_node is not None and isinstance(parent_node.data, SpatialNodeRef): + parent_ref = parent_node.data + can_promote = parent_ref.node_type == "location" + can_demote = False + if ref.obj in parent_locations: + index = parent_locations.index(ref.obj) + can_demote = index < len(parent_locations) - 1 + return (can_promote, can_demote) + + def _parent_pois_list(self, parent_node: Any) -> Optional[list[dict[str, Any]]]: + if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): + return None + parent_ref = parent_node.data + if parent_ref.node_type != "location": + return None + return parent_ref.obj.setdefault("pois", []) + + def _collect_locations(self) -> list[dict[str, Any]]: + locations: list[dict[str, Any]] = [] + + def walk(location_list: list[dict[str, Any]]) -> None: + for location in location_list: + locations.append(location) + walk(location.get("locations", [])) + + for region in self._world().get("regions", []): + walk(region.get("locations", [])) + return locations + + def _location_descendants(self, location: dict[str, Any]) -> set[int]: + descendants: set[int] = set() + + def walk(children: list[dict[str, Any]]) -> None: + for child in children: + descendants.add(id(child)) + walk(child.get("locations", [])) + + walk(location.get("locations", [])) + return descendants + + def _find_location_parent( + self, target: dict[str, Any] + ) -> Optional[tuple[str, dict[str, Any]]]: + world = self._world() + for region in world.get("regions", []): + if target in region.get("locations", []): + return ("region", region) + for location in region.get("locations", []): + found = self._find_location_parent_in_location(location, target) + if found: + return found + return None + + def _find_location_parent_in_location( + self, location: dict[str, Any], target: dict[str, Any] + ) -> Optional[tuple[str, dict[str, Any]]]: + if target in location.get("locations", []): + return ("location", location) + for nested in location.get("locations", []): + found = self._find_location_parent_in_location(nested, target) + if found: + return found + return None + + def _build_parent_options( + self, current: dict[str, Any] + ) -> tuple[list[tuple[str, str]], Optional[str], dict[str, tuple[str, dict[str, Any]]]]: + options: list[tuple[str, str]] = [] + option_map: dict[str, tuple[str, dict[str, Any]]] = {} + forbidden = {id(current)} | self._location_descendants(current) + selected_value: Optional[str] = None + + for region in self._world().get("regions", []): + region_id = str(region.get("id", "")).strip() + region_name = str(region.get("name", "")).strip() + value = f"region@{id(region)}" + label = ( + f"πŸ—ΊοΈ {region_name} ({region_id})" + if region_name + else f"πŸ—ΊοΈ {region_id or 'region'}" + ) + options.append((label, value)) + option_map[value] = ("region", region) + + for location in self._collect_locations(): + if id(location) in forbidden: + continue + location_id = str(location.get("id", "")).strip() + location_name = str(location.get("name", "")).strip() + value = f"location@{id(location)}" + label = ( + f"πŸ“ {location_name} ({location_id})" + if location_name + else f"πŸ“ {location_id or 'location'}" + ) + options.append((label, value)) + option_map[value] = ("location", location) + + parent_info = self._find_location_parent(current) + if parent_info: + parent_type, parent_obj = parent_info + for value, (entry_type, entry_obj) in option_map.items(): + if entry_type == parent_type and entry_obj is parent_obj: + selected_value = value + break + + return options, selected_value, option_map + + def _update_selected_node(self) -> None: + ref = self._selected_ref() + if ref is None or ref.node_type == "world": + self._set_status("Select a node to update") + return + node_id = self.query_one("#node_id", Input).value.strip() + name = self.query_one("#node_name", Input).value.strip() + description = self.query_one("#node-description", TextArea).text.strip() + if not node_id or not name: + self._set_status("ID and name are required") + return + ref.obj["id"] = node_id + ref.obj["name"] = name + ref.obj["description"] = description + if ref.node_type == "region": + ref.obj.setdefault("locations", []) + elif ref.node_type == "location": + ref.obj.setdefault("locations", []) + ref.obj.setdefault("pois", []) + elif ref.node_type == "poi": + ref.obj.setdefault("connections", []) + if ref.node_type == "location": + parent_select = self.query_one("#node_parent", Select) + if parent_select.value is not Select.NULL: + selected_value = str(parent_select.value) + parent_info = self.parent_option_map.get(selected_value) + if parent_info: + parent_type, parent_obj = parent_info + current_parent = self._find_location_parent(ref.obj) + if current_parent is None or current_parent[1] is not parent_obj: + if current_parent: + current_type, current_obj = current_parent + if current_type == "region": + current_obj.get("locations", []).remove(ref.obj) + elif current_type == "location": + current_obj.get("locations", []).remove(ref.obj) + parent_obj.setdefault("locations", []).append(ref.obj) + expanded = self._capture_expanded_nodes() + self._build_tree(select_obj=ref.obj, expanded=expanded) + self._set_status("Node updated") + return + if self.selected_node: + self.selected_node.label = self._label_for(ref.node_type, ref.obj) + self._set_status("Node updated") + + def _node_path(self, node: Any) -> str: + parts = [] + current = node + while current is not None and isinstance(current.data, SpatialNodeRef): + ref = current.data + if ref.node_type == "world": + break + label = str(ref.obj.get("name") or ref.obj.get("id") or ref.node_type) + parts.append(f"{ref.node_type}: {label}") + current = current.parent + return " > ".join(reversed(parts)) if parts else "World" + + def _collect_ids(self, node_type: str) -> set[str]: + world = self._world() + ids: set[str] = set() + + if node_type == "region": + for region in world.get("regions", []): + if region_id := str(region.get("id", "")).strip(): + ids.add(region_id) + return ids + + def walk_locations(locations: list[dict[str, Any]]) -> None: + for location in locations: + if node_type == "location": + if loc_id := str(location.get("id", "")).strip(): + ids.add(loc_id) + if node_type == "poi": + for poi in location.get("pois", []): + if poi_id := str(poi.get("id", "")).strip(): + ids.add(poi_id) + walk_locations(location.get("locations", [])) + + for region in world.get("regions", []): + walk_locations(region.get("locations", [])) + + return ids + + def _generate_node_fields(self) -> None: + ref = self._selected_ref() + if ref is None or ref.node_type == "world": + self._toast("Select a node to generate") + return + + scenario_meta = self.scenario.get("scenario", {}) + node_type = ref.node_type + system_prompt = """ +You are helping build a game scenario. +Return ONLY a single-line JSON object with keys: +id, name, description. +- id: short snake_case identifier +- description: 1-3 sentences +No commentary. ASCII only. +""".strip() + + user_prompt = f""" +Scenario title: {scenario_meta.get("title", "")} +Scenario description: {scenario_meta.get("description", "")} +Node type: {node_type} +Node path: {self._node_path(self.selected_node)} +Current id: {ref.obj.get("id", "")} +Current name: {ref.obj.get("name", "")} +""".strip() + + from omnia.llm_runtime import invoke_llm + + response_text = invoke_llm( + [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)] + ) + + try: + payload = json.loads(response_text) + except json.JSONDecodeError: + self._toast("LLM response was not valid JSON", severity="error") + return + + if not isinstance(payload, dict): + self._toast("LLM response must be a JSON object", severity="error") + return + + generated_id = str(payload.get("id", "")).strip() + generated_name = str(payload.get("name", "")).strip() + generated_description = str(payload.get("description", "")).strip() + + if not generated_id or not generated_name: + self._toast("Generated data missing id or name", severity="error") + return + + existing_ids = self._collect_ids(node_type) + if generated_id in existing_ids and generated_id != str(ref.obj.get("id", "")).strip(): + self._toast( + f"Generated id '{generated_id}' already exists", severity="warning" + ) + return + + ref.obj["id"] = generated_id + ref.obj["name"] = generated_name + ref.obj["description"] = generated_description + if self.selected_node: + self.selected_node.label = self._label_for(node_type, ref.obj) + self._load_form_from_node(self.selected_node) + self._set_status("Generated fields applied") + + def _reset_selected_node(self) -> None: + if self.selected_node is None: + return + self._load_form_from_node(self.selected_node) + self._set_status("Reverted changes") + + def _add_node(self, node_type: str) -> None: + world = self._world() + expanded = self._capture_expanded_nodes() + if node_type == "region": + region = {"id": "", "name": "", "description": "", "locations": []} + world.setdefault("regions", []).append(region) + self._build_tree(select_obj=region, expanded=expanded) + self._set_status("Region added") + return + + parent_node = self.selected_node + if node_type == "location": + if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): + self._toast("Select a region or location for the new location") + return + parent_ref = parent_node.data + if parent_ref.node_type not in {"region", "location"}: + self._toast("Select a region or location for the new location") + return + location = { + "id": "", + "name": "", + "description": "", + "locations": [], + "pois": [], + } + parent_ref.obj.setdefault("locations", []).append(location) + self._build_tree(select_obj=location, expanded=expanded) + self._set_status("Location added") + return + + if node_type == "poi": + if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): + self._toast("Select a location for the new POI") + return + parent_ref = parent_node.data + if parent_ref.node_type == "poi": + parent_node = parent_node.parent + if parent_node is None or not isinstance( + parent_node.data, SpatialNodeRef + ): + self._toast("Select a location for the new POI") + return + parent_ref = parent_node.data + if parent_ref.node_type != "location": + self._toast("Select a location for the new POI") + return + poi = {"id": "", "name": "", "description": "", "connections": []} + parent_ref.obj.setdefault("pois", []).append(poi) + self._build_tree(select_obj=poi, expanded=expanded) + self._set_status("POI added") + + def _delete_node(self) -> None: + ref = self._selected_ref() + if ref is None or ref.node_type == "world": + self._set_status("Select a node to delete") + return + parent_node = self._selected_parent_node() + expanded = self._capture_expanded_nodes() + if ref.node_type == "region": + world = self._world() + world["regions"] = [r for r in world.get("regions", []) if r is not ref.obj] + self._build_tree(expanded=expanded) + self._set_status("Region deleted") + return + if ref.node_type == "location": + parent_locations = self._parent_locations_list(parent_node) + if parent_locations is None: + self._set_status("Unable to delete location") + return + parent_locations[:] = [ + loc for loc in parent_locations if loc is not ref.obj + ] + self._build_tree(expanded=expanded) + self._set_status("Location deleted") + return + if ref.node_type == "poi": + parent_pois = self._parent_pois_list(parent_node) + if parent_pois is None: + self._set_status("Unable to delete POI") + return + parent_pois[:] = [poi for poi in parent_pois if poi is not ref.obj] + self._build_tree(expanded=expanded) + self._set_status("POI deleted") + + def _promote_location(self) -> None: + ref = self._selected_ref() + if ref is None or ref.node_type != "location": + self._set_status("Select a location to promote") + return + parent_node = self._selected_parent_node() + if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef): + self._set_status("Unable to promote location") + return + parent_ref = parent_node.data + if parent_ref.node_type != "location": + self._set_status("Location is already top-level") + return + grandparent_node = parent_node.parent + parent_locations = parent_ref.obj.setdefault("locations", []) + if ref.obj in parent_locations: + parent_locations.remove(ref.obj) + target_list = self._parent_locations_list(grandparent_node) + if target_list is None: + self._set_status("Unable to promote location") + return + target_list.append(ref.obj) + expanded = self._capture_expanded_nodes() + self._build_tree(select_obj=ref.obj, expanded=expanded) + self._set_status("Location promoted") + + def _demote_location(self) -> None: + ref = self._selected_ref() + if ref is None or ref.node_type != "location": + self._set_status("Select a location to demote") + return + parent_node = self._selected_parent_node() + parent_locations = self._parent_locations_list(parent_node) + if parent_locations is None: + self._set_status("Unable to demote location") + return + try: + index = parent_locations.index(ref.obj) + except ValueError: + self._set_status("Unable to demote location") + return + if index >= len(parent_locations) - 1: + self._set_status("No next sibling to demote under") + return + new_parent = parent_locations[index + 1] + new_parent.setdefault("locations", []).append(ref.obj) + parent_locations.remove(ref.obj) + expanded = self._capture_expanded_nodes() + self._build_tree(select_obj=ref.obj, expanded=expanded) + self._set_status("Location demoted") + + def on_button_pressed(self, event: Button.Pressed) -> None: + button_id = event.button.id + if button_id == "spatial_add": + self._open_node_type_prompt() + return + if button_id == "spatial_delete": + self._delete_node() + return + if button_id == "spatial_promote": + self._promote_location() + return + if button_id == "spatial_demote": + self._demote_location() + return + if button_id == "node_generate": + self._generate_node_fields() + return + if button_id == "node_update": + self._update_selected_node() + return + if button_id == "node_reset": + self._reset_selected_node() + return + if button_id == "spatial_save": + self.dismiss(self.scenario) + return + if button_id == "spatial_cancel": + self.dismiss(None)