ライブラリ
各サービスの記述の違いは、LangChain と言うライブラリが吸収してくれます。
Grok が OpenAI と Anthropic 互換なのが裏目に出てしまいます。
# Grokにるリファクタリング後のコード
import sys
import os
import csv
import time
import importlib
from datetime import datetime
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder
)
from langchain.schema.runnable import (
RunnableLambda,
RunnableSequence
)
from langchain.schema import AIMessage, HumanMessage
import pyperclip
# AI Assistants configuration
AI_ASSISTANTS = {
'ChatGPT': {'module': 'langchain_openai', 'class': 'ChatOpenAI', 'model': 'gpt-4'},
'Gemini': {'module': 'langchain_google_genai', 'class': 'ChatGoogleGenerativeAI', 'model': 'gemini-pro'},
'Groq': {'module': 'langchain_groq', 'class': 'ChatGroq', 'model': 'llama3-70b-8192'},
'Mistral': {'module': 'langchain_mistralai', 'class': 'ChatMistralAI', 'model': 'open-mixtral-8x22b'},
'Together': {'module': 'langchain_together', 'class': 'ChatTogether', 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'}
}
MAX_HISTORY_LENGTH = 5
# Q&A
def put_user_question(user_question, ai_assistant, conversation, memory):
start_time = time.time()
try:
response = conversation.invoke({"input": user_question, "history": memory.load_memory_variables({})['history']})
ai_response = response['history'][-1].content if response['history'] and isinstance(response['history'][-1], AIMessage) else 'No output found'
memory.save_context({"input": user_question}, {"output": ai_response})
exec_status = 'Success'
except Exception as e:
print(f"エラーが発生しました: {e}")
return None, 'Failure'
end_time = time.time()
elapsed_time = end_time - start_time
print(f"\nアシスタント({ai_assistant}):\n{ai_response}")
pyperclip.copy(ai_response)
print(f"\n--------------------\n実行時間: {elapsed_time:.3f}秒\n--------------------")
return ai_response, exec_status
# Log書き出し
def log_execution(ai_assistant, model_name, id_num, status, exec_time, run_time, log_filename="AI_Assistant_Log.csv"):
with open(log_filename, 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow([ai_assistant, model_name, id_num, status, exec_time, run_time])
# メモリをラップする関数を定義
def get_memory(input):
history = memory.load_memory_variables({})['history']
return {"input": input["input"], "history": history}
# メモリを更新する関数
def update_memory(user_input, ai_response):
memory.save_context({"input": user_input}, {"output": ai_response})
return {"input": user_input, "history": memory.load_memory_variables({})['history']}
def parse_arguments():
if len(sys.argv) == 1:
return "Groq", AI_ASSISTANTS["Groq"]['model']
elif len(sys.argv) == 2:
return sys.argv[1], AI_ASSISTANTS[sys.argv[1]]['model']
else:
return sys.argv[1], sys.argv[2]
def load_assistant(ai_assistant, model_name):
module_name = AI_ASSISTANTS[ai_assistant]['module']
class_name = AI_ASSISTANTS[ai_assistant]['class']
module = importlib.import_module(module_name)
AssistantClass = getattr(module, class_name)
return AssistantClass(model=model_name)
def create_prompt(ai_assistant):
if ai_assistant != 'Gemini':
with open("system_message.txt", "r", encoding="utf-8") as file:
system_content = file.read()
return ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_content),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}")
])
return ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}")
])
def main():
global memory # メモリをグローバルにする(これは理想的ではないが、元の構造を保つため)
ai_assistant, model_name = parse_arguments()
print("AI Assistant:", ai_assistant)
print("Model name:", model_name)
log_filename = 'AI_Assistant_Log.csv'
if not os.path.exists(log_filename):
with open(log_filename, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['AI Assistant', 'Model', 'Question', 'ExecStatus', 'ExecTime', 'RunTime'])
llm = load_assistant(ai_assistant, model_name)
prompt = create_prompt(ai_assistant)
memory = ConversationBufferWindowMemory(k=MAX_HISTORY_LENGTH, return_messages=True)
conversation = RunnableSequence(
RunnableLambda(lambda x: x), # 入力そのままを次のステップに渡す
prompt,
llm,
RunnableLambda(lambda response: update_memory(response.additional_kwargs.get("input", "No input found"), response.content if hasattr(response, 'content') else 'No output found'))
)
user_question = "こんにちは"
initial_input = {"input": user_question, "history": memory.load_memory_variables({})['history']}
response = conversation.invoke(initial_input)
ai_response = response['history'][-1].content if response['history'] and isinstance(response['history'][-1], AIMessage) else 'No output found'
print(ai_response)
pyperclip.copy(ai_response)
id_num = 1
while True:
user_question = ""
print("\n---------------------------\n質問を入力してください(入力を終了するに改行3つ入力してください): \n")
while True:
try:
line = input()
user_question += line + "\n"
if user_question.endswith("\n\n\n"):
break
elif line.strip().lower() in ["さようなら", "bye"]:
break
except EOFError:
user_question = ""
break
if user_question.strip().lower() in ["", "さようなら", "bye"]:
print("さようなら")
pyperclip.copy("さようなら")
break
if user_question:
start_time = time.time()
response, exec_status = put_user_question(user_question, ai_assistant, conversation, memory)
end_time = time.time()
run_time = round(end_time - start_time, 3)
exec_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # 現在時刻を取得
# ログ書き出し
log_execution(ai_assistant, model_name, id_num, exec_status, exec_time, run_time)
id_num += 1
if __name__ == "__main__":
main()