darknet_diaries_llm/main.py

114 lines
4.2 KiB
Python
Raw Normal View History

2023-10-07 00:57:45 +02:00
from llama_index import (ServiceContext, StorageContext,
2023-10-06 21:35:53 +02:00
load_index_from_storage, Document, set_global_service_context)
from llama_index.node_parser import SimpleNodeParser
from llama_index import VectorStoreIndex
2023-10-06 23:22:10 +02:00
from llama_index.llms import OpenAI, ChatMessage, MessageRole
2023-10-06 23:45:47 +02:00
from llama_index.prompts import ChatPromptTemplate
2023-10-07 15:14:18 +02:00
# from llama_index import set_global_handler
2023-10-07 08:48:08 +02:00
from llama_index.chat_engine.types import ChatMode
2023-10-07 16:55:18 +02:00
from dotenv import load_dotenv
2023-10-06 21:35:53 +02:00
import os
import re
2023-10-07 08:50:13 +02:00
# set_global_handler("simple")
2023-10-07 08:48:08 +02:00
2023-10-07 16:55:18 +02:00
# load .env
load_dotenv()
OPEN_API_KEY = os.getenv('OPEN_API_KEY')
# config llm context
llm = OpenAI(model="gpt-4", temperature=0, max_tokens=256, api_key="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")
2023-10-06 21:35:53 +02:00
service_context = ServiceContext.from_defaults(llm=llm)
set_global_service_context(service_context)
2023-10-07 16:55:18 +02:00
# TODO split in small functions
if __name__ == '__main__':
if not os.path.exists("./index/lock"):
documents = []
for filename in os.listdir("./transcripts"):
episode_number = re.search(r'\d+', filename).group()
with open("./transcripts/" + filename, 'r') as f:
title = f.readline().strip()
downloads = f.readline().strip()
content = f.read()
document = Document(
text=content,
doc_id=filename,
metadata={
"episode_number": episode_number,
"episode_title": title,
"episode_downloads": downloads,
"episode_url": f"https://darknetdiaries.com/episode/{episode_number}/"
}
)
documents.append(document)
2023-10-06 21:35:53 +02:00
2023-10-07 16:55:18 +02:00
parser = SimpleNodeParser.from_defaults()
nodes = parser.get_nodes_from_documents(documents)
2023-10-06 21:35:53 +02:00
2023-10-07 16:55:18 +02:00
index = VectorStoreIndex(nodes, show_progress=True)
index.storage_context.persist(persist_dir="./index")
open("./index/lock", 'a').close()
else:
print("Loading index...")
storage_context = StorageContext.from_defaults(persist_dir="./index")
index = load_index_from_storage(storage_context)
2023-10-06 21:35:53 +02:00
2023-10-07 16:55:18 +02:00
chat_text_qa_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"You have been trained on the Darknet Diaries podcast transcripts with data from october 6 2023."
"You are an expert about it and will answer as such. You know about every episode up to number 138."
"Always answer the question, even if the context isn't helpful."
"Mention the number and title of the episodes you are referring to."
)
),
ChatMessage(
role=MessageRole.USER,
content=(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge,"
"answer the question: {query_str}\n"
)
2023-10-06 23:45:47 +02:00
)
2023-10-07 16:55:18 +02:00
]
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
2023-10-06 23:22:10 +02:00
2023-10-07 16:55:18 +02:00
chat_refine_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content="Always answer the question, even if the context isn't helpful.",
),
ChatMessage(
role=MessageRole.USER,
content=(
"We have the opportunity to refine the original answer "
"(only if needed) with some more context below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question: {query_str}. "
"If the context isn't useful, output the original answer again.\n"
"Original Answer: {existing_answer}"
),
2023-10-06 23:45:47 +02:00
),
2023-10-07 16:55:18 +02:00
]
refine_template = ChatPromptTemplate(chat_refine_msgs)
2023-10-06 23:45:47 +02:00
2023-10-07 16:55:18 +02:00
chat_engine = index.as_chat_engine(
text_qa_template=text_qa_template,
refine_template=refine_template,
chat_mode=ChatMode.OPENAI
)
2023-10-06 21:35:53 +02:00
2023-10-07 16:55:18 +02:00
while True:
try:
chat_engine.chat_repl()
except KeyboardInterrupt:
break