Improved chat mode

This commit is contained in:
Romain Quinet 2023-10-07 08:48:08 +02:00
parent bf3fd878ac
commit 1a46ea4816

13
main.py
View File

@ -4,9 +4,13 @@ from llama_index.node_parser import SimpleNodeParser
from llama_index import VectorStoreIndex from llama_index import VectorStoreIndex
from llama_index.llms import OpenAI, ChatMessage, MessageRole from llama_index.llms import OpenAI, ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate from llama_index.prompts import ChatPromptTemplate
from llama_index import set_global_handler
from llama_index.chat_engine.types import ChatMode
import os import os
import re import re
#set_global_handler("simple")
llm = OpenAI(model="gpt-4", temperature=0, max_tokens=256) llm = OpenAI(model="gpt-4", temperature=0, max_tokens=256)
service_context = ServiceContext.from_defaults(llm=llm) service_context = ServiceContext.from_defaults(llm=llm)
set_global_service_context(service_context) set_global_service_context(service_context)
@ -90,16 +94,13 @@ refine_template = ChatPromptTemplate(chat_refine_msgs)
chat_engine = index.as_chat_engine( chat_engine = index.as_chat_engine(
text_qa_template=text_qa_template, text_qa_template=text_qa_template,
refine_template=refine_template refine_template=refine_template,
chat_mode=ChatMode.OPENAI
) )
while True: while True:
try: try:
user_prompt = input("Prompt: ") chat_engine.chat_repl()
streaming_response = chat_engine.stream_chat(user_prompt)
for token in streaming_response.response_gen:
print(token, end="")
print("\n")
except KeyboardInterrupt: except KeyboardInterrupt:
break break