forked from phito/darknet_diaries_llm
58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
from llama_index import (SimpleDirectoryReader, ServiceContext, StorageContext, PromptTemplate,
|
|
load_index_from_storage, Document, set_global_service_context)
|
|
from llama_index.node_parser import SimpleNodeParser
|
|
from llama_index import VectorStoreIndex
|
|
from llama_index.llms import OpenAI
|
|
import os
|
|
import re
|
|
|
|
llm = OpenAI(model="gpt-4", temperature=0, max_tokens=256)
|
|
service_context = ServiceContext.from_defaults(llm=llm)
|
|
set_global_service_context(service_context)
|
|
|
|
if not os.path.exists("./index/lock"):
|
|
print("Generating index...")
|
|
documents = []
|
|
for filename in os.listdir("./data"):
|
|
episode_number = re.search(r'\d+', filename).group()
|
|
with open("./data/" + filename, 'r') as f:
|
|
content = f.read()
|
|
document = Document(
|
|
text=content,
|
|
metadata={
|
|
"episode_number": episode_number
|
|
}
|
|
)
|
|
|
|
documents = SimpleDirectoryReader('./data').load_data()
|
|
parser = SimpleNodeParser.from_defaults()
|
|
nodes = parser.get_nodes_from_documents(documents)
|
|
|
|
index = VectorStoreIndex(nodes, show_progress=True)
|
|
index.storage_context.persist(persist_dir="./index")
|
|
open("./index/lock", 'a').close()
|
|
else:
|
|
print("Loading index...")
|
|
storage_context = StorageContext.from_defaults(persist_dir="./index")
|
|
index = load_index_from_storage(storage_context)
|
|
|
|
template = (
|
|
"You have been trained on the Darknet Diaries podcast transcripts with data from october 6 2023."
|
|
"You are now an expert about it and will answer as such. You know about every episode up to number 138. \n"
|
|
"----------------\n"
|
|
"Here is the context: {context_str}"
|
|
"----------------\n"
|
|
"Please answer this question by referring to the podcast: {query_str}"
|
|
)
|
|
qa_template = PromptTemplate(template)
|
|
query_engine = index.as_query_engine(text_qa_template=qa_template)
|
|
|
|
while True:
|
|
try:
|
|
user_prompt = input("Prompt: ")
|
|
response = query_engine.query(user_prompt)
|
|
print(response)
|
|
except KeyboardInterrupt:
|
|
break
|
|
|