Кратко про Multi-Head RAG: решение многоаспектных задач с помощью LLM

cfc3cc75d90255fa46cce2b2b93388dd.png

Привет, Хабр!

Современные языковые модельки обладают огромным потенциалом, но они часто сталкиваются с трудностями, когда дело доходит до решения комплексных задач, требующих доступа к разнообразным источникам данных. Multi-Head RAG объявился на нашем свете для того, чтобы изменить эту ситуацию. Эта модель сочетает генерацию и поиск информации, что позволяет ей справляться с многогранными задачами, которые традиционно сложны для обычных LLM.

Multi-Head RAG

RAG

Retrieval-Augmented Generation (RAG) — это архитектурный подход, который усиливает возможности LLM путем интеграции поиска информации в процесс генерации текста. Традиционные LLM обучаются на статичных данных, что ограничивает их актуальность и точность при работе с динамической или специализированной информацией. RAG решает эту проблему, добавляя возможность извлечения актуальных данных из внешних источников в реальном времени.

Процесс RAG включает три основные этапа:

  1. Поиск: на этом этапе моделька извлекает релевантную информацию из внешних БД. Для этого используются методы индексирования и поиска, типо Locality-Sensitive Hashing и k-Nearest Neighbors. Процесс начинается с преобразования текста в эмбеддинги, которые хранятся в векторной БД. Эти векторы помогают модели быстро находить наиболее релевантную информацию.

  2. Дополнение: извлеченная информация добавляется к исходным данным, передаваемым в модель, обогащая контекст. Этот процесс включает в себя добавление релевантных данных к пользовательскому запросу, что позволяет модели использовать дополненную информацию для генерации более точного и актуального ответа. На этом этапе начинается тот самый prompt engineering для интеграции новых данных с исходным контекстом​

  3. Генерация: модель использует дополненный контекст для создания информированного и актуального ответа. На этом этапе ответ синтезируется на основе внутренних данных модели и доп. информации, полученной на этапе поиска. Используя механизмы трансформеров, типо self-attention, модель генерирует текст, который не только релевантен, но и основан на последних данных​

Например, RAG с использованием Hugging Face Transformers и Faiss:

import torch
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

# инициализация токенизатора, ретривера и модели
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

# ввод пользователя
question = "What is the capital of France?" # интересный вопрос

# токенизация и извлечение контекста
input_ids = tokenizer(question, return_tensors="pt").input_ids
retrieved_docs = retriever(input_ids)

# генерация ответа
outputs = model.generate(input_ids, num_beams=5, num_return_sequences=1)
generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Question:", question)
print("Answer:", generated_answer)

Multi-Head RAG

Представляет собой расширение традиционной архитектуры RAG, внедряя многоголовый подход. В ванильном RAG используется один механизм поиска и генерации, что может быть недостаточно для сложных и многоаспектных задач. Multi-Head RAG решает эту проблему следующим образом:

В Multi-Head RAG каждая голова в модели отвечает за обработку определенного аспекта задачи.

Каждая голова модели может быть настроена для работы с разными типами данных, такими как текстовые документы, изображения или аудио. Многоголовая архитектура позволяет ускорить обработку данных за счет параллельной работы нескольких голов.

Каждая голова в Multi-Head RAG может быть настроена для работы с определенными типами информации или контекстами. Например:

  1. Тексты научных статей: одна голова может быть специально обучена на текстах из научных журналов и баз данных, что дает фичу анализы научной литературы.

  2. Данные из социальных сетей: другая голова может быть настроена для анализа данных из соц. сетей, что позволяет учитывать контекст и специфику соц. медиа.

  3. Коммерческая информация: третья голова может обрабатывать данные из коммерческих источников, вроде новостей, различных фин. отчетов и т.п, предоставляя тем самым актуальные ответы.

Результаты от каждой головы комбинируются для формирования финального ответа.

Извлеченные данные сопоставляются и фильтруются для удаления дублирующейся или нерелевантной информации.

Комбинируются данные из различных источников для создания единого, консистентного ответа.

Таким образом, архитектура будет выглядеть так:

  1. Модуль поиска: Включает несколько голов, каждая из которых отвечает за поиск информации в специфичных источниках данных.

  2. Модуль генерации: Использует агрегированные данные для создания информированного ответа. Генеративная модель интегрируется с результатами поиска.

  3. Интерфейс пользователя.

Пример реализации:

import torch
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
from faiss import IndexFlatL2
import numpy as np

# инициализация токенизатора и модели
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever_1 = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact")
retriever_2 = RagRetriever.from_pretrained("facebook/rag-token-wiki", index_name="exact")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=[retriever_1, retriever_2])

# пример пользовательского вопроса
question = "What are the recent advancements in quantum computing?"
input_ids = tokenizer(question, return_tensors="pt").input_ids

# извлечение релевантных документов из нескольких источников
retrieved_docs_1 = retriever_1(input_ids)
retrieved_docs_2 = retriever_2(input_ids)
combined_docs = retrieved_docs_1 + retrieved_docs_2

# генерация ответа
outputs = model.generate(input_ids, context_input_ids=combined_docs, num_beams=5, num_return_sequences=1)
generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Question:", question)
print("Answer:", generated_answer)

# инициализация Faiss для кастомного индекса
dimension = 768
index = IndexFlatL2(dimension)
data = np.random.random((1000, dimension)).astype('float32')
index.add(data)
query_vector = np.random.random((1, dimension)).astype('float32')
D, I = index.search(query_vector, 10)

print("Nearest Neighbors:", I)

# функция для интеграции Faiss с RAG
def retrieve_custom_index(query, index, tokenizer, model):
    query_vector = tokenizer(query, return_tensors="pt").input_ids.numpy()
    D, I = index.search(query_vector, 10)
    retrieved_docs = [tokenizer.decode(idx) for idx in I[0]]
    return retrieved_docs

# пример пользовательского вопроса
custom_question = "What are the recent trends in AI?"
custom_retrieved_docs = retrieve_custom_index(custom_question, index, tokenizer, model)

# генерация ответа с использованием кастомного индекса
custom_input_ids = tokenizer(custom_question, return_tensors="pt").input_ids
custom_outputs = model.generate(custom_input_ids, context_input_ids=custom_retrieved_docs, num_beams=5, num_return_sequences=1)
custom_generated_answer = tokenizer.decode(custom_outputs[0], skip_special_tokens=True)

print("Custom Question:", custom_question)
print("Custom Answer:", custom_generated_answer)

Используются два ретривера для работы с различными индексами данных.

Также токенизируем запроса юзера и извлекаем документы из двух различных источников.

Multi-Head RAG улучшает точность за счет параллельной обработки различных аспектов задачи, позволяя учитывать больше контекста и разнообразие данных. Традиционные модели часто ограничены в этом отношении и могут допускать ошибки.

Про RAG и другие модели и инструменты эксперты OTUS рассказывают в рамках практических курсов по машинному обучению. Переходите в каталог и выбирайте подходящее направление.

Habrahabr.ru прочитано 2265 раз