Retrieval-Augmented Generation в техподдержке на основе YandexGPT

Большие языковые модели обладают обширными знаниями о мире, однако они не всезнающи. Из-за длительности процесса обучения, информация, использованная во время последнего обучения, может быть устаревшей. Несмотря на то что эти модели знакомы с общедоступной информацией из Интернета, они не обладают знаниями о специфических данных, необходимых для реализации бизнес-логики приложений на основе искусственного интеллекта. До появления LLM, обновление данных моделей достигалось путем их дополнительного обучения. Однако с увеличением масштаба и объема данных, используемых для обучения, дополнительное обучение стало эффективным только для ограниченного числа сценариев использования. RAG представляет собой метод интеграции пользовательских данных в LLM, который основывается на использовании промптов.

Статья посвящена рассмотрению процесса создания системы генерации ответов службы технической поддержки. Для этого используется методика с расширенным поиском, известная как Retrieval-Augmented Generation (RAG). Процесс основан на использовании шаблонов и реальных вопросов-ответов техподдержки. В качестве основных инструментов применяются YandexGPT и ChromaDB.

Архитектура системы

RAG с векторным поиском

RAG с векторным поиском

Input — первое сообщение клиента или текущий диалог,
Embedding — получение эмбеддингов в YandexGPT API,
Search service — ChromaDB, будет хранить эмбеддинги, метаданные, документы и производить поиск по схожим эмбеддингам
Documents — вопросы-ответы из внутреннего сервиса техподдержки,
LLM — YandexGPT, TextGenerationAsync.,
Output — ответ, требующий модерации.

Создание обертки для YandexGPT

[Исходный код]

Все вызовы YandexGPT API имеют одинаковые заголовки и требуют URI модели.

class YandexGpt:
    def __init__(self, api_key: str, model_uri: str):
        self.api_str = api_key
        self.model_uri = model_uri

    def get_headers(self):
        return {
            "Content-Type": "application/json",
            "Authorization": f"Api-Key {self.api_str}",
            "x-data-logging-enabled": "false"
        }

В закрытом тестировании было много ошибок, возникающих по непонятным причинам. Чтобы не нарушать логику приложения, приходилось один и тот же запрос отправлять, пока не будет корректного ответа. Сейчас, когда идет открытое тестирование, подобных ошибок я не видел, но в рейт лимит упирался (ошибка 429). Для обработки ошибок написан декоратор, которые пытается заново отправить запрос при status_code != 200:

def retry_yandex_gpt_factory(reties=2):
    def retry_yandex_gpt(func):
        def wrapper_retry_yandex_gpt(*args, **kwargs):
            for retry in range(reties):
                res = func(*args, **kwargs)
                if (res.status_code) == 200:
                    return res.json()
                else:
                    print(f"Request failed {res.status_code}: {res.json()}, retry number: {retry + 1}")
                    if res.status_code == 429:
                        sleep(5)

        return wrapper_retry_yandex_gpt

    return retry_yandex_gpt

Обертка над API для создания эмбеддингов (численных векторов):

class Embeddings(YandexGpt):
    @retry_yandex_gpt_factory(5)
    def text_embedding(self, text: str):
        url = "https://llm.api.cloud.yandex.net/foundationModels/v1/textEmbedding"
        data = {
            "modelUri": self.model_uri,
            "text": text
        }

        return requests.post(url, json=data, headers=self.get_headers())

YandexGPT API имеет синхронный и асинхронный вызовы метода completion. Согласно документации, асинхронный вызов медленнее синхронного (8–15 сек), более точный и дешевле в 2 раза. Для техподдержки более критична точность ответа, нежели ожидание в 15 сек. Метод sync_completion делает completion из асинхронного в синхронный.

class MessageRole(Enum):
    SYSTEM = 'system'
    ASSISTANT = 'assistant'
    USER = 'user'


class Message:
    def __init__(self, role: MessageRole, text: str):
        self.role = role
        self.text = text


