diff --git a/text_sql_agent.py b/text_sql_agent.py new file mode 100644 index 0000000..7b0546b --- /dev/null +++ b/text_sql_agent.py @@ -0,0 +1,190 @@ +#sql_interface virtual environment + +import os +from langchain.chat_models import init_chat_model +import streamlit as st +from pathlib import Path +from langchain_community.callbacks import StreamlitCallbackHandler +from langchain_community.agent_toolkits import SQLDatabaseToolkit + +from sqlalchemy import create_engine +from langchain_community.utilities import SQLDatabase +model = init_chat_model("gpt-4.1") + +from langchain.agents import create_agent + + +db = SQLDatabase.from_uri("clickhouse+native://192.168.1.242:9000/nuggets_for_ai") + +print(f"Dialect: {db.dialect}") +print(f"Available tables: {db.get_usable_table_names()}") +print(f'Sample output: {db.run("SELECT * FROM object_detection_results LIMIT 5;")}') + + + + +toolkit = SQLDatabaseToolkit(db=db, llm=model) + +tools = toolkit.get_tools() + + +system_prompt = """ +You are an agent designed to interact with a SQL database. +Given an input question, create a syntactically correct {dialect} query to run, +then look at the results of the query and return the answer. Unless the user +specifies a specific number of examples they wish to obtain, always limit your +query to at most {top_k} results. + +You can order the results by a relevant column to return the most interesting +examples in the database. Never query for all the columns from a specific table, +only ask for the relevant columns given the question. + +You MUST double check your query before executing it. If you get an error while +executing a query, rewrite the query and try again. + +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the +database. + +To start you should ALWAYS look at the tables in the database to see what you +can query. Do NOT skip this step. + +Then you should query the schema of the most relevant tables. + +For the "object_name" column, use the scientific name of the birds whenever available +""".format( + dialect=db.dialect, + top_k=5, +) + + +import json +import streamlit as st +from langchain_core.messages import AIMessage, ToolMessage, BaseMessage + +def render_tool_calls(messages: list[BaseMessage], container): + """Render tool calls + tool results from a LangChain/LangGraph message list.""" + with container: + # 1) tool calls (usually on AIMessage) + for m in messages: + if isinstance(m, AIMessage) and getattr(m, "tool_calls", None): + for tc in m.tool_calls: + name = tc.get("name", "tool") + args = tc.get("args", {}) + tid = tc.get("id", None) + + with st.status(f"Tool call: {name}", expanded=False) as s: + if tid: + st.caption(f"id: {tid}") + # args can be dict or string depending on version + if isinstance(args, (dict, list)): + st.json(args) + else: + st.code(str(args)) + s.update(state="running") + + # 2) tool results (ToolMessage) + for m in messages: + if isinstance(m, ToolMessage): + name = getattr(m, "name", "tool") + with st.status(f"Tool result: {name}", expanded=False) as s: + # content may be JSON-ish text; try pretty print if possible + try: + st.json(json.loads(m.content)) + except Exception: + st.code(m.content) + s.update(state="complete") + + +agent_executor = create_agent( + model, + tools, + system_prompt=system_prompt +) + +# %% +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage +def to_role(m: BaseMessage) -> str: + if isinstance(m, AIMessage): + return "assistant" + if isinstance(m, HumanMessage): + return "user" + if isinstance(m, SystemMessage): + return "system" + return "assistant" + + +st.set_page_config(page_title="LangGraph Chatbot") +st.title("Chatbot") + +st.session_state.agent = agent_executor +st.session_state.messages = list() + +#messages = agent_executor.invoke({"messages":"What is the cardinal seasonality to showing up?"}) + +st_callback = StreamlitCallbackHandler(st.container()) + + +def get_last_assistant_text(messages: list[BaseMessage]) -> str: + for m in reversed(messages): + if isinstance(m, AIMessage): + return m.content or "" + return "" +# render history +for m in st.session_state.messages: + role = to_role(m) + if role == "system": + continue + with st.chat_message(role): + st.markdown(m.content) + +def last_ai_text(msgs: list[BaseMessage]) -> str: + for m in reversed(msgs): + if isinstance(m, AIMessage): + return m.content or "" + return "" + + +user_text = st.text_input("Ask about birds", "What is the seasonality of Mourning doves visiting? I want total duration in hours.") +if user_text: + st.session_state.messages.append(HumanMessage(content=user_text)) + with st.chat_message("user"): + st.markdown(user_text) + with st.chat_message("assistant"): + answer_box = st.empty() + debug_box = st.expander("Intermediate steps", expanded=True) + latest_messages = None + tools_box = st.expander("Tools", expanded=True) + + placeholder = st.empty() + latest_state = None + for event in st.session_state.agent.stream({"messages": st.session_state.messages}): + for _, payload in event.items(): + if isinstance(payload, dict) and "messages" in payload: + latest_messages = payload["messages"] + + # update assistant answer + answer_box.markdown(last_ai_text(latest_messages)) + + # show tools nicely (rerenders; simplest approach) + tools_box.empty() + render_tool_calls(latest_messages, tools_box) + + + # After streaming ends, persist final state (avoid duplicates) + if latest_state and "messages" in latest_state: + st.session_state.messages = latest_state["messages"] + + +#question = "What is the seasonality of Cardinals visting? I want total duration" +# What is the seasonality of cardinals showing up to the railing? +#for step in agent.stream( +# {"messages": [{"role": "user", "content": question}]}, +# stream_mode="values", +#): +# step["messages"][-1].pretty_print() + + + + + +