Chat mode

This commit is contained in:
Romain Quinet 2023-10-06 23:22:10 +02:00
parent b41d8288b7
commit 3d734a3064

28
main.py
View File

@ -2,7 +2,9 @@ from llama_index import (SimpleDirectoryReader, ServiceContext, StorageContext,
load_index_from_storage, Document, set_global_service_context) load_index_from_storage, Document, set_global_service_context)
from llama_index.node_parser import SimpleNodeParser from llama_index.node_parser import SimpleNodeParser
from llama_index import VectorStoreIndex from llama_index import VectorStoreIndex
from llama_index.llms import OpenAI from llama_index.llms import OpenAI, ChatMessage, MessageRole
from llama_index.prompts import PromptTemplate
from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine
import os import os
import re import re
@ -36,22 +38,32 @@ else:
storage_context = StorageContext.from_defaults(persist_dir="./index") storage_context = StorageContext.from_defaults(persist_dir="./index")
index = load_index_from_storage(storage_context) index = load_index_from_storage(storage_context)
template = ( custom_prompt = PromptTemplate(
"You have been trained on the Darknet Diaries podcast transcripts with data from october 6 2023." "You have been trained on the Darknet Diaries podcast transcripts with data from october 6 2023."
"You are now an expert about it and will answer as such. You know about every episode up to number 138. \n" "You are now an expert about it and will answer as such. You know about every episode up to number 138. \n"
"----------------\n" "----------------\n"
"Here is the context: {context_str}" "Chat history: {chat_history}\n"
"----------------\n" "----------------\n"
"Please answer this question by referring to the podcast: {query_str}" "Please answer this question by referring to the podcast: {question}"
) )
qa_template = PromptTemplate(template)
query_engine = index.as_query_engine(text_qa_template=qa_template) custom_chat_history = []
query_engine = index.as_query_engine()
chat_engine = CondenseQuestionChatEngine.from_defaults(
query_engine=query_engine,
condense_question_prompt=custom_prompt,
chat_history=custom_chat_history,
verbose=True
)
while True: while True:
try: try:
user_prompt = input("Prompt: ") user_prompt = input("Prompt: ")
response = query_engine.query(user_prompt) streaming_response = chat_engine.stream_chat(user_prompt)
print(response) for token in streaming_response.response_gen:
print(token, end="")
print("\n")
except KeyboardInterrupt: except KeyboardInterrupt:
break break