class TextGenerationAsync(YandexGpt):
    @retry_yandex_gpt_factory()
    def completion(self, messages: list[Message], stream: bool, temperature: int, max_tokens: int):
        url = "https://llm.api.cloud.yandex.net/foundationModels/v1/completionAsync"
        data = {
            "modelUri": self.model_uri,
            "completionOptions": {
                "stream": stream,
                "temperature": temperature,
                "maxTokens": max_tokens
            },
            "messages": [{"role": str(msg.role.value), "text": msg.text} for msg in messages]
        }
        return requests.post(url, json=data, headers=self.get_headers())

    def get_operation(self, operation_id: str):
        url = "https://operation.api.cloud.yandex.net/operations/" + operation_id
        return requests.get(url, headers=self.get_headers()).json()

    def sync_completion(self, messages: list[Message], stream: bool, temperature: float, max_tokens: int, max_wait_secs: int):
        operation_id = self.completion(messages, stream, temperature, max_tokens)['id']

        for i in range(max_wait_secs):
            res = self.get_operation(operation_id)
            if res["done"]:
                return res
            sleep(1)

Извлечение данных

[Исходный код]

Допустим, что имеются 2 источника информации для расширенного поиска — intents и messages.

Код извлечения intents

def get_intents_df():
    intents_url = "https://example.com/api/bot/findIntent"
    total_count = 3815
    limit = 500
    headers = {
        "Cookie": os.getenv("COOKIE")
    }
    
    df = pd.DataFrame(columns=['id', 'text', 'pattern', 'intentId', 'groupId', 'answer'])
    
    for page in tqdm(range(math.ceil(total_count / limit))):
        res = requests.get(
            intents_url, 
            params={"limit": limit, "count": True, "page": page+1}, 
            headers=headers
        )
        rows = res.json()["rows"]
        df = pd.concat([df, pd.DataFrame(rows)], ignore_index=True)
    return df

intents_df = get_intents_df()
intents_df.to_csv("data/intents.csv", index=False)

Структура исходных данных:

