From 3d734a306406291503e4da67c6509fefbd22568c Mon Sep 17 00:00:00 2001 From: Romain Quinet Date: Fri, 6 Oct 2023 23:22:10 +0200 Subject: [PATCH] Chat mode --- main.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index ad794a8..dfc8c39 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,9 @@ from llama_index import (SimpleDirectoryReader, ServiceContext, StorageContext, load_index_from_storage, Document, set_global_service_context) from llama_index.node_parser import SimpleNodeParser from llama_index import VectorStoreIndex -from llama_index.llms import OpenAI +from llama_index.llms import OpenAI, ChatMessage, MessageRole +from llama_index.prompts import PromptTemplate +from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine import os import re @@ -36,22 +38,32 @@ else: storage_context = StorageContext.from_defaults(persist_dir="./index") index = load_index_from_storage(storage_context) -template = ( +custom_prompt = PromptTemplate( "You have been trained on the Darknet Diaries podcast transcripts with data from october 6 2023." "You are now an expert about it and will answer as such. You know about every episode up to number 138. \n" "----------------\n" - "Here is the context: {context_str}" + "Chat history: {chat_history}\n" "----------------\n" - "Please answer this question by referring to the podcast: {query_str}" + "Please answer this question by referring to the podcast: {question}" ) -qa_template = PromptTemplate(template) -query_engine = index.as_query_engine(text_qa_template=qa_template) + +custom_chat_history = [] +query_engine = index.as_query_engine() +chat_engine = CondenseQuestionChatEngine.from_defaults( + query_engine=query_engine, + condense_question_prompt=custom_prompt, + chat_history=custom_chat_history, + verbose=True +) + while True: try: user_prompt = input("Prompt: ") - response = query_engine.query(user_prompt) - print(response) + streaming_response = chat_engine.stream_chat(user_prompt) + for token in streaming_response.response_gen: + print(token, end="") + print("\n") except KeyboardInterrupt: break