darknet_diaries_llm/main.py

107 lines
3.6 KiB
Python
Raw Normal View History

2023-10-06 21:35:53 +02:00
from llama_index import (SimpleDirectoryReader, ServiceContext, StorageContext, PromptTemplate,
load_index_from_storage, Document, set_global_service_context)
from llama_index.node_parser import SimpleNodeParser
from llama_index import VectorStoreIndex
2023-10-06 23:22:10 +02:00
from llama_index.llms import OpenAI, ChatMessage, MessageRole
2023-10-06 23:45:47 +02:00
from llama_index.prompts import ChatPromptTemplate
2023-10-06 21:35:53 +02:00
import os
import re
2023-10-06 22:52:42 +02:00
llm = OpenAI(model="gpt-4", temperature=0, max_tokens=256)
2023-10-06 21:35:53 +02:00
service_context = ServiceContext.from_defaults(llm=llm)
set_global_service_context(service_context)
if not os.path.exists("./index/lock"):
2023-10-06 22:43:39 +02:00
print("Generating index...")
2023-10-06 21:35:53 +02:00
documents = []
for filename in os.listdir("./data"):
episode_number = re.search(r'\d+', filename).group()
with open("./data/" + filename, 'r') as f:
content = f.read()
document = Document(
text=content,
metadata={
"episode_number": episode_number
}
)
documents = SimpleDirectoryReader('./data').load_data()
parser = SimpleNodeParser.from_defaults()
nodes = parser.get_nodes_from_documents(documents)
index = VectorStoreIndex(nodes, show_progress=True)
index.storage_context.persist(persist_dir="./index")
open("./index/lock", 'a').close()
else:
2023-10-06 22:43:39 +02:00
print("Loading index...")
2023-10-06 21:35:53 +02:00
storage_context = StorageContext.from_defaults(persist_dir="./index")
index = load_index_from_storage(storage_context)
2023-10-06 23:22:10 +02:00
custom_prompt = PromptTemplate(
2023-10-06 22:43:39 +02:00
"----------------\n"
2023-10-06 23:22:10 +02:00
"Chat history: {chat_history}\n"
2023-10-06 22:43:39 +02:00
"----------------\n"
2023-10-06 23:22:10 +02:00
"Please answer this question by referring to the podcast: {question}"
2023-10-06 21:35:53 +02:00
)
2023-10-06 23:22:10 +02:00
2023-10-06 23:45:47 +02:00
chat_text_qa_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"You have been trained on the Darknet Diaries podcast transcripts with data from october 6 2023."
"You are an expert about it and will answer as such. You know about every episode up to number 138."
"Always answer the question, even if the context isn't helpful."
)
),
ChatMessage(
role=MessageRole.USER,
content=(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge,"
"answer the question: {query_str}\n"
)
)
]
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
2023-10-06 23:22:10 +02:00
2023-10-06 23:45:47 +02:00
chat_refine_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content="Always answer the question, even if the context isn't helpful.",
),
ChatMessage(
role=MessageRole.USER,
content=(
"We have the opportunity to refine the original answer "
"(only if needed) with some more context below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question: {query_str}. "
"If the context isn't useful, output the original answer again.\n"
"Original Answer: {existing_answer}"
),
),
]
refine_template = ChatPromptTemplate(chat_refine_msgs)
chat_engine = index.as_chat_engine(
text_qa_template=text_qa_template,
refine_template=refine_template
)
2023-10-06 21:35:53 +02:00
while True:
try:
user_prompt = input("Prompt: ")
2023-10-06 23:22:10 +02:00
streaming_response = chat_engine.stream_chat(user_prompt)
for token in streaming_response.response_gen:
print(token, end="")
print("\n")
2023-10-06 21:35:53 +02:00
except KeyboardInterrupt:
break