Files
clint-dataset/main.py
2026-04-07 22:00:40 +05:30

360 lines
12 KiB
Python

import argparse
import datetime as dt
import json
import os
import secrets
import string
from pathlib import Path
from typing import Iterable
from google import genai
from loguru import logger
from .enums import TokenLabel
DEFAULT_BATCH_SIZE = 20
DEFAULT_MODEL = "gemini-2.5-flash-lite"
DEFAULT_RAW_LOG_PATH = "logs/gemini_raw.log"
DEFAULT_CONVERT_INPUT_DIR = "datasets/preannotated"
DEFAULT_CONVERT_OUTPUT_DIR = "datasets/annotated"
LABEL_NAMES = [label.name for label in TokenLabel]
VALUE_TO_NAME = {label.value: label.name for label in TokenLabel}
ID_ALPHABET = string.ascii_uppercase + string.digits
def _chunked(items: list[int], size: int) -> Iterable[list[int]]:
if size < 1:
raise ValueError("batch_size must be at least 1.")
for start in range(0, len(items), size):
yield items[start : start + size]
def _generate_id() -> str:
return (
f"{''.join(secrets.choice(ID_ALPHABET) for _ in range(3))}"
f"-{''.join(secrets.choice(ID_ALPHABET) for _ in range(6))}"
)
def _all_occurrences(text: str, span: str) -> list[int]:
occurrences = []
start = 0
while True:
idx = text.find(span, start)
if idx == -1:
break
occurrences.append(idx)
start = idx + 1
return occurrences
def _build_prompt(texts: list[str]) -> str:
labels = ", ".join(LABEL_NAMES)
return (
"You are a token pre-annotator. For each input text, return JSON with tagged "
"token/subword/word/span labels.\n"
f"Allowed labels: {labels}.\n"
"Rules:\n"
"- Output ONLY valid JSON (no markdown).\n"
"- Return a JSON array with the same length/order as the input.\n"
"- Each item must be an object: "
'{"text": "<original>", "tags": [{"span": "<exact substring>", "label": "<LABEL>"}]}.\n'
"- The span must be an exact substring of the original text.\n"
"- Use RAW_PHRASE when no other label applies.\n\n"
f"Input texts: {json.dumps(texts, ensure_ascii=True)}"
)
def _normalize_label(label: str) -> str:
if label in LABEL_NAMES:
return label
if label in VALUE_TO_NAME:
return VALUE_TO_NAME[label]
raise ValueError(f"Unknown label: {label}")
def _normalize_result(text: str, result: dict) -> dict:
if result.get("text") != text:
raise ValueError("Gemini result text does not match input text.")
tags = result.get("tags")
if not isinstance(tags, list):
raise ValueError("Gemini result tags must be a list.")
normalized_tags = []
for tag in tags:
if not isinstance(tag, dict):
raise ValueError("Each tag must be an object.")
span = tag.get("span")
label = tag.get("label")
if not isinstance(span, str) or not isinstance(label, str):
raise ValueError("Each tag must include string span and label fields.")
if span not in text:
raise ValueError(f"Span not found in text: {span}")
normalized_tags.append({"span": span, "label": _normalize_label(label)})
return {"text": text, "tags": normalized_tags}
def preannotate_tokens(
texts: list[str], client: genai.Client, model: str
) -> list[dict]:
prompt = _build_prompt(texts)
response = client.models.generate_content(model=model, contents=prompt)
if response.text is None:
raise ValueError("Gemini returned an empty response.")
logger.bind(raw_gemini=True).trace(response.text)
parsed = json.loads(response.text)
if not isinstance(parsed, list) or len(parsed) != len(texts):
raise ValueError("Gemini response must be a JSON array matching input length.")
return [_normalize_result(text, result) for text, result in zip(texts, parsed)]
def _load_raw_records(path: Path) -> list[dict]:
raw = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(raw, list):
raise ValueError(f"Dataset {path} must be a JSON array.")
for idx, item in enumerate(raw):
if not isinstance(item, dict) or "text" not in item:
raise ValueError(
f"Dataset {path} item {idx} must be an object with a text field."
)
return raw
def _load_preannotated_records(path: Path) -> list[dict]:
raw = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(raw, list):
raise ValueError(f"Dataset {path} must be a JSON array.")
for idx, item in enumerate(raw):
if not isinstance(item, dict) or "text" not in item:
raise ValueError(
f"Dataset {path} item {idx} must be an object with a text field."
)
if "tags" not in item or not isinstance(item["tags"], list):
raise ValueError(
f"Dataset {path} item {idx} must include a tags list for conversion."
)
return raw
def _build_labelstudio_results(text: str, tags: list[dict]) -> list[dict]:
occurrences_map: dict[str, list[int]] = {}
occurrence_index: dict[str, int] = {}
results = []
for tag in tags:
if not isinstance(tag, dict):
raise ValueError("Each tag must be an object.")
span = tag.get("span")
label = tag.get("label")
if not isinstance(span, str) or not isinstance(label, str):
raise ValueError("Each tag must include string span and label fields.")
if span not in occurrences_map:
occurrences_map[span] = _all_occurrences(text, span)
occurrence_index[span] = 0
occurrences = occurrences_map[span]
idx = occurrence_index[span]
if idx >= len(occurrences):
raise ValueError(f"Span not found in text: {span}")
start = occurrences[idx]
end = start + len(span)
occurrence_index[span] = idx + 1
results.append(
{
"value": {
"start": start,
"end": end,
"text": span,
"labels": [_normalize_label(label)],
},
"id": _generate_id(),
"from_name": "label",
"to_name": "text",
"type": "labels",
"origin": "manual",
}
)
return results
def _preannotate_dataset(
input_path: Path,
output_path: Path,
client: genai.Client,
model: str,
batch_size: int,
) -> None:
raw_items = _load_raw_records(input_path)
texts = [item["text"] for item in raw_items]
annotated_items: list[dict | None] = [None] * len(raw_items)
logger.info("Pre-annotating {} records from {}", len(texts), input_path)
for batch_indices in _chunked(list(range(len(texts))), batch_size):
batch_texts = [texts[idx] for idx in batch_indices]
logger.debug(
"Sending batch {}-{} (size {}) to Gemini",
batch_indices[0],
batch_indices[-1],
len(batch_texts),
)
batch_results = preannotate_tokens(batch_texts, client, model)
for idx, result in zip(batch_indices, batch_results):
item = dict(raw_items[idx])
item["tags"] = result["tags"]
annotated_items[idx] = item
if any(item is None for item in annotated_items):
raise ValueError("Pre-annotation failed to produce results for all items.")
output_path.write_text(
json.dumps(annotated_items, indent=2, ensure_ascii=True), encoding="utf-8"
)
logger.info("Wrote pre-annotated dataset to {}", output_path)
def _convert_preannotated_dataset(input_path: Path, output_path: Path) -> None:
preannotated_items = _load_preannotated_records(input_path)
now_iso = dt.datetime.now(dt.UTC).isoformat()
tasks = []
logger.info("Converting {} records from {}", len(preannotated_items), input_path)
for index, item in enumerate(preannotated_items):
text = item["text"]
tags = item["tags"]
task_id = _generate_id()
annotation_id = index + 1
results = _build_labelstudio_results(text, tags)
annotation = {
"id": annotation_id,
"completed_by": 2,
"result": results,
"was_cancelled": False,
"ground_truth": False,
"created_at": now_iso,
"updated_at": now_iso,
"draft_created_at": now_iso,
"lead_time": 0.0,
"prediction": {},
"result_count": len(results),
"unique_id": _generate_id(),
"import_id": None,
"last_action": None,
"bulk_created": False,
"task": task_id,
"project": None,
"updated_by": None,
"parent_prediction": None,
"parent_annotation": None,
"last_created_by": None,
}
tasks.append(
{
"id": task_id,
"annotations": [annotation],
"file_upload": None,
"drafts": [],
"predictions": [],
"data": {"text": text},
"meta": {},
"created_at": now_iso,
"updated_at": now_iso,
"allow_skip": True,
"inner_id": index + 1,
"total_annotations": 1,
"cancelled_annotations": 0,
"total_predictions": 0,
"comment_count": 0,
"unresolved_comment_count": 0,
"last_comment_updated_at": None,
"project": None,
"updated_by": None,
"comment_authors": [],
}
)
output_path.write_text(
json.dumps(tasks, indent=2, ensure_ascii=True), encoding="utf-8"
)
logger.info("Wrote converted annotated dataset to {}", output_path)
def main():
parser = argparse.ArgumentParser(description="Pre-annotate datasets with Gemini.")
parser.add_argument(
"--mode",
choices=["preannotate", "convert"],
default="preannotate",
)
parser.add_argument("--input-dir", default=None)
parser.add_argument("--output-dir", default=None)
parser.add_argument("--input-file", default=None)
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
parser.add_argument("--model", default=os.getenv("GEMINI_MODEL", DEFAULT_MODEL))
parser.add_argument(
"--raw-log-file",
default=os.getenv("GEMINI_RAW_LOG_FILE", DEFAULT_RAW_LOG_PATH),
)
args = parser.parse_args()
if args.mode == "preannotate":
raw_log_path = Path(args.raw_log_file)
raw_log_path.parent.mkdir(parents=True, exist_ok=True)
logger.add(
raw_log_path,
level="TRACE",
filter=lambda record: record["extra"].get("raw_gemini") is True,
format="{message}",
)
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY must be set in the environment.")
client = genai.Client(api_key=api_key)
input_dir = Path(args.input_dir or "datasets/raw")
output_dir = Path(args.output_dir or "datasets/preannotated")
output_dir.mkdir(parents=True, exist_ok=True)
if args.input_file:
input_paths = [Path(args.input_file)]
else:
input_paths = sorted(input_dir.glob("*.json"))
if not input_paths:
raise ValueError(f"No input datasets found in {input_dir}.")
logger.info(
"Starting pre-annotation: model={}, batch_size={}, input_dir={}, output_dir={}",
args.model,
args.batch_size,
input_dir,
output_dir,
)
for input_path in input_paths:
output_path = output_dir / input_path.name
_preannotate_dataset(
input_path, output_path, client, args.model, args.batch_size
)
return
input_dir = Path(args.input_dir or DEFAULT_CONVERT_INPUT_DIR)
output_dir = Path(args.output_dir or DEFAULT_CONVERT_OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
if args.input_file:
input_paths = [Path(args.input_file)]
else:
input_paths = sorted(input_dir.glob("*.json"))
if not input_paths:
raise ValueError(f"No input datasets found in {input_dir}.")
logger.info(
"Starting conversion: input_dir={}, output_dir={}",
input_dir,
output_dir,
)
for input_path in input_paths:
output_path = output_dir / input_path.name
_convert_preannotated_dataset(input_path, output_path)
if __name__ == "__main__":
main()