360 lines
12 KiB
Python
360 lines
12 KiB
Python
import argparse
|
|
import datetime as dt
|
|
import json
|
|
import os
|
|
import secrets
|
|
import string
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
from google import genai
|
|
from loguru import logger
|
|
|
|
from .enums import TokenLabel
|
|
|
|
DEFAULT_BATCH_SIZE = 20
|
|
DEFAULT_MODEL = "gemini-2.5-flash-lite"
|
|
DEFAULT_RAW_LOG_PATH = "logs/gemini_raw.log"
|
|
DEFAULT_CONVERT_INPUT_DIR = "datasets/preannotated"
|
|
DEFAULT_CONVERT_OUTPUT_DIR = "datasets/annotated"
|
|
LABEL_NAMES = [label.name for label in TokenLabel]
|
|
VALUE_TO_NAME = {label.value: label.name for label in TokenLabel}
|
|
ID_ALPHABET = string.ascii_uppercase + string.digits
|
|
|
|
|
|
def _chunked(items: list[int], size: int) -> Iterable[list[int]]:
|
|
if size < 1:
|
|
raise ValueError("batch_size must be at least 1.")
|
|
for start in range(0, len(items), size):
|
|
yield items[start : start + size]
|
|
|
|
|
|
def _generate_id() -> str:
|
|
return (
|
|
f"{''.join(secrets.choice(ID_ALPHABET) for _ in range(3))}"
|
|
f"-{''.join(secrets.choice(ID_ALPHABET) for _ in range(6))}"
|
|
)
|
|
|
|
|
|
def _all_occurrences(text: str, span: str) -> list[int]:
|
|
occurrences = []
|
|
start = 0
|
|
while True:
|
|
idx = text.find(span, start)
|
|
if idx == -1:
|
|
break
|
|
occurrences.append(idx)
|
|
start = idx + 1
|
|
return occurrences
|
|
|
|
|
|
def _build_prompt(texts: list[str]) -> str:
|
|
labels = ", ".join(LABEL_NAMES)
|
|
return (
|
|
"You are a token pre-annotator. For each input text, return JSON with tagged "
|
|
"token/subword/word/span labels.\n"
|
|
f"Allowed labels: {labels}.\n"
|
|
"Rules:\n"
|
|
"- Output ONLY valid JSON (no markdown).\n"
|
|
"- Return a JSON array with the same length/order as the input.\n"
|
|
"- Each item must be an object: "
|
|
'{"text": "<original>", "tags": [{"span": "<exact substring>", "label": "<LABEL>"}]}.\n'
|
|
"- The span must be an exact substring of the original text.\n"
|
|
"- Use RAW_PHRASE when no other label applies.\n\n"
|
|
f"Input texts: {json.dumps(texts, ensure_ascii=True)}"
|
|
)
|
|
|
|
|
|
def _normalize_label(label: str) -> str:
|
|
if label in LABEL_NAMES:
|
|
return label
|
|
if label in VALUE_TO_NAME:
|
|
return VALUE_TO_NAME[label]
|
|
raise ValueError(f"Unknown label: {label}")
|
|
|
|
|
|
def _normalize_result(text: str, result: dict) -> dict:
|
|
if result.get("text") != text:
|
|
raise ValueError("Gemini result text does not match input text.")
|
|
tags = result.get("tags")
|
|
if not isinstance(tags, list):
|
|
raise ValueError("Gemini result tags must be a list.")
|
|
normalized_tags = []
|
|
for tag in tags:
|
|
if not isinstance(tag, dict):
|
|
raise ValueError("Each tag must be an object.")
|
|
span = tag.get("span")
|
|
label = tag.get("label")
|
|
if not isinstance(span, str) or not isinstance(label, str):
|
|
raise ValueError("Each tag must include string span and label fields.")
|
|
if span not in text:
|
|
raise ValueError(f"Span not found in text: {span}")
|
|
normalized_tags.append({"span": span, "label": _normalize_label(label)})
|
|
return {"text": text, "tags": normalized_tags}
|
|
|
|
|
|
def preannotate_tokens(
|
|
texts: list[str], client: genai.Client, model: str
|
|
) -> list[dict]:
|
|
prompt = _build_prompt(texts)
|
|
response = client.models.generate_content(model=model, contents=prompt)
|
|
if response.text is None:
|
|
raise ValueError("Gemini returned an empty response.")
|
|
logger.bind(raw_gemini=True).trace(response.text)
|
|
parsed = json.loads(response.text)
|
|
if not isinstance(parsed, list) or len(parsed) != len(texts):
|
|
raise ValueError("Gemini response must be a JSON array matching input length.")
|
|
return [_normalize_result(text, result) for text, result in zip(texts, parsed)]
|
|
|
|
|
|
def _load_raw_records(path: Path) -> list[dict]:
|
|
raw = json.loads(path.read_text(encoding="utf-8"))
|
|
if not isinstance(raw, list):
|
|
raise ValueError(f"Dataset {path} must be a JSON array.")
|
|
for idx, item in enumerate(raw):
|
|
if not isinstance(item, dict) or "text" not in item:
|
|
raise ValueError(
|
|
f"Dataset {path} item {idx} must be an object with a text field."
|
|
)
|
|
return raw
|
|
|
|
|
|
def _load_preannotated_records(path: Path) -> list[dict]:
|
|
raw = json.loads(path.read_text(encoding="utf-8"))
|
|
if not isinstance(raw, list):
|
|
raise ValueError(f"Dataset {path} must be a JSON array.")
|
|
for idx, item in enumerate(raw):
|
|
if not isinstance(item, dict) or "text" not in item:
|
|
raise ValueError(
|
|
f"Dataset {path} item {idx} must be an object with a text field."
|
|
)
|
|
if "tags" not in item or not isinstance(item["tags"], list):
|
|
raise ValueError(
|
|
f"Dataset {path} item {idx} must include a tags list for conversion."
|
|
)
|
|
return raw
|
|
|
|
|
|
def _build_labelstudio_results(text: str, tags: list[dict]) -> list[dict]:
|
|
occurrences_map: dict[str, list[int]] = {}
|
|
occurrence_index: dict[str, int] = {}
|
|
results = []
|
|
for tag in tags:
|
|
if not isinstance(tag, dict):
|
|
raise ValueError("Each tag must be an object.")
|
|
span = tag.get("span")
|
|
label = tag.get("label")
|
|
if not isinstance(span, str) or not isinstance(label, str):
|
|
raise ValueError("Each tag must include string span and label fields.")
|
|
if span not in occurrences_map:
|
|
occurrences_map[span] = _all_occurrences(text, span)
|
|
occurrence_index[span] = 0
|
|
occurrences = occurrences_map[span]
|
|
idx = occurrence_index[span]
|
|
if idx >= len(occurrences):
|
|
raise ValueError(f"Span not found in text: {span}")
|
|
start = occurrences[idx]
|
|
end = start + len(span)
|
|
occurrence_index[span] = idx + 1
|
|
results.append(
|
|
{
|
|
"value": {
|
|
"start": start,
|
|
"end": end,
|
|
"text": span,
|
|
"labels": [_normalize_label(label)],
|
|
},
|
|
"id": _generate_id(),
|
|
"from_name": "label",
|
|
"to_name": "text",
|
|
"type": "labels",
|
|
"origin": "manual",
|
|
}
|
|
)
|
|
return results
|
|
|
|
|
|
def _preannotate_dataset(
|
|
input_path: Path,
|
|
output_path: Path,
|
|
client: genai.Client,
|
|
model: str,
|
|
batch_size: int,
|
|
) -> None:
|
|
raw_items = _load_raw_records(input_path)
|
|
texts = [item["text"] for item in raw_items]
|
|
annotated_items: list[dict | None] = [None] * len(raw_items)
|
|
|
|
logger.info("Pre-annotating {} records from {}", len(texts), input_path)
|
|
for batch_indices in _chunked(list(range(len(texts))), batch_size):
|
|
batch_texts = [texts[idx] for idx in batch_indices]
|
|
logger.debug(
|
|
"Sending batch {}-{} (size {}) to Gemini",
|
|
batch_indices[0],
|
|
batch_indices[-1],
|
|
len(batch_texts),
|
|
)
|
|
batch_results = preannotate_tokens(batch_texts, client, model)
|
|
for idx, result in zip(batch_indices, batch_results):
|
|
item = dict(raw_items[idx])
|
|
item["tags"] = result["tags"]
|
|
annotated_items[idx] = item
|
|
|
|
if any(item is None for item in annotated_items):
|
|
raise ValueError("Pre-annotation failed to produce results for all items.")
|
|
|
|
output_path.write_text(
|
|
json.dumps(annotated_items, indent=2, ensure_ascii=True), encoding="utf-8"
|
|
)
|
|
logger.info("Wrote pre-annotated dataset to {}", output_path)
|
|
|
|
|
|
def _convert_preannotated_dataset(input_path: Path, output_path: Path) -> None:
|
|
preannotated_items = _load_preannotated_records(input_path)
|
|
now_iso = dt.datetime.now(dt.UTC).isoformat()
|
|
tasks = []
|
|
logger.info("Converting {} records from {}", len(preannotated_items), input_path)
|
|
for index, item in enumerate(preannotated_items):
|
|
text = item["text"]
|
|
tags = item["tags"]
|
|
task_id = _generate_id()
|
|
annotation_id = index + 1
|
|
results = _build_labelstudio_results(text, tags)
|
|
annotation = {
|
|
"id": annotation_id,
|
|
"completed_by": 2,
|
|
"result": results,
|
|
"was_cancelled": False,
|
|
"ground_truth": False,
|
|
"created_at": now_iso,
|
|
"updated_at": now_iso,
|
|
"draft_created_at": now_iso,
|
|
"lead_time": 0.0,
|
|
"prediction": {},
|
|
"result_count": len(results),
|
|
"unique_id": _generate_id(),
|
|
"import_id": None,
|
|
"last_action": None,
|
|
"bulk_created": False,
|
|
"task": task_id,
|
|
"project": None,
|
|
"updated_by": None,
|
|
"parent_prediction": None,
|
|
"parent_annotation": None,
|
|
"last_created_by": None,
|
|
}
|
|
tasks.append(
|
|
{
|
|
"id": task_id,
|
|
"annotations": [annotation],
|
|
"file_upload": None,
|
|
"drafts": [],
|
|
"predictions": [],
|
|
"data": {"text": text},
|
|
"meta": {},
|
|
"created_at": now_iso,
|
|
"updated_at": now_iso,
|
|
"allow_skip": True,
|
|
"inner_id": index + 1,
|
|
"total_annotations": 1,
|
|
"cancelled_annotations": 0,
|
|
"total_predictions": 0,
|
|
"comment_count": 0,
|
|
"unresolved_comment_count": 0,
|
|
"last_comment_updated_at": None,
|
|
"project": None,
|
|
"updated_by": None,
|
|
"comment_authors": [],
|
|
}
|
|
)
|
|
|
|
output_path.write_text(
|
|
json.dumps(tasks, indent=2, ensure_ascii=True), encoding="utf-8"
|
|
)
|
|
logger.info("Wrote converted annotated dataset to {}", output_path)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Pre-annotate datasets with Gemini.")
|
|
parser.add_argument(
|
|
"--mode",
|
|
choices=["preannotate", "convert"],
|
|
default="preannotate",
|
|
)
|
|
parser.add_argument("--input-dir", default=None)
|
|
parser.add_argument("--output-dir", default=None)
|
|
parser.add_argument("--input-file", default=None)
|
|
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
|
|
parser.add_argument("--model", default=os.getenv("GEMINI_MODEL", DEFAULT_MODEL))
|
|
parser.add_argument(
|
|
"--raw-log-file",
|
|
default=os.getenv("GEMINI_RAW_LOG_FILE", DEFAULT_RAW_LOG_PATH),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == "preannotate":
|
|
raw_log_path = Path(args.raw_log_file)
|
|
raw_log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
logger.add(
|
|
raw_log_path,
|
|
level="TRACE",
|
|
filter=lambda record: record["extra"].get("raw_gemini") is True,
|
|
format="{message}",
|
|
)
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY")
|
|
if not api_key:
|
|
raise ValueError("GEMINI_API_KEY must be set in the environment.")
|
|
|
|
client = genai.Client(api_key=api_key)
|
|
input_dir = Path(args.input_dir or "datasets/raw")
|
|
output_dir = Path(args.output_dir or "datasets/preannotated")
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if args.input_file:
|
|
input_paths = [Path(args.input_file)]
|
|
else:
|
|
input_paths = sorted(input_dir.glob("*.json"))
|
|
|
|
if not input_paths:
|
|
raise ValueError(f"No input datasets found in {input_dir}.")
|
|
|
|
logger.info(
|
|
"Starting pre-annotation: model={}, batch_size={}, input_dir={}, output_dir={}",
|
|
args.model,
|
|
args.batch_size,
|
|
input_dir,
|
|
output_dir,
|
|
)
|
|
for input_path in input_paths:
|
|
output_path = output_dir / input_path.name
|
|
_preannotate_dataset(
|
|
input_path, output_path, client, args.model, args.batch_size
|
|
)
|
|
return
|
|
|
|
input_dir = Path(args.input_dir or DEFAULT_CONVERT_INPUT_DIR)
|
|
output_dir = Path(args.output_dir or DEFAULT_CONVERT_OUTPUT_DIR)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if args.input_file:
|
|
input_paths = [Path(args.input_file)]
|
|
else:
|
|
input_paths = sorted(input_dir.glob("*.json"))
|
|
|
|
if not input_paths:
|
|
raise ValueError(f"No input datasets found in {input_dir}.")
|
|
|
|
logger.info(
|
|
"Starting conversion: input_dir={}, output_dir={}",
|
|
input_dir,
|
|
output_dir,
|
|
)
|
|
for input_path in input_paths:
|
|
output_path = output_dir / input_path.name
|
|
_convert_preannotated_dataset(input_path, output_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|