intents_df.info()
RangeIndex: 3815 entries, 0 to 3814
Data columns (total 6 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   id        3815 non-null   object
 1   text      3815 non-null   object
 2   pattern   0 non-null      object
 3   intentId  3815 non-null   object
 4   groupId   0 non-null      object
 5   answer    3815 non-null   object

Код извлечения messages

def remove_now_answered_column(df):
    return df.drop(['nowAnswered'], axis=1)

def get_messages_df():
    messages_url = "https://example.com/api/bot/findMessage"
    total_count = 35574
    limit = 500
    page = 1
    headers = {
        "Cookie": os.getenv("COOKIE")
    }
    
    df = pd.DataFrame(columns=['answer', 'answered', 'chatId', 
                               'clientId', 'messageId', "success", 
                               "text"])
    
    for page in tqdm(range(math.ceil(total_count / limit))):
        res = requests.get(
            messages_url, 
            params={"limit": limit, "count": True, "page": page + 1, 
                    "sortBy": "TIMESTAMP", "ascending": False}, 
            headers=headers
        )
        rows = res.json()["rows"]
        df = pd.concat([df, pd.DataFrame(rows)], ignore_index=True) 
    return df.pipe(remove_now_answered_column)

messages_df = get_messages_df()
messages_df.to_csv("data/messages.csv", index=False)

Структура исходных данных. В БД будут внесены только те строки, которые имеют ненулевое поле answer, т.к. эмбеддинги этого поля будут ключем к полю text.

intents_df.info()
RangeIndex: 35576 entries, 0 to 35575
Data columns (total 7 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   answer     8414 non-null   object
 1   answered   35576 non-null  object
 2   chatId     35576 non-null  object
 3   clientId   35576 non-null  object
 4   messageId  35576 non-null  object
 5   success    35576 non-null  object
 6   text       35576 non-null  object

Получение эмбеддингов и создание векторной базы данных в ChromaDB

[Исходный код]

Создание и сохранение эмбеддингов в csv

def get_emdeddings(text):
    sleep(0.5)
    return embeddings.text_embedding(text)["embedding"]

embeddings = yandexgpt.Embeddings(os.getenv("YANDEX_GPT_KEY"), 
                                  os.getenv("YANDEX_GPT_EMBEDDINGS_URI")
                              )

# Получение эмбеддиногов для intents
df = pd.read_csv("data/intents.csv")
text_embeddings = []
for txt in tqdm(df["text"]):
    text_embeddings.append(get_emdeddings(txt))
df["text_embeddings"] = text_embeddings
df.to_csv("data/intents_with_embeddings.csv", index=False)

# Получение эмбеддиногов для messages
df = pd.read_csv("data/messages.csv").dropna(subset=['answer'])
text_embeddings = []
for txt in tqdm(df["text"]):
    text_embeddings.append(get_emdeddings(txt))
df["text_embeddings"] = text_embeddings
df.to_csv("data/messages_with_embeddings.csv", index=False)

Запуск ChromaDB

docker pull chromadb/chroma
docker run -p 8000:8000 chromadb/chroma

Подключение к СУБД к коллекции

chroma_client = chromadb.HttpClient(host='localhost', 
                                    port="8000", 
                                    settings=Settings(anonymized_telemetry=False))

collection = chroma_client.get_or_create_collection("intents")

Загрузка эмбеддингов вопросов, источников и вопросов (metadata), ответов (documents) в БД:

df = pd.read_csv('data/intents_with_embeddings.csv')
texts = df["text"].tolist()
text_embeddings = list(map(
    lambda str_arr: ast.literal_eval(str_arr), 
    df["text_embeddings"].tolist()))
ids = df["id"].astype(str).tolist()
answers = df["answer"].tolist()
collection.upsert(
    ids=ids,
    embeddings=text_embeddings,
    metadatas=[{"source": "intents", "text": txt} for txt in texts],
    documents=answers
)


df = pd.read_csv('data/messages_with_embeddings.csv')
texts = df["text"].tolist()
text_embeddings = list(map(
    lambda str_arr: ast.literal_eval(str_arr), 
    df["text_embeddings"].tolist()))
ids = df["messageId"].astype(str).tolist()
answers = df["answer"].tolist()
collection.upsert(
    ids=ids,
    embeddings=text_embeddings,
    metadatas=[{"source": "messages", "text": txt} for txt in texts],
    documents=answers
)

Генерация ответов

Исходный код с примерами генерации

Подключение к ChromaDB и создание оберток YandexGPT для эмбеддингов и генерации текста:

chroma_client = chromadb.HttpClient(host='localhost', 
                                    port="8000", 
                                    settings=Settings(anonymized_telemetry=False))
embeddings = yandexgpt.Embeddings(os.getenv("YANDEX_GPT_KEY"), 
                                  os.getenv("YANDEX_GPT_EMBEDDINGS_URI"))
textGenerationAsync = yandexgpt.TextGenerationAsync(os.getenv("YANDEX_GPT_KEY"), 
                                  os.getenv("YANDEX_GPT_URI"))
  1. получение эмбеддингов вопроса,

  2. поиск релеваных эмбеддингов с документами и метаданными (6 шт.),

  3. отброс тех, которы имеют дистанцию больше 1,

  4. задание промпта, создание диалога на основе 6 релевантных вопрос-ответов, генерация описания с температурой 0, макс. кол-вом генерируемых токенов 250.

def format_qa(func):
    def wrapper_format_qa(*args, **kwargs):
        print("Вопрос:", args[0])
        print("Ответ:", func(*args, **kwargs))
    return wrapper_format_qa

@format_qa
def generate_answer(question: str):
    result = collection.query(
        query_embeddings=[embeddings.text_embedding(question)["embedding"]],
        n_results=6,
    )
    
    messages = [yandexgpt.Message(yandexgpt.MessageRole.SYSTEM, "Ты специалист технической поддежки. На основе сообщений, написанных тобой выше, сгенерируй сообщение")]
    for distance, metadata, document in zip(
            result["distances"][0], result["metadatas"][0], result["documents"][0]
    ):
        if distance < 1:
            messages.append(yandexgpt.Message(yandexgpt.MessageRole.USER, metadata["text"]))
            messages.append(yandexgpt.Message(yandexgpt.MessageRole.ASSISTANT, document))
    
    messages.append(yandexgpt.Message(yandexgpt.MessageRole.USER, question))
    return textGenerationAsync.sync_completion(messages, False, 0, 250, 20)["response"]["alternatives"][0]["message"]["text"]

Вывод

Сообщения, сгенерированные RAG могут ускорить работу ТП, но нельзя допускать отправку клиентам сообщений, сгенерированных таким образом без модерации.

© Habrahabr.ru