diff --git a/.gitignore b/.gitignore index e2880ce..5fc5ffc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ build/ dist/ wheels/ *.egg-info +model/* # Virtual environments .venv diff --git a/1.txt b/1.txt new file mode 100644 index 0000000..c0c338e --- /dev/null +++ b/1.txt @@ -0,0 +1 @@ +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 \ No newline at end of file diff --git a/mini_lm.py b/mini_lm.py index 90fddb7..6cd014b 100644 --- a/mini_lm.py +++ b/mini_lm.py @@ -1,5 +1,4 @@ import json -import re from functools import partial from pathlib import Path @@ -32,16 +31,7 @@ def _normalize_label(label: str) -> str: raise ValueError(f"Unknown label: {label}") -def _text_to_tokens(text: str) -> tuple[list[str], list[tuple[int, int]]]: - tokens = [] - spans = [] - for match in re.finditer(r"\S+", text): - tokens.append(match.group(0)) - spans.append(match.span()) - return tokens, spans - - -def _extract_spans(task: dict) -> list[tuple[int, int, str]]: +def _extract_spans(task: dict) -> list[dict]: annotations = task.get("annotations") or [] if not annotations: return [] @@ -54,7 +44,13 @@ def _extract_spans(task: dict) -> list[tuple[int, int, str]]: labels = value.get("labels") or [] if start is None or end is None or not labels: continue - spans.append((int(start), int(end), _normalize_label(labels[0]))) + spans.append( + { + "start": int(start), + "end": int(end), + "label": _normalize_label(labels[0]), + } + ) return spans @@ -67,36 +63,42 @@ def parse_annotated_dataset(path: str) -> list[dict]: for task in data: text = task["data"]["text"] spans = _extract_spans(task) - tokens, token_spans = _text_to_tokens(text) - labels = [] - for token_start, token_end in token_spans: - assigned = "RAW_PHRASE" - for span_start, span_end, span_label in spans: - if token_start < span_end and token_end > span_start: - assigned = span_label - break - labels.append(assigned) - examples.append({"tokens": tokens, "labels": labels}) + examples.append({"text": text, "spans": spans}) return examples -def tokenize_and_align(example, tokenizer, label2id): - tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True) - word_ids = tokenized.word_ids() - labels = [] - prev = None +def _label_for_offset(start: int, end: int, spans: list[dict]) -> str: + best_label = "RAW_PHRASE" + best_overlap = 0 + for span in spans: + span_start = span["start"] + span_end = span["end"] + span_label = span["label"] + overlap = min(end, span_end) - max(start, span_start) + if overlap > best_overlap: + best_overlap = overlap + best_label = span_label + return best_label - for word_id in word_ids: - if word_id is None: + +def tokenize_and_align(example, tokenizer, label2id): + tokenized = tokenizer( + example["text"], + truncation=True, + return_offsets_mapping=True, + ) + labels = [] + + for start, end in tokenized["offset_mapping"]: + if start == end: labels.append(-100) - elif word_id != prev: - labels.append(label2id[example["labels"][word_id]]) else: - labels.append(-100) - prev = word_id + label_name = _label_for_offset(start, end, example["spans"]) + labels.append(label2id[label_name]) tokenized["labels"] = labels + tokenized.pop("offset_mapping") return tokenized @@ -144,7 +146,7 @@ def _load_model(model_dir: str = MODEL_DIR): def predict(text: str, model_dir: str = MODEL_DIR, top_k: int = 3): tokenizer, model = _load_model(model_dir) - inputs = tokenizer(text.split(), return_tensors="pt", is_split_into_words=True) + inputs = tokenizer(text, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits[0]