본 포스팅은 RAG를 처음 입문하는 스타터들을 위해 간단하게 RAG를 코드로 구현하는 방법을 설명합니다.
먼저 RAG 코드 구현에 있어 RAG가 뭔지 모르겠다면 아래 글을 참고해주길 바랍니다.
Prompt Engineering으로 ChatGPT 제대로 활용하기 - (5) 프롬프트 엔지니어링 심화편 III (Tree of Thought, RAG)
Tree of Thought기존 Chain of Thought(CoT) 방식은 LLM이 문제 해결 시 연속된 단일 추론 경로(chain)에 의존하게 된다.그러나 탐색(exploration)이나 전략적인 미래 예측(lookahead)이 필요한 복잡한 문제에서는 한
kangth97.tistory.com
0. Prerequisite
실행환경은 local Jupyter Notebook이고, python 3.11, jupyter 1.1.1, torch 2.5.1 버전을 사용했습니다.
먼저 아래 필요한 라이브러리들을 pip로 설치해 줍니다.
!pip install transformers==4.28.0 sentence-transformers==2.2.2 faiss-cpu datasets==2.13.0
1. Dataset 준비
과도하게 큰 코퍼스 대신, HuggingFace의 ag_news 같은 소규모 텍스트 데이터를 예시로 사용한다.
ag_news는 뉴스 토픽 classification dataset으로, 127,600개의 뉴스데이터와 4개의 토픽이 존재한다.
토픽 라벨은 {1: World 2: Sports 3: Business 4: Sci/Tech} 으로 분류되어 있다.
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset
# 예시: AG News 데이터셋 (4개 카테고리의 뉴스 제목+본문)
dataset = load_dataset("ag_news", split="train[:1000]") # 샘플로 1000개만 사용
dataset = dataset.to_pandas() # Pandas로 변환
dataset.head(3)
### output
text label
0 Wall St. Bears Claw Back Into the Black (Reute... 2
1 Carlyle Looks Toward Commercial Aerospace (Reu... 2
2 Oil and Economy Cloud Stocks' Outlook (Reuters... 2
2. 모델 불러오기
Generator를 위한 모델은 GPU 메모리 제한에 맞춰 facebook/bart-base를 선택했다.
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
GENERATOR_MODEL_NAME = "facebook/bart-base"
generator_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_NAME)
generator_model = AutoModelForSeq2SeqLM.from_pretrained(GENERATOR_MODEL_NAME)
Retrieval을 위한 임베딩 모델로 sentence-transformers 계열을 많이 사용한다.
그 중에서 "sentence-transformers/all-MiniLM-L6-v2" 모델이 가볍고 성능도 준수해서 많이들 사용한다.
from sentence_transformers import SentenceTransformer
QUERY_ENCODER_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
query_encoder = SentenceTransformer(QUERY_ENCODER_MODEL_NAME)
3. Query Encoder
사용자의 질문을 받아 query encoder로 임베딩 벡터를 생성하는 함수이다.
def encode_query(query: str):
# SentenceTransformers 모델로 임베딩 추출
embedding = query_encoder.encode([query], convert_to_numpy=True)
# shape: (384, )
return embedding[0]
tensor = encode_query("안녕하세요! 내일 날씨가 어떻습니까?")
tensor.shape
### output
(384,)
4. Document Indexing 하기
먼저 ag_news 데이터셋의 본문 텍스트 1000개에 대해 임베딩 모델에 넣어 각 문서에 대한 임베딩 벡터를 생성한다.
이후 FAISS index를 구성하여 문서 임베딩을 추가한다.
index 객체가 우리의 Document Index가 됩니다. 여기에 쿼리 임베딩을 입력하게 되면 각 문서 임베딩 벡터와 유사도를 계산하여 top-k개의 문서를 검색할 수 있다.
import faiss
documents = dataset["text"].tolist()
doc_embeddings = query_encoder.encode(documents, convert_to_numpy=True) # shape: (num_docs, 384)
normalized_embeddings = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
dim = normalized_embeddings.shape[1] # 임베딩 차원
index = faiss.IndexFlatL2(dim) # L2 거리기반
index.add(normalized_embeddings) # 문서 임베딩 등록
index.ntotal
### output
1000
5. Retriever
Retrieval 함수는 쿼리를 입력받아
- 쿼리 임베딩을 계산하고
- FAISS index에서 유사한 문서 top-k개를 찾고
- 해당 문서들의 텍스트를 반환합니다.
def retrieve_top_k_docs(query: str, k=3):
q_emb = encode_query(query).reshape(1, -1) # (1, dim)
distances, indices = index.search(q_emb, k) # (1, k) shape
# indices: (1, k) 형태, 실제 문서 인덱스
top_k_docs = [documents[i] for i in indices[0]]
return top_k_docs, distances[0]
query_example = "Find news about microsoft"
top_docs, dists = retrieve_top_k_docs(query_example, k=5)
for i, (doc, dist) in enumerate(zip(top_docs, dists)):
print(f"Top {i+1} doc (dist={dist:.4f}):\n{doc}\n")
테스트로 "Find news about microsoft" 라는 쿼리를 날려 faiss에서 관련 문서를 요청하면 다음과 같은 결과가 나온다.
L2 distance 기반이기 때문에 거리가 작을수록 유사한 문서이다.
Top 1 doc (dist=0.7838):
Taking the Microsoft Rorschach test CNET News.com's Charles Cooper asks what it is about Microsoft that pushes so many people straight over the edge?
Top 2 doc (dist=0.8133):
Microsoft to Introduce Cheaper Version of Windows SEATTLE (Reuters) - Microsoft Corp. <MSFT.O> said it will begin selling a stripped-down, low-cost version of its Windows XP operating system in the emerging markets of Indonesia, Malaysia and Thailand in order to spread the use of computing and develop technology markets.
Top 3 doc (dist=0.9460):
Microsoft Lists XP SP2 Problems (NewsFactor) NewsFactor - With automatic download of Microsoft's (Nasdaq: MSFT) enormous SP2 security patch to the Windows XP operating system set to begin, the industry still waits to understand its ramifications. Home users that have their preferences set to receive operating-system updates as they are made available by Microsoft may be surprised to learn that some of the software they already run on their systems could be disabled by SP2 or may run very differently.
Top 4 doc (dist=0.9588):
Microsoft Upgrades Software for Digital Pictures SEATTLE (Reuters) - Microsoft Corp. <MSFT.O> released on Tuesday the latest version of its software for editing and organizing digital photographs and images to tap into widespread demand for digital cameras and photography.
Top 5 doc (dist=1.0524):
Microsoft Corp. 2.0: a kinder corporate culture Even a genius can mess up. Bill Gates was a brilliant technologist when he cofounded Microsoft , but as he guided it to greatness in both size and historical consequence, he blundered. He terrorized underlings with his temper and parceled out praise like Scrooge gave to charity. Only the lash inspired the necessary aggressiveness to beat the competition, he thought.
6. Generator
이제 RAG 방식으로 "사용자 쿼리 + 검색 문서"를 결합하여 seq2seq 모델로 최종 답변을 생성한다.
여기서는 간단히 "문서들을 전부 합쳐서" in-context를 생성하고, BART에 입력을 넣어 inference를 한다.
실제 RAG는 각 문서마다 답변을 내고 확률을 결합하는 late fusion을 할 수 있으나 여기서는 이해를 돕기위해 단순하게 구현했다.
def rag_generate_answer(query: str, k=3, max_length=128):
# 1) Retrieval
top_docs, _ = retrieve_top_k_docs(query, k)
context_str = ""
# 2) Context 만들기 (단순 예시)
for idx, doc in enumerate(top_docs):
context_str = context_str + f"\n[news no.{idx}]\n{doc}"
combined_prompt = f"Question: {query}\nContext: {context_str}\nresult:"
# 3) Generator로 답 생성
inputs = generator_tokenizer([combined_prompt], return_tensors="pt", truncation=True)
# GPU 사용 가능 시 -> inputs = inputs.to('cuda'), model도 .to('cuda') 가능
with torch.no_grad():
outputs = generator_model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True,
do_sample=True,
temperature=0.2
)
answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
이제 예시 쿼리로 테스트를 해보자.
# 예시 쿼리
query_sample = "when does google's auction open?"
answer_output = rag_generate_answer(query_sample, k=3)
print("=== RAG-style Generation ===")
print("Query:", query_sample)
print("Answer:", answer_output)
### prompt
Question: when does google's auction open?
Context:
[news no.0]
Google auction begins on Friday An auction of shares in Google, the web search engine which could be floated for as much as \$36bn, takes place on Friday.
[news no.1]
Google IPO Bidding Opens Google IPO Bidding Opens\\Google's IPO bidding is officially open. Google and its underwriters expect to open the auction for the shares of Google rsquo;s Class A common stock at 9:00 a.m. EST (press time) on Friday, August 13, 2004. Google bidders must have obtained a bidder ID from ipo.google.com if you ...
[news no.2]
In Google's Auction, It's Not Easy to Tell a Bid From a Bet In a competition combining suspense and strategy, countless brave souls are hoping to buy a small piece of Google in an auction this week.
result:
### output
=== RAG-style Generation ===
Query: when does google's auction open?
Answer: Question: when does google's auction open?Context: ߣ[news no.0]ߣGoogle auction begins on Friday An auction of shares in Google, the web search engine which could be floated for as much as \$36bn, takes place on Friday. The auction will take place at 9:00 a.m. EST (press time) on Friday, August 13, 2004. _______________[newsno.1]\\\\\\\\Google IPO Bidding Opens Google IPO Bids Opens\\Google's IPO bidding is officially open. Google and its underwriters expect to open the auction for the
"when does google's auction open?" 이라는 간단한 쿼리를 날려보았다.
prompt는 question-context-result로 구성하였고, faiss index에서 google과 auction과 관련된 뉴스들을 잘 가져온것을 확인할 수 있다.
다만 output의 퀄리티가 굉장히 좋지 않다. prompt의 내용을 거의 그대로 가져오는 듯 한데, 이는 generator 모델의 크기가 매우 작고 성능이 낮기 때문에 그렇다.
다음 포스팅에서는 이를 해결하기 위해 OpenAI API와 LangChain을 이용하여 ChatGPT 모델로 RAG 시스템을 구현하는 글을 올릴 예정이다.