import json import re from functools import partial from pathlib import Path import torch from datasets import Dataset from transformers import ( AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments, ) from enums import TokenLabel MODEL_NAME = "microsoft/MiniLM-L12-H384-uncased" MODEL_DIR = "./model" LABEL_LIST = [label.name for label in TokenLabel] LABEL2ID = {label: i for i, label in enumerate(LABEL_LIST)} ID2LABEL = {i: label for i, label in enumerate(LABEL_LIST)} VALUE_TO_NAME = {label.value: label.name for label in TokenLabel} def _normalize_label(label: str) -> str: if label in LABEL2ID: return label if label in VALUE_TO_NAME: return VALUE_TO_NAME[label] raise ValueError(f"Unknown label: {label}") def _text_to_tokens(text: str) -> tuple[list[str], list[tuple[int, int]]]: tokens = [] spans = [] for match in re.finditer(r"\S+", text): tokens.append(match.group(0)) spans.append(match.span()) return tokens, spans def _extract_spans(task: dict) -> list[tuple[int, int, str]]: annotations = task.get("annotations") or [] if not annotations: return [] results = annotations[0].get("result", []) or [] spans = [] for result in results: value = result.get("value", {}) start = value.get("start") end = value.get("end") labels = value.get("labels") or [] if start is None or end is None or not labels: continue spans.append((int(start), int(end), _normalize_label(labels[0]))) return spans def parse_annotated_dataset(path: str) -> list[dict]: data = json.loads(Path(path).read_text(encoding="utf-8")) if not isinstance(data, list): raise ValueError("Annotated dataset must be a JSON array.") examples = [] for task in data: text = task["data"]["text"] spans = _extract_spans(task) tokens, token_spans = _text_to_tokens(text) labels = [] for token_start, token_end in token_spans: assigned = "RAW_PHRASE" for span_start, span_end, span_label in spans: if token_start < span_end and token_end > span_start: assigned = span_label break labels.append(assigned) examples.append({"tokens": tokens, "labels": labels}) return examples def tokenize_and_align(example, tokenizer, label2id): tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True) word_ids = tokenized.word_ids() labels = [] prev = None for word_id in word_ids: if word_id is None: labels.append(-100) elif word_id != prev: labels.append(label2id[example["labels"][word_id]]) else: labels.append(-100) prev = word_id tokenized["labels"] = labels return tokenized def train(dataset_path: str = "./datasets/annotated/ffmpeg_gemini_v1.json", model_dir: str = MODEL_DIR): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForTokenClassification.from_pretrained( MODEL_NAME, num_labels=len(LABEL_LIST), id2label=ID2LABEL, label2id=LABEL2ID ) examples = parse_annotated_dataset(dataset_path) dataset = Dataset.from_list(examples) dataset = dataset.map(partial(tokenize_and_align, tokenizer=tokenizer, label2id=LABEL2ID)) dataset = dataset.train_test_split(test_size=0.1) training_args = TrainingArguments( output_dir=model_dir, learning_rate=3e-5, per_device_train_batch_size=8, num_train_epochs=5, weight_decay=0.01, logging_steps=10, save_strategy="epoch", ) data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], data_collator=data_collator, ) trainer.train() trainer.save_model(model_dir) tokenizer.save_pretrained(model_dir) def _load_model(model_dir: str = MODEL_DIR): tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForTokenClassification.from_pretrained(model_dir) return tokenizer, model def predict(text: str, model_dir: str = MODEL_DIR, top_k: int = 3): tokenizer, model = _load_model(model_dir) inputs = tokenizer(text.split(), return_tensors="pt", is_split_into_words=True) outputs = model(**inputs) logits = outputs.logits[0] probs = torch.softmax(logits, dim=-1) k = min(max(1, top_k), probs.shape[-1]) top_probs, top_ids = torch.topk(probs, k, dim=-1) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) for token, prob_row, id_row in zip(tokens, top_probs, top_ids): preds = [ f"{ID2LABEL.get(int(label_id), 'RAW_PHRASE')}:{float(score):.4f}" for label_id, score in zip(id_row, prob_row) ] print(token, ", ".join(preds)) if __name__ == "__main__": train()