from langgraph.graph import StateGraph, START, END from bielik import llm from guardian import check_input from helpful_functions import get_last_user_message, check_situation, beliefs_check_function, introduction_talk, create_interview from neo4j_driver import driver from classifier import predict_raw, predict_raw1 from state import ChatState from prompts import build_system_prompt_introduction_chapter_ellis_distortion def detect_distortion(state: ChatState): if not state.get("messages"): print("Siema") state["messages"] = [{ "role": "assistant", "content": "Cześć! Cieszę się, że jesteś. Co u ciebie, czy masz jakiś problem? Z checią ci pomogę!" }] state["awaitingUser"] = True state["stage"] = "detect_distortion" return state else: state["first_stage_iterations"] += 1 print(state["first_stage_iterations"]) print("Siema1") last_message = get_last_user_message(state) user_text = (last_message["content"] or "").strip() if state["distortion"] is None: result = predict_raw(user_text) if result != "No Distortion": thought = beliefs_check_function(user_text) if thought: distortion = predict_raw1(user_text) print(distortion) state["distortion"] = distortion state["distortion_text"] = user_text print("Siema2") system_prompt = build_system_prompt_introduction_chapter_ellis_distortion(state["distortion"], state["situation"], state["think"], state["emotion"]) result = introduction_talk(state["messages"], system_prompt) if state["situation"] == "": state["situation"] = result.situation else: if result.situation != "": state["situation"] = create_interview(result.situation, state["situation"]) if state["emotion"] == "": state["emotion"] = result.emotion else: if result.emotion != "": state["emotion"] = create_interview(result.emotion, state["emotion"]) if state["think"] == "": state["think"] = result.think else: if result.think != "": state["think"] = create_interview(result.think, state["think"]) state["introduction_end_flag"] = result.chapter_end if state["distortion"] is not None and state["situation"] != "" and state["think"] != "" and state["emotion"] != "": print("Next") state["awaitingUser"] = False state["messages_detect"] = state["messages"] state["stage"] = "get_distortion_def" return state else: state["messages"].append({"role":"assistant", "content": result.model_output}) state["awaitingUser"] = True state["stage"] = "detect_distortion" return state def get_distortion_def(state: ChatState): print("Siema4") distortion = state["distortion"] query = """ MATCH (d:Distortion {name: $name}) RETURN d.definicja AS definicja """ records, summary, keys = driver.execute_query( query, parameters_={"name": distortion}, ) state["distortion_def"] = records[0]["definicja"] if records else None state["stage"] = "talk_about_distortion" state["awaitingUser"] = False return state def talk_about_distortion(state: ChatState): distortion = state["distortion"] distortion_def = state["distortion_def"] print("Siema5") if not state.get("distortion_explained"): print("Siema6") system_prompt_talk = f""" Jesteś empatycznym asystentem CBT. Użytkownikowi wykryto zniekształcenie poznawcze: Nazwa: {distortion} Definicja: {distortion_def} Przedstaw mu, że wykryłeś u niego zniekształcenie i wyjaśnij je w prosty, życzliwy sposób i zapytaj, czy chce, abyś pomógł mu to wspólnie przepracować. Język: polski, maksymalnie 2–3 zdania. """ llm_reply = llm.invoke([ { "role": "system", "content": system_prompt_talk, }, ]) follow_text = ( llm_reply if isinstance(llm_reply, str) else getattr(llm_reply, "content", str(llm_reply)) ) state["messages"].append({"role": "assistant", "content": follow_text}) state["awaitingUser"] = True state["stage"] = "talk_about_distortion" state["distortion_explained"] = True return state else: print("Siema7") last_user_msg = get_last_user_message(state) if not last_user_msg: state["awaitingUser"] = True return state classify_result = check_situation(last_user_msg["content"]) state["classify_result"] = classify_result if classify_result == "understand": print("Siema8") state["messages"].append({ "role": "assistant", "content": "Super! To przejdźmy teraz do kolejnego kroku" }) state["stage"] = "get_intention" state["awaitingUser"] = False return state # elif classify_result == "low_expression": # system_prompt = f""" # WEJSCIE # Historia wiadomości - {state["messages"]} # # Użytkownik jest mało wylewny i odpowiada krótko. # Twoim zadaniem jest napisać 2–3 empatyczne zdania po polsku, które spokojnie i nienachalnie zachęcą go do kontynuowania rozmowy. # Brzmi naturalnie, bez punktów, presji ani oceniania. # Na końcu zapytaj czy możemy możemy przejść do działania # Twoją rolą jest tylko i wyłącznie zachęcenie do działania nie pisz nic innego # """ # llm_reply = llm.invoke([ # { # "role": "system", # "content": system_prompt, # }, # ]) # follow_text = ( # llm_reply if isinstance(llm_reply, str) # else getattr(llm_reply, "content", str(llm_reply)) # ) # state["messages"].append({"role": "assistant", "content": follow_text}) # state["awaitingUser"] = True # state["stage"] = "talk_about_distortion" else: print("Siema9") system_prompt = f""" WEJSCIE Historia wiadomości - {state["messages"]} Użytkownik nie zrozumiał wyjaśnienia zniekształcenia. Nazwa: {distortion} Definicja: {distortion_def} Język tylko polski. Twoje zadanie: - Wyjaśnij prostszymi słowami (1–2 zdania). - Dodaj przykład z życia (1–2 zdania). - Zapytaj, czy teraz jest to jasne i czy możemy przejść do działania. Maksymalnie 3-4 zdania """ llm_reply = llm.invoke([ { "role": "system", "content": system_prompt, }, ]) follow_text = ( llm_reply if isinstance(llm_reply, str) else getattr(llm_reply, "content", str(llm_reply)) ) state["messages"].append({"role": "assistant", "content": follow_text}) state["awaitingUser"] = True state["stage"] = "talk_about_distortion" return state def validate_input(state: ChatState): stage = state.get("stage") if stage == "detect_distortion": chapter = "ETAP 1" elif stage == "talk_about_distortion" or stage == "get_distortion_def": chapter = "ETAP 2" elif stage == "create_socratic_question" or stage == "get_intention" or stage == "select_intention" or stage == "analyze_output": chapter = "ETAP 3" elif stage == "enter_alt_thought" or stage == "enter_alt_thought" or stage == "handle_alt_thought_input" or stage == "handle_alt_thought_input": chapter = "ETAP 4" else: chapter = "None" last_user_msg = state.get("last_user_msg_content") result = check_input(state["messages"], chapter, last_user_msg) state["last_user_msg"] = False if result.decision: state["validated"] = True state["awaitingUser"] = False else: state["noValidated"] = f"{chapter} - {last_user_msg}" state["explanation"] = result.explanation state["messages"].append({"role": "assistant", "content": result.message_to_user}) state["awaitingUser"] = True return state def global_router(state: ChatState) -> str: if state.get("awaitingUser"): print("[ROUTER] awaitingUser=True → __end__") return "__end__" stage = state.get("stage") print(f"[ROUTER] stage={stage} (fallback)") if not state.get("validated") and state.get("last_user_msg"): return "validate_input" if stage == "end": return "__end__" if stage == "get_distortion_def": return "get_distortion_def" if stage == "talk_about_distortion": return "talk_about_distortion" print("[ROUTER] default → detect_distortion") return "detect_distortion" graph_builder = StateGraph(ChatState) graph_builder.add_node("detect_distortion", detect_distortion) graph_builder.add_node("get_distortion_def", get_distortion_def) graph_builder.add_node("talk_about_distortion", talk_about_distortion) graph_builder.add_node("validate_input", validate_input) graph_builder.add_conditional_edges(START, global_router, { "detect_distortion": "detect_distortion", "get_distortion_def": "get_distortion_def", "talk_about_distortion": "talk_about_distortion", "validate_input": "validate_input", "__end__": END, }) edge_map = { "detect_distortion": "detect_distortion", "get_distortion_def": "get_distortion_def", "talk_about_distortion": "talk_about_distortion", "validate_input": "validate_input", "__end__": END, } for node in ["detect_distortion", "get_distortion_def","talk_about_distortion", "validate_input"]: graph_builder.add_conditional_edges(node, global_router, edge_map) graph = graph_builder.compile()