added base datasets
This commit is contained in:
359
main.py
Normal file
359
main.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from google import genai
|
||||
from loguru import logger
|
||||
|
||||
from .enums import TokenLabel
|
||||
|
||||
DEFAULT_BATCH_SIZE = 20
|
||||
DEFAULT_MODEL = "gemini-2.5-flash-lite"
|
||||
DEFAULT_RAW_LOG_PATH = "logs/gemini_raw.log"
|
||||
DEFAULT_CONVERT_INPUT_DIR = "datasets/preannotated"
|
||||
DEFAULT_CONVERT_OUTPUT_DIR = "datasets/annotated"
|
||||
LABEL_NAMES = [label.name for label in TokenLabel]
|
||||
VALUE_TO_NAME = {label.value: label.name for label in TokenLabel}
|
||||
ID_ALPHABET = string.ascii_uppercase + string.digits
|
||||
|
||||
|
||||
def _chunked(items: list[int], size: int) -> Iterable[list[int]]:
|
||||
if size < 1:
|
||||
raise ValueError("batch_size must be at least 1.")
|
||||
for start in range(0, len(items), size):
|
||||
yield items[start : start + size]
|
||||
|
||||
|
||||
def _generate_id() -> str:
|
||||
return (
|
||||
f"{''.join(secrets.choice(ID_ALPHABET) for _ in range(3))}"
|
||||
f"-{''.join(secrets.choice(ID_ALPHABET) for _ in range(6))}"
|
||||
)
|
||||
|
||||
|
||||
def _all_occurrences(text: str, span: str) -> list[int]:
|
||||
occurrences = []
|
||||
start = 0
|
||||
while True:
|
||||
idx = text.find(span, start)
|
||||
if idx == -1:
|
||||
break
|
||||
occurrences.append(idx)
|
||||
start = idx + 1
|
||||
return occurrences
|
||||
|
||||
|
||||
def _build_prompt(texts: list[str]) -> str:
|
||||
labels = ", ".join(LABEL_NAMES)
|
||||
return (
|
||||
"You are a token pre-annotator. For each input text, return JSON with tagged "
|
||||
"token/subword/word/span labels.\n"
|
||||
f"Allowed labels: {labels}.\n"
|
||||
"Rules:\n"
|
||||
"- Output ONLY valid JSON (no markdown).\n"
|
||||
"- Return a JSON array with the same length/order as the input.\n"
|
||||
"- Each item must be an object: "
|
||||
'{"text": "<original>", "tags": [{"span": "<exact substring>", "label": "<LABEL>"}]}.\n'
|
||||
"- The span must be an exact substring of the original text.\n"
|
||||
"- Use RAW_PHRASE when no other label applies.\n\n"
|
||||
f"Input texts: {json.dumps(texts, ensure_ascii=True)}"
|
||||
)
|
||||
|
||||
|
||||
def _normalize_label(label: str) -> str:
|
||||
if label in LABEL_NAMES:
|
||||
return label
|
||||
if label in VALUE_TO_NAME:
|
||||
return VALUE_TO_NAME[label]
|
||||
raise ValueError(f"Unknown label: {label}")
|
||||
|
||||
|
||||
def _normalize_result(text: str, result: dict) -> dict:
|
||||
if result.get("text") != text:
|
||||
raise ValueError("Gemini result text does not match input text.")
|
||||
tags = result.get("tags")
|
||||
if not isinstance(tags, list):
|
||||
raise ValueError("Gemini result tags must be a list.")
|
||||
normalized_tags = []
|
||||
for tag in tags:
|
||||
if not isinstance(tag, dict):
|
||||
raise ValueError("Each tag must be an object.")
|
||||
span = tag.get("span")
|
||||
label = tag.get("label")
|
||||
if not isinstance(span, str) or not isinstance(label, str):
|
||||
raise ValueError("Each tag must include string span and label fields.")
|
||||
if span not in text:
|
||||
raise ValueError(f"Span not found in text: {span}")
|
||||
normalized_tags.append({"span": span, "label": _normalize_label(label)})
|
||||
return {"text": text, "tags": normalized_tags}
|
||||
|
||||
|
||||
def preannotate_tokens(
|
||||
texts: list[str], client: genai.Client, model: str
|
||||
) -> list[dict]:
|
||||
prompt = _build_prompt(texts)
|
||||
response = client.models.generate_content(model=model, contents=prompt)
|
||||
if response.text is None:
|
||||
raise ValueError("Gemini returned an empty response.")
|
||||
logger.bind(raw_gemini=True).trace(response.text)
|
||||
parsed = json.loads(response.text)
|
||||
if not isinstance(parsed, list) or len(parsed) != len(texts):
|
||||
raise ValueError("Gemini response must be a JSON array matching input length.")
|
||||
return [_normalize_result(text, result) for text, result in zip(texts, parsed)]
|
||||
|
||||
|
||||
def _load_raw_records(path: Path) -> list[dict]:
|
||||
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"Dataset {path} must be a JSON array.")
|
||||
for idx, item in enumerate(raw):
|
||||
if not isinstance(item, dict) or "text" not in item:
|
||||
raise ValueError(
|
||||
f"Dataset {path} item {idx} must be an object with a text field."
|
||||
)
|
||||
return raw
|
||||
|
||||
|
||||
def _load_preannotated_records(path: Path) -> list[dict]:
|
||||
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"Dataset {path} must be a JSON array.")
|
||||
for idx, item in enumerate(raw):
|
||||
if not isinstance(item, dict) or "text" not in item:
|
||||
raise ValueError(
|
||||
f"Dataset {path} item {idx} must be an object with a text field."
|
||||
)
|
||||
if "tags" not in item or not isinstance(item["tags"], list):
|
||||
raise ValueError(
|
||||
f"Dataset {path} item {idx} must include a tags list for conversion."
|
||||
)
|
||||
return raw
|
||||
|
||||
|
||||
def _build_labelstudio_results(text: str, tags: list[dict]) -> list[dict]:
|
||||
occurrences_map: dict[str, list[int]] = {}
|
||||
occurrence_index: dict[str, int] = {}
|
||||
results = []
|
||||
for tag in tags:
|
||||
if not isinstance(tag, dict):
|
||||
raise ValueError("Each tag must be an object.")
|
||||
span = tag.get("span")
|
||||
label = tag.get("label")
|
||||
if not isinstance(span, str) or not isinstance(label, str):
|
||||
raise ValueError("Each tag must include string span and label fields.")
|
||||
if span not in occurrences_map:
|
||||
occurrences_map[span] = _all_occurrences(text, span)
|
||||
occurrence_index[span] = 0
|
||||
occurrences = occurrences_map[span]
|
||||
idx = occurrence_index[span]
|
||||
if idx >= len(occurrences):
|
||||
raise ValueError(f"Span not found in text: {span}")
|
||||
start = occurrences[idx]
|
||||
end = start + len(span)
|
||||
occurrence_index[span] = idx + 1
|
||||
results.append(
|
||||
{
|
||||
"value": {
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": span,
|
||||
"labels": [_normalize_label(label)],
|
||||
},
|
||||
"id": _generate_id(),
|
||||
"from_name": "label",
|
||||
"to_name": "text",
|
||||
"type": "labels",
|
||||
"origin": "manual",
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def _preannotate_dataset(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
client: genai.Client,
|
||||
model: str,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
raw_items = _load_raw_records(input_path)
|
||||
texts = [item["text"] for item in raw_items]
|
||||
annotated_items: list[dict | None] = [None] * len(raw_items)
|
||||
|
||||
logger.info("Pre-annotating {} records from {}", len(texts), input_path)
|
||||
for batch_indices in _chunked(list(range(len(texts))), batch_size):
|
||||
batch_texts = [texts[idx] for idx in batch_indices]
|
||||
logger.debug(
|
||||
"Sending batch {}-{} (size {}) to Gemini",
|
||||
batch_indices[0],
|
||||
batch_indices[-1],
|
||||
len(batch_texts),
|
||||
)
|
||||
batch_results = preannotate_tokens(batch_texts, client, model)
|
||||
for idx, result in zip(batch_indices, batch_results):
|
||||
item = dict(raw_items[idx])
|
||||
item["tags"] = result["tags"]
|
||||
annotated_items[idx] = item
|
||||
|
||||
if any(item is None for item in annotated_items):
|
||||
raise ValueError("Pre-annotation failed to produce results for all items.")
|
||||
|
||||
output_path.write_text(
|
||||
json.dumps(annotated_items, indent=2, ensure_ascii=True), encoding="utf-8"
|
||||
)
|
||||
logger.info("Wrote pre-annotated dataset to {}", output_path)
|
||||
|
||||
|
||||
def _convert_preannotated_dataset(input_path: Path, output_path: Path) -> None:
|
||||
preannotated_items = _load_preannotated_records(input_path)
|
||||
now_iso = dt.datetime.now(dt.UTC).isoformat()
|
||||
tasks = []
|
||||
logger.info("Converting {} records from {}", len(preannotated_items), input_path)
|
||||
for index, item in enumerate(preannotated_items):
|
||||
text = item["text"]
|
||||
tags = item["tags"]
|
||||
task_id = _generate_id()
|
||||
annotation_id = index + 1
|
||||
results = _build_labelstudio_results(text, tags)
|
||||
annotation = {
|
||||
"id": annotation_id,
|
||||
"completed_by": 2,
|
||||
"result": results,
|
||||
"was_cancelled": False,
|
||||
"ground_truth": False,
|
||||
"created_at": now_iso,
|
||||
"updated_at": now_iso,
|
||||
"draft_created_at": now_iso,
|
||||
"lead_time": 0.0,
|
||||
"prediction": {},
|
||||
"result_count": len(results),
|
||||
"unique_id": _generate_id(),
|
||||
"import_id": None,
|
||||
"last_action": None,
|
||||
"bulk_created": False,
|
||||
"task": task_id,
|
||||
"project": None,
|
||||
"updated_by": None,
|
||||
"parent_prediction": None,
|
||||
"parent_annotation": None,
|
||||
"last_created_by": None,
|
||||
}
|
||||
tasks.append(
|
||||
{
|
||||
"id": task_id,
|
||||
"annotations": [annotation],
|
||||
"file_upload": None,
|
||||
"drafts": [],
|
||||
"predictions": [],
|
||||
"data": {"text": text},
|
||||
"meta": {},
|
||||
"created_at": now_iso,
|
||||
"updated_at": now_iso,
|
||||
"allow_skip": True,
|
||||
"inner_id": index + 1,
|
||||
"total_annotations": 1,
|
||||
"cancelled_annotations": 0,
|
||||
"total_predictions": 0,
|
||||
"comment_count": 0,
|
||||
"unresolved_comment_count": 0,
|
||||
"last_comment_updated_at": None,
|
||||
"project": None,
|
||||
"updated_by": None,
|
||||
"comment_authors": [],
|
||||
}
|
||||
)
|
||||
|
||||
output_path.write_text(
|
||||
json.dumps(tasks, indent=2, ensure_ascii=True), encoding="utf-8"
|
||||
)
|
||||
logger.info("Wrote converted annotated dataset to {}", output_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Pre-annotate datasets with Gemini.")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["preannotate", "convert"],
|
||||
default="preannotate",
|
||||
)
|
||||
parser.add_argument("--input-dir", default=None)
|
||||
parser.add_argument("--output-dir", default=None)
|
||||
parser.add_argument("--input-file", default=None)
|
||||
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
|
||||
parser.add_argument("--model", default=os.getenv("GEMINI_MODEL", DEFAULT_MODEL))
|
||||
parser.add_argument(
|
||||
"--raw-log-file",
|
||||
default=os.getenv("GEMINI_RAW_LOG_FILE", DEFAULT_RAW_LOG_PATH),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "preannotate":
|
||||
raw_log_path = Path(args.raw_log_file)
|
||||
raw_log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
logger.add(
|
||||
raw_log_path,
|
||||
level="TRACE",
|
||||
filter=lambda record: record["extra"].get("raw_gemini") is True,
|
||||
format="{message}",
|
||||
)
|
||||
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GEMINI_API_KEY must be set in the environment.")
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
input_dir = Path(args.input_dir or "datasets/raw")
|
||||
output_dir = Path(args.output_dir or "datasets/preannotated")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.input_file:
|
||||
input_paths = [Path(args.input_file)]
|
||||
else:
|
||||
input_paths = sorted(input_dir.glob("*.json"))
|
||||
|
||||
if not input_paths:
|
||||
raise ValueError(f"No input datasets found in {input_dir}.")
|
||||
|
||||
logger.info(
|
||||
"Starting pre-annotation: model={}, batch_size={}, input_dir={}, output_dir={}",
|
||||
args.model,
|
||||
args.batch_size,
|
||||
input_dir,
|
||||
output_dir,
|
||||
)
|
||||
for input_path in input_paths:
|
||||
output_path = output_dir / input_path.name
|
||||
_preannotate_dataset(
|
||||
input_path, output_path, client, args.model, args.batch_size
|
||||
)
|
||||
return
|
||||
|
||||
input_dir = Path(args.input_dir or DEFAULT_CONVERT_INPUT_DIR)
|
||||
output_dir = Path(args.output_dir or DEFAULT_CONVERT_OUTPUT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.input_file:
|
||||
input_paths = [Path(args.input_file)]
|
||||
else:
|
||||
input_paths = sorted(input_dir.glob("*.json"))
|
||||
|
||||
if not input_paths:
|
||||
raise ValueError(f"No input datasets found in {input_dir}.")
|
||||
|
||||
logger.info(
|
||||
"Starting conversion: input_dir={}, output_dir={}",
|
||||
input_dir,
|
||||
output_dir,
|
||||
)
|
||||
for input_path in input_paths:
|
||||
output_path = output_dir / input_path.name
|
||||
_convert_preannotated_dataset(input_path, output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user