Three issues you might encounter when using LLMs like ChatGPT are: firstly, its training data is only up to 2022; secondly, it can sometimes give the impression of knowing everything, even while providing wrong information convincingly (a phenomenon known as hallucination) when it doesn't actually have any correct answer; and finally, it can be extremely difficult to track where the information came from, making source attribution challenging. If you've ever found yourself sent on a wild goose chase because of ChatGPT, Retrieval Augmented Generation (RAG) might be just what you need.
Retrieval-Augmented Generation (RAG) is a method that enhances the output of a large language model by retrieving information from a knowledge base other than its training data. This process involves "retrieval," where relevant information is pulled from external sources, and "augmentation," where this retrieved information is put together with LLM for accurate and up-to-date responses.
How do RAGs work?
Now that we understand what it means and the problems it can solve, let's take a closer look at what goes under the hood to make it all work. The 5 Step diagram below highlights the effective working of Retrieval and Augmentation enabled by the Chatbot.
User Interaction: When we interact with the chatbot by asking a question or making a request.
Search and Retrieval Initiation: The chatbot forwards the user's query for search and retrieval.
Information Retrieval: The search and retrieval mechanism searches for relevant information in an external knowledge base or database you provided and sends back the enhanced context.
Information Augmentation: The retrieved information or context is put together with initial prompt question and sends to the LLM for generating response.
Response Generation: The chatbot uses the retrieved information to generate a coherent and accurate response.
Building a RAG from Scratch in Python
Time for some code, as I put together a simple RAG in less than 50 lines of code using pure python. It's time to tinkle in the Google Colaboratory . Our RAG model will be built using Meta's open-source llama-3-8b-Instruct model deployed via HuggingFace, along with helpful functions from the llama index package.
There are a few pre-requisites before you start getting on to coding.
Sign up to HuggingFace and generate Access Tokens
Get gated access to the llama model you wish to use approved by Meta following the link here and filling the form. You should get a mail once your request has been approved by the owners.
Switch your Google Colaboratory Runtime to use GPU.
Having all this set up, we can now kickstart with some coding.
Installing and Importing Relevant Packages
## For optimising the process as bitsandbytes is used for memory reduction and some here might be due to dependencies.
!pip install -q transformers einops accelerate langchain bitsandbytes
## For Embedding
!pip install sentence_transformers
##llama-index packages
!pip install llama_index
!pip install llama-index-embeddings-langchain
!pip install llama-index-llms-huggingface
!pip install langchain-community
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core import ServiceContext
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.prompts.prompts import SimpleInputPrompt
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index.core import ServiceContext
from llama_index.legacy.embeddings.langchain import LangchainEmbedding
Importing/Uploading your Additional Documents
documents= SimpleDirectoryReader('/content').load_data()
The documents should be uploaded in the content folder within Google Colaboratory. For this example, I have uploaded a pdf file of the latest 2022 3rd edition MCC cricket laws.
Prompting and HuggingFace Login
Please Note: This is not the best prompt and please free to engineer some better prompts for specific use cases.
system_prompt="""
You are a Q&A assistant. Your goal is to answer questions as
accurately as possible based on the instructions and context provided.
"""
## Default prompt supported by llama2 and 3
query_wrapper_prompt = SimpleInputPrompt("<|USER|>{query_str}<|ASSISTANT|>")
## Access tokens for HuggingFace
!huggingface-cli login
On executing the login code, we get the drop window where we need to the hugging face access token generated above along with Y/n for adding token as git credential.
Instantiating the model with Parameters and Embedding
llm = HuggingFaceLLM(
context_window=4096,
max_new_tokens=256,
generate_kwargs={"temperature": 0.0, "do_sample": False},
system_prompt=system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
device_map="auto",
##maps it to the correct device (TPU, GPU etc.)
##loading model in 8bit for reducing memory
model_kwargs={"torch_dtype": torch.float16 ,"load_in_8bit":True}
)
This piece of code is instantiating the LLM. You can play around with these parameters based on your requirements. We are using the Meta-Llama-3-8b-Instruct model here for this case.
Embedding is when words and sentences are turned into numbers to help computers understand them better. Similar words are put near each other in a vector space so that computers can see how they relate to each other. It's like giving each word its own code so the computer can understand the connections between them and learn the language in the process.
Here, we will be using the open-source "sentence-transformers/all-mpnet-base-v2" as our embedding model. You can find more embedding models by searching for sentence transformers models on hugging face.
##Embedding model
embed_model= LangchainEmbedding(HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2"))
Consolidating Documents and Models
Service Context consolidates the llm model and the embedding model for text processing and NLP application. The chunk size specifies the size of text chunks that will be processed at a time. It is an important parameter for handling large documents efficiently.
service_context=ServiceContext.from_defaults(
chunk_size=1024,
llm=llm,
embed_model=embed_model
)
The service_context carries the models and settings for processing documents and queries. Documents are broken down into chunks, which are then converted into high-dimensional vectors (embeddings). These embeddings are stored in a VectorStoreIndex for efficient searching. When a query is processed, it is also converted into an embedding, and the index is searched to find and return the positions of the best-matching chunks in the documents.
index=VectorStoreIndex.from_documents(documents,service_context=service_context)
The variable index now stores the index of best matched chunk within the document .
Querying and Responses
The Query Engine acts like a taskmaster that manages and coordinates the processes of embedding, matching, indexing, and searching behind the scenes. It takes your query and calls the necessary functions to perform all these tasks. It then returns the best match, which is converted back into text and provided as a response by the chatbot.
query_engine=index.as_query_engine()
Finally, we pass our query. to the query engine and it does all the processes described above and returns us a text response which we have stored in the response variable.
response=query_engine.query("What are the rules regarding hitting the ball twice?")
Now, we can print the response variable as the final output or chatbot response.
print(response)
The output generated by the bot to our question, "What are the rules regarding hitting the ball twice?", is as follows:
Output:
According to Law 34 of the Laws of Cricket 2017 Code (3rd Edition - 2022), a batsman is out Hit the ball twice if, while the ball is in play, it strikes any part of his/her person or is struck by his/her bat and, before the ball has been touched by a fielder, the striker wilfully strikes it again with his/her bat or person, other than a hand not holding the bat, except for the sole purpose of guarding his/her wicket. (Law 34.1.1) The striker will not be out under this Law if he/she strikes the ball a second or subsequent time in order to return the ball to any fielder. (Law 34.2.1) However, if the striker wilfully strikes the ball after it has touched a fielder, he/she will be out. (Law 34.2.2) The striker may lawfully strike the ball a second or subsequent time with the bat, or with any part of his/her person other than a hand not holding the bat, solely in order to guard his/her wicket. (Law 34.3) If the ball is lawfully struck more than once, the umpire shall call and signal Dead.
If you've followed along this far, you should now have an understanding of what a Retrieval-Augmented Generation (RAG) model is, why it's beneficial, and how to build a simple one in under 50 lines of code.
Additionally, you could extend this by incorporating multiple documents or integrating it with a Streamlit interface for using it in production.
If you have any questions or feedback, I'd be happy to hear them.
Please see the Github Repository and the Google Colaboratory notebook.
Signing Off,
Yash
Comments