import streamlit as st import streamlit.components.v1 as components import shap from datashap import DataSHAP as ds import matplotlib.pyplot as plt import warnings warnings.simplefilter("ignore", category=DeprecationWarning) def select_fala(): if st.session_state.df_falas['selection']['rows']: st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0] num = st.session_state.num_fala df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df rotulo = df.iloc[num]['tag'] option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',} st.session_state.rotulo = option_map[rotulo] @st.cache_resource def get_dataSHAP(file, company, trim): shap_value=ds.DataSHAP(file, company, trim) shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'}) return shap_value def init_session(key, val): if key not in st.session_state: st.session_state[key] = [] for i in range(1,5): arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save' shap_value = get_dataSHAP(arquivo, empresa_dict[key], i) st.session_state[key].append(shap_value) st.set_page_config(page_title="TCCPoliUSPPro", ) pasta = {'vale':'VALE', 'petr':'Petrobras', 'bb':'BB'} empresa_dict = {'petr':'Petrobras', 'vale':'Vale', 'bb':'Banco do Brasil'} option_map = {'NEUTRAL':'Neutro', 'POSITIVE':'Positivo', 'NEGATIVE':'Negativo',} shap_values = {} title_score = ['positive_score', 'negative_score', 'neutral_score'] for key, val in pasta.items(): init_session(key, val) shap_values[key] = st.session_state[key] st.header("Sentimento da fala e Scores") col1, col2, col3, col4 = st.columns([1.7,1.2,1.2,2], gap="small", vertical_alignment="bottom") empresa = col1.selectbox( "**Qual empresa quer analisar:**", ("vale", "bb", "petr"), format_func=lambda option: empresa_dict[option], key='empresa', ) trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimestre') text_num = col3.number_input( "**Fala número:**", 0, max_value = len(shap_values[empresa][trim-1].shap_value)-1, key='num_fala',) df=shap_values[empresa][trim-1].df total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance() col4.write(f"**Total tokens:** {total_tokens} \ \n**Compute time:** {h}h {m}m {s:.2}s") tab1, tab2, tab3 = st.tabs(["**Data Frame**", "**Estatística Score**", '**Gráfico Estatística**']) with tab1: st.dataframe(df.style.highlight_max(axis = 1, color ='lightgreen', subset = title_score), selection_mode = 'single-row', key='df_falas', on_select=select_fala, column_config={'speech':st.column_config.Column('Fala', width=100), 'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'), 'positive_score':st.column_config.NumberColumn("Score Positivo",), 'negative_score':st.column_config.NumberColumn("Score Negativo",), 'neutral_score':st.column_config.NumberColumn("Score Neutro",), 'tag':"Rótulo", }, height=200,) with tab2: st.dataframe(shap_values[empresa][trim-1].statistic, ) with tab3: st.plotly_chart(shap_values[empresa][trim-1].plot) score_positive, score_negative, score_neutral = df.loc[text_num, title_score] rotulo = st.radio( "**Rótulo**", option_map.keys(), horizontal=True, format_func=lambda option: option_map[option], captions = [f'{score_neutral:.4}', f'{score_positive:.4}', f'{score_negative:.4}'], key='rotulo' ) plot_text = shap_values[empresa][trim-1].shap_plot_text(text_num, rotulo) components.html(plot_text, height = 180, scrolling = True) st.header("Gráfico waterfall dos termos e Valores de Shapley") with st.expander("Expand"): max_display = st.slider( "**Máximo de exibição:**", 1, max_value = int(df['qty_tokens'][text_num]), value=int(int(df['qty_tokens'][text_num])/3)+1 ) plot_waterfall = shap_values[empresa][trim-1].shap_waterfall(text_num, rotulo, max_display) st.pyplot(plot_waterfall) st.header('Rank de termos do documento em Gráfico Barra') with st.expander("Expand"): plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank() for key, val in option_map.items(): st.subheader(val) st.pyplot(plot_bar[key]) st.dataframe(rank)