74 lines
3.5 KiB
Python
74 lines
3.5 KiB
Python
import streamlit as st
|
|
from typing import Set
|
|
import requests
|
|
import json
|
|
import yaml
|
|
import st_auth
|
|
import boto3
|
|
from botocore.exceptions import ClientError
|
|
import jwt
|
|
headers = st.context.headers
|
|
import streamlit as st
|
|
|
|
# Using a list of options
|
|
|
|
options = ["English", "Portugues", "Espanhol",]
|
|
language = st.selectbox("Response language:", options)
|
|
id_token = headers.get('x-amzn-oidc-data')
|
|
decoded = jwt.decode(id_token, options={"verify_signature": False})
|
|
|
|
# Tenta diferentes campos onde o user_id pode estar
|
|
user_id = (
|
|
decoded.get("sub") or # Subject (padrão JWT)
|
|
decoded.get("cognito:username") or # Username do Cognito
|
|
decoded.get("username") or # Username alternativo
|
|
decoded.get("user_id") # Campo customizado
|
|
)
|
|
email=decoded.get("email")
|
|
st.header("Assistente Produtos Servicos")
|
|
url="https://xexm2wsz07-vpce-05915540d0592b921.execute-api.us-east-1.amazonaws.com/dev"
|
|
payload=[]
|
|
message_history=[]
|
|
if "user_prompt_history" not in st.session_state:
|
|
st.session_state["user_prompt_history"]=[]
|
|
if "chat_answer_history" not in st.session_state:
|
|
st.session_state["chat_answer_history"]=[]
|
|
if "chat_history" not in st.session_state:
|
|
st.session_state["chat_history"] = []
|
|
prompt=st.chat_input(placeholder="Digite uma mensagem...",key="prompt")
|
|
for generated_response, user_query in zip(st.session_state["chat_answer_history"],st.session_state["user_prompt_history"]):
|
|
st.chat_message("user").write(user_query)
|
|
st.chat_message("assistant").write(generated_response)
|
|
def create_sources_string(source_urls: Set[str])->str:
|
|
if not source_urls:
|
|
return ""
|
|
source_list=list(source_urls)
|
|
source_list.sort()
|
|
sources_string="source:\n"
|
|
for i, source in enumerate(source_list):
|
|
sources_string+=f"{i+1}, {source}\n"
|
|
return sources_string
|
|
if prompt:
|
|
st.chat_message("user").write(prompt)
|
|
with st.spinner("Generating response.."):
|
|
payload=[{"role":"user","content":prompt}]
|
|
content={"message":payload,"chat_history":st.session_state["chat_history"],"username":user_id,"origem":"Front","email":email,"language":language}
|
|
headers={"Content-type":"application/json","x-api-key":json.loads(st_auth.get_secret())['api-gateway-api-key']}
|
|
generated_response=json.loads(requests.post(url,json=content,headers=headers).text)
|
|
if 'chat_history' in generated_response:
|
|
if st.session_state["chat_history"] == []:
|
|
st.session_state["chat_history"] = generated_response['chat_history']
|
|
if 'json' in generated_response:
|
|
generated_response=generated_response['json']
|
|
if 'message' in generated_response and generated_response['message']=="Endpoint request timed out":
|
|
generated_response="Falta de dados: Por favor encaminhe a conversa para a equipe desenvolvedora"
|
|
#generated_response=[{"role":"user","content":prompt}]
|
|
# sources= set([doc.metadata["source"] for doc in generated_response['context']])
|
|
#formatted_response=f"{generated_response['answer']} \n\n {create_sources_string(sources)}"
|
|
formatted_response=generated_response
|
|
st.chat_message("assistant").write(formatted_response)
|
|
st.session_state["user_prompt_history"].append(prompt)
|
|
st.session_state["chat_answer_history"].append(formatted_response)
|
|
st.session_state["chat_history"]=st.session_state["chat_history"]+[{"role":"user","content":prompt}]+[{"role":"assistant","content":formatted_response}]
|
|
st.session_state.user_input=""
|
|
|