minilm testing
This commit is contained in:
10
README.md
10
README.md
@@ -4,6 +4,16 @@ Dataset for labelling queries containing tasks in natural language, with a focus
|
||||
|
||||
These queries were generated by prompting various commercially available LLMs and were pre-annotated using Gemini-2.5-flash-lite. They were then converted to a label studio supported format and then annotations were manually revised.
|
||||
|
||||
## Training
|
||||
|
||||
```bash
|
||||
uv run mini_lm.py
|
||||
```
|
||||
|
||||
```bash
|
||||
uv run inference.py --top-k
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
|
||||
@@ -55,7 +55,16 @@ if __name__ == "__main__":
|
||||
# number based menu to match file
|
||||
for file in annotated_dataset_list:
|
||||
print(f"{annotated_dataset_list.index(file)}: {file}")
|
||||
_ = input("Enter the number of the annotated dataset to analyze: ")
|
||||
path = annotated_dataset_list[int(_)]
|
||||
counts = parse_annotated(path)
|
||||
print(json.dumps(counts, indent=2))
|
||||
print("a: all annotated datasets")
|
||||
selection = input("Enter the number of the annotated dataset to analyze: ")
|
||||
if selection.lower() == "a":
|
||||
combined_counts = {label.name: 0 for label in TokenLabel}
|
||||
for dataset_path in annotated_dataset_list:
|
||||
counts = parse_annotated(dataset_path)
|
||||
for label, value in counts.items():
|
||||
combined_counts[label] += value
|
||||
print(json.dumps(combined_counts, indent=2))
|
||||
else:
|
||||
path = annotated_dataset_list[int(selection)]
|
||||
counts = parse_annotated(path)
|
||||
print(json.dumps(counts, indent=2))
|
||||
|
||||
90
dataset_convertor.py
Normal file
90
dataset_convertor.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import json
|
||||
|
||||
|
||||
def parse_json_dataset(path):
|
||||
with open(path, "r") as f:
|
||||
tasks = json.load(f)
|
||||
|
||||
for task in tasks:
|
||||
label_line = ""
|
||||
text_line = ""
|
||||
|
||||
full_text = task["data"]["text"]
|
||||
def append_segment(segment_text, segment_label):
|
||||
label = f"{segment_label} "
|
||||
text = f"{segment_text} "
|
||||
if len(text) > len(label):
|
||||
label_line_part = f"{label + ' ' * (len(text) - len(label))}"
|
||||
text_line_part = text
|
||||
elif len(text) < len(label):
|
||||
text_line_part = f"{text + ' ' * (len(label) - len(text))}"
|
||||
label_line_part = label
|
||||
else:
|
||||
text_line_part = text
|
||||
label_line_part = label
|
||||
return text_line_part, label_line_part
|
||||
|
||||
def append_gap(gap_text):
|
||||
segments = []
|
||||
if not gap_text:
|
||||
return segments
|
||||
start = 0
|
||||
end = len(gap_text)
|
||||
while start < end and gap_text[start].isspace():
|
||||
start += 1
|
||||
while end > start and gap_text[end - 1].isspace():
|
||||
end -= 1
|
||||
leading = gap_text[:start]
|
||||
middle = gap_text[start:end]
|
||||
trailing = gap_text[end:]
|
||||
if leading:
|
||||
text_line_append = leading
|
||||
label_line_append = " " * len(leading)
|
||||
segments.append((text_line_append, label_line_append))
|
||||
if middle:
|
||||
text_part, label_part = append_segment(middle, "0")
|
||||
segments.append((text_part, label_part))
|
||||
if trailing:
|
||||
text_line_append = trailing
|
||||
label_line_append = " " * len(trailing)
|
||||
segments.append((text_line_append, label_line_append))
|
||||
return segments
|
||||
|
||||
results = []
|
||||
annotations = task.get("annotations") or []
|
||||
if annotations:
|
||||
results = annotations[0].get("result", [])
|
||||
results = sorted(results, key=lambda item: item["value"]["start"])
|
||||
|
||||
cursor = 0
|
||||
for annotation in results:
|
||||
start = annotation["value"]["start"]
|
||||
end = annotation["value"]["end"]
|
||||
label = annotation["value"]["labels"][0]
|
||||
|
||||
if cursor < start:
|
||||
gap_text = full_text[cursor:start]
|
||||
for text_part, label_part in append_gap(gap_text):
|
||||
text_line += text_part
|
||||
label_line += label_part
|
||||
|
||||
text_segment = full_text[start:end]
|
||||
text_part, label_part = append_segment(text_segment, label)
|
||||
text_line += text_part
|
||||
label_line += label_part
|
||||
cursor = end
|
||||
|
||||
if cursor < len(full_text):
|
||||
gap_text = full_text[cursor:]
|
||||
for text_part, label_part in append_gap(gap_text):
|
||||
text_line += text_part
|
||||
label_line += label_part
|
||||
print(text_line)
|
||||
print(label_line)
|
||||
print("\n")
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parse_json_dataset("./datasets/annotated/ffmpeg_gemini_v1.json")
|
||||
File diff suppressed because one or more lines are too long
13
inference.py
Normal file
13
inference.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import argparse
|
||||
|
||||
from mini_lm import predict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run token classification inference.")
|
||||
parser.add_argument("text", nargs="*", help="Input text to label.")
|
||||
parser.add_argument("--top-k", type=int, default=3)
|
||||
args = parser.parse_args()
|
||||
|
||||
text = " ".join(args.text).strip() or "convert foo.mp4 to mkv"
|
||||
predict(text, top_k=args.top_k)
|
||||
165
mini_lm.py
Normal file
165
mini_lm.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
AutoModelForTokenClassification,
|
||||
AutoTokenizer,
|
||||
DataCollatorForTokenClassification,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from enums import TokenLabel
|
||||
|
||||
MODEL_NAME = "microsoft/MiniLM-L12-H384-uncased"
|
||||
MODEL_DIR = "./model"
|
||||
|
||||
LABEL_LIST = [label.name for label in TokenLabel]
|
||||
LABEL2ID = {label: i for i, label in enumerate(LABEL_LIST)}
|
||||
ID2LABEL = {i: label for i, label in enumerate(LABEL_LIST)}
|
||||
VALUE_TO_NAME = {label.value: label.name for label in TokenLabel}
|
||||
|
||||
|
||||
def _normalize_label(label: str) -> str:
|
||||
if label in LABEL2ID:
|
||||
return label
|
||||
if label in VALUE_TO_NAME:
|
||||
return VALUE_TO_NAME[label]
|
||||
raise ValueError(f"Unknown label: {label}")
|
||||
|
||||
|
||||
def _text_to_tokens(text: str) -> tuple[list[str], list[tuple[int, int]]]:
|
||||
tokens = []
|
||||
spans = []
|
||||
for match in re.finditer(r"\S+", text):
|
||||
tokens.append(match.group(0))
|
||||
spans.append(match.span())
|
||||
return tokens, spans
|
||||
|
||||
|
||||
def _extract_spans(task: dict) -> list[tuple[int, int, str]]:
|
||||
annotations = task.get("annotations") or []
|
||||
if not annotations:
|
||||
return []
|
||||
results = annotations[0].get("result", []) or []
|
||||
spans = []
|
||||
for result in results:
|
||||
value = result.get("value", {})
|
||||
start = value.get("start")
|
||||
end = value.get("end")
|
||||
labels = value.get("labels") or []
|
||||
if start is None or end is None or not labels:
|
||||
continue
|
||||
spans.append((int(start), int(end), _normalize_label(labels[0])))
|
||||
return spans
|
||||
|
||||
|
||||
def parse_annotated_dataset(path: str) -> list[dict]:
|
||||
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("Annotated dataset must be a JSON array.")
|
||||
|
||||
examples = []
|
||||
for task in data:
|
||||
text = task["data"]["text"]
|
||||
spans = _extract_spans(task)
|
||||
tokens, token_spans = _text_to_tokens(text)
|
||||
labels = []
|
||||
for token_start, token_end in token_spans:
|
||||
assigned = "RAW_PHRASE"
|
||||
for span_start, span_end, span_label in spans:
|
||||
if token_start < span_end and token_end > span_start:
|
||||
assigned = span_label
|
||||
break
|
||||
labels.append(assigned)
|
||||
examples.append({"tokens": tokens, "labels": labels})
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def tokenize_and_align(example, tokenizer, label2id):
|
||||
tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True)
|
||||
word_ids = tokenized.word_ids()
|
||||
labels = []
|
||||
prev = None
|
||||
|
||||
for word_id in word_ids:
|
||||
if word_id is None:
|
||||
labels.append(-100)
|
||||
elif word_id != prev:
|
||||
labels.append(label2id[example["labels"][word_id]])
|
||||
else:
|
||||
labels.append(-100)
|
||||
prev = word_id
|
||||
|
||||
tokenized["labels"] = labels
|
||||
return tokenized
|
||||
|
||||
|
||||
def train(dataset_path: str = "./datasets/annotated/ffmpeg_gemini_v1.json", model_dir: str = MODEL_DIR):
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
MODEL_NAME, num_labels=len(LABEL_LIST), id2label=ID2LABEL, label2id=LABEL2ID
|
||||
)
|
||||
|
||||
examples = parse_annotated_dataset(dataset_path)
|
||||
dataset = Dataset.from_list(examples)
|
||||
dataset = dataset.map(partial(tokenize_and_align, tokenizer=tokenizer, label2id=LABEL2ID))
|
||||
dataset = dataset.train_test_split(test_size=0.1)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=model_dir,
|
||||
learning_rate=3e-5,
|
||||
per_device_train_batch_size=8,
|
||||
num_train_epochs=5,
|
||||
weight_decay=0.01,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch",
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(model_dir)
|
||||
tokenizer.save_pretrained(model_dir)
|
||||
|
||||
|
||||
def _load_model(model_dir: str = MODEL_DIR):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForTokenClassification.from_pretrained(model_dir)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def predict(text: str, model_dir: str = MODEL_DIR, top_k: int = 3):
|
||||
tokenizer, model = _load_model(model_dir)
|
||||
inputs = tokenizer(text.split(), return_tensors="pt", is_split_into_words=True)
|
||||
outputs = model(**inputs)
|
||||
|
||||
logits = outputs.logits[0]
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
k = min(max(1, top_k), probs.shape[-1])
|
||||
top_probs, top_ids = torch.topk(probs, k, dim=-1)
|
||||
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
||||
|
||||
for token, prob_row, id_row in zip(tokens, top_probs, top_ids):
|
||||
preds = [
|
||||
f"{ID2LABEL.get(int(label_id), 'RAW_PHRASE')}:{float(score):.4f}"
|
||||
for label_id, score in zip(id_row, prob_row)
|
||||
]
|
||||
print(token, ", ".join(preds))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
@@ -5,8 +5,13 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"accelerate>=1.13.0",
|
||||
"datasets>=4.8.4",
|
||||
"git-filter-repo>=2.47.0",
|
||||
"google-genai>=1.70.0",
|
||||
"label-studio>=1.23.0",
|
||||
"label-studio-ml>=1.0.9",
|
||||
"loguru>=0.7.3",
|
||||
"torch>=2.11.0",
|
||||
"transformers>=5.5.0",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user