added example
This commit is contained in:
190
text_sql_agent.py
Normal file
190
text_sql_agent.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#sql_interface virtual environment
|
||||
|
||||
import os
|
||||
from langchain.chat_models import init_chat_model
|
||||
import streamlit as st
|
||||
from pathlib import Path
|
||||
from langchain_community.callbacks import StreamlitCallbackHandler
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
model = init_chat_model("gpt-4.1")
|
||||
|
||||
from langchain.agents import create_agent
|
||||
|
||||
|
||||
db = SQLDatabase.from_uri("clickhouse+native://192.168.1.242:9000/nuggets_for_ai")
|
||||
|
||||
print(f"Dialect: {db.dialect}")
|
||||
print(f"Available tables: {db.get_usable_table_names()}")
|
||||
print(f'Sample output: {db.run("SELECT * FROM object_detection_results LIMIT 5;")}')
|
||||
|
||||
|
||||
|
||||
|
||||
toolkit = SQLDatabaseToolkit(db=db, llm=model)
|
||||
|
||||
tools = toolkit.get_tools()
|
||||
|
||||
|
||||
system_prompt = """
|
||||
You are an agent designed to interact with a SQL database.
|
||||
Given an input question, create a syntactically correct {dialect} query to run,
|
||||
then look at the results of the query and return the answer. Unless the user
|
||||
specifies a specific number of examples they wish to obtain, always limit your
|
||||
query to at most {top_k} results.
|
||||
|
||||
You can order the results by a relevant column to return the most interesting
|
||||
examples in the database. Never query for all the columns from a specific table,
|
||||
only ask for the relevant columns given the question.
|
||||
|
||||
You MUST double check your query before executing it. If you get an error while
|
||||
executing a query, rewrite the query and try again.
|
||||
|
||||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
|
||||
database.
|
||||
|
||||
To start you should ALWAYS look at the tables in the database to see what you
|
||||
can query. Do NOT skip this step.
|
||||
|
||||
Then you should query the schema of the most relevant tables.
|
||||
|
||||
For the "object_name" column, use the scientific name of the birds whenever available
|
||||
""".format(
|
||||
dialect=db.dialect,
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
|
||||
import json
|
||||
import streamlit as st
|
||||
from langchain_core.messages import AIMessage, ToolMessage, BaseMessage
|
||||
|
||||
def render_tool_calls(messages: list[BaseMessage], container):
|
||||
"""Render tool calls + tool results from a LangChain/LangGraph message list."""
|
||||
with container:
|
||||
# 1) tool calls (usually on AIMessage)
|
||||
for m in messages:
|
||||
if isinstance(m, AIMessage) and getattr(m, "tool_calls", None):
|
||||
for tc in m.tool_calls:
|
||||
name = tc.get("name", "tool")
|
||||
args = tc.get("args", {})
|
||||
tid = tc.get("id", None)
|
||||
|
||||
with st.status(f"Tool call: {name}", expanded=False) as s:
|
||||
if tid:
|
||||
st.caption(f"id: {tid}")
|
||||
# args can be dict or string depending on version
|
||||
if isinstance(args, (dict, list)):
|
||||
st.json(args)
|
||||
else:
|
||||
st.code(str(args))
|
||||
s.update(state="running")
|
||||
|
||||
# 2) tool results (ToolMessage)
|
||||
for m in messages:
|
||||
if isinstance(m, ToolMessage):
|
||||
name = getattr(m, "name", "tool")
|
||||
with st.status(f"Tool result: {name}", expanded=False) as s:
|
||||
# content may be JSON-ish text; try pretty print if possible
|
||||
try:
|
||||
st.json(json.loads(m.content))
|
||||
except Exception:
|
||||
st.code(m.content)
|
||||
s.update(state="complete")
|
||||
|
||||
|
||||
agent_executor = create_agent(
|
||||
model,
|
||||
tools,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
# %%
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage
|
||||
def to_role(m: BaseMessage) -> str:
|
||||
if isinstance(m, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(m, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(m, SystemMessage):
|
||||
return "system"
|
||||
return "assistant"
|
||||
|
||||
|
||||
st.set_page_config(page_title="LangGraph Chatbot")
|
||||
st.title("Chatbot")
|
||||
|
||||
st.session_state.agent = agent_executor
|
||||
st.session_state.messages = list()
|
||||
|
||||
#messages = agent_executor.invoke({"messages":"What is the cardinal seasonality to showing up?"})
|
||||
|
||||
st_callback = StreamlitCallbackHandler(st.container())
|
||||
|
||||
|
||||
def get_last_assistant_text(messages: list[BaseMessage]) -> str:
|
||||
for m in reversed(messages):
|
||||
if isinstance(m, AIMessage):
|
||||
return m.content or ""
|
||||
return ""
|
||||
# render history
|
||||
for m in st.session_state.messages:
|
||||
role = to_role(m)
|
||||
if role == "system":
|
||||
continue
|
||||
with st.chat_message(role):
|
||||
st.markdown(m.content)
|
||||
|
||||
def last_ai_text(msgs: list[BaseMessage]) -> str:
|
||||
for m in reversed(msgs):
|
||||
if isinstance(m, AIMessage):
|
||||
return m.content or ""
|
||||
return ""
|
||||
|
||||
|
||||
user_text = st.text_input("Ask about birds", "What is the seasonality of Mourning doves visiting? I want total duration in hours.")
|
||||
if user_text:
|
||||
st.session_state.messages.append(HumanMessage(content=user_text))
|
||||
with st.chat_message("user"):
|
||||
st.markdown(user_text)
|
||||
with st.chat_message("assistant"):
|
||||
answer_box = st.empty()
|
||||
debug_box = st.expander("Intermediate steps", expanded=True)
|
||||
latest_messages = None
|
||||
tools_box = st.expander("Tools", expanded=True)
|
||||
|
||||
placeholder = st.empty()
|
||||
latest_state = None
|
||||
for event in st.session_state.agent.stream({"messages": st.session_state.messages}):
|
||||
for _, payload in event.items():
|
||||
if isinstance(payload, dict) and "messages" in payload:
|
||||
latest_messages = payload["messages"]
|
||||
|
||||
# update assistant answer
|
||||
answer_box.markdown(last_ai_text(latest_messages))
|
||||
|
||||
# show tools nicely (rerenders; simplest approach)
|
||||
tools_box.empty()
|
||||
render_tool_calls(latest_messages, tools_box)
|
||||
|
||||
|
||||
# After streaming ends, persist final state (avoid duplicates)
|
||||
if latest_state and "messages" in latest_state:
|
||||
st.session_state.messages = latest_state["messages"]
|
||||
|
||||
|
||||
#question = "What is the seasonality of Cardinals visting? I want total duration"
|
||||
# What is the seasonality of cardinals showing up to the railing?
|
||||
#for step in agent.stream(
|
||||
# {"messages": [{"role": "user", "content": question}]},
|
||||
# stream_mode="values",
|
||||
#):
|
||||
# step["messages"][-1].pretty_print()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user