Source code for fanoutqa.retrieval

"""This module contains a baseline implementation of a retriever for use with long Wikipedia articles"""

from dataclasses import dataclass
from typing import Iterable

try:
    import numpy as np
    from rank_bm25 import BM25Plus
except ImportError as e:
    raise ImportError(
        "Using the baseline retriever requires the rank_bm25 package. Use `pip install fanoutqa[retrieval]`."
    ) from e

from .models import Evidence
from .norm import normalize
from .wiki import wiki_content


[docs] @dataclass class RetrievalResult: title: str """The title of the article this fragment comes from.""" content: str """The content of the fragment."""
[docs] class Corpus: """ A corpus of wiki docs. Indexes the docs on creation, normalizing the text beforehand with lemmatization. Splits the documents into chunks no longer than a given length, preferring splitting on paragraph and sentence boundaries. Documents will be converted to Markdown. Uses BM25+ (Lv and Zhai, 2011), a TF-IDF based approach to retrieve document fragments. To retrieve chunks corresponding to a query, iterate over ``Corpus.best(query)``. .. code-block:: python # example of how to use in the Evidence Provided setting prompt = "..." corpus = fanoutqa.retrieval.Corpus(q.necessary_evidence) for fragment in corpus.best(q.question): # use your own structured prompt format here prompt += f"# {fragment.title}\\n{fragment.content}\\n\\n" """ def __init__(self, documents: list[Evidence], doc_len: int = 2048): """ :param documents: The list of evidences to index :param doc_len: The maximum length, in characters, of each chunk """ self.documents = [] normalized_corpus = [] for doc in documents: title = doc.title content = wiki_content(doc) for chunk in chunk_text(content, max_chunk_size=doc_len): self.documents.append(RetrievalResult(title, chunk)) normalized_corpus.append(self.tokenize(chunk)) self.index = BM25Plus(normalized_corpus) @staticmethod def tokenize(text: str): return normalize(text).split(" ")
[docs] def best(self, q: str) -> Iterable[RetrievalResult]: """Yield the best matching fragments to the given query.""" tok_q = self.tokenize(q) scores = self.index.get_scores(tok_q) idxs = np.argsort(scores)[::-1] for idx in idxs: yield self.documents[idx]
[docs] def chunk_text(text, max_chunk_size=1024, chunk_on=("\n\n", "\n", ". ", ", ", " "), chunker_i=0): """ Recursively chunks *text* into a list of str, with each element no longer than *max_chunk_size*. Prefers splitting on the elements of *chunk_on*, in order. """ if len(text) <= max_chunk_size: # the chunk is small enough return [text] if chunker_i >= len(chunk_on): # we have no more preferred chunk_on characters # optimization: instead of merging a thousand characters, just use list slicing return [text[:max_chunk_size], *chunk_text(text[max_chunk_size:], max_chunk_size, chunk_on, chunker_i + 1)] # split on the current character chunks = [] split_char = chunk_on[chunker_i] for chunk in text.split(split_char): chunk = f"{chunk}{split_char}" if len(chunk) > max_chunk_size: # this chunk needs to be split more, recurse chunks.extend(chunk_text(chunk, max_chunk_size, chunk_on, chunker_i + 1)) elif chunks and len(chunk) + len(chunks[-1]) <= max_chunk_size: # this chunk can be merged chunks[-1] += chunk else: chunks.append(chunk) # if the last chunk is just the split_char, yeet it if chunks[-1] == split_char: chunks.pop() # remove extra split_char from last chunk chunks[-1] = chunks[-1][: -len(split_char)] return chunks