340 lines
13 KiB
Python
340 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""NL2SQL 전용 워커 프로세스
|
|
|
|
Usage: python nl2sql_worker.py <port>
|
|
|
|
담당 도구:
|
|
run_sql, query_pv_history, get_tag_metadata, list_drawings, query_with_nl
|
|
|
|
특징:
|
|
- PostgreSQL 직접 연결
|
|
- LLM SQL 생성 + DB 실행 분리
|
|
- 메모리: ~1GB (SQL 생성용 LLM)
|
|
- 생명주기: 메인 서버 종료 시까지 유지
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import sys
|
|
import os
|
|
|
|
# mcp-server 디렉토리를 Python 경로에 추가
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import logging
|
|
import asyncio
|
|
from functools import lru_cache
|
|
|
|
from fastapi import FastAPI, Request
|
|
import uvicorn
|
|
import httpx
|
|
|
|
# ── 설정 ─────────────────────────────────────────────────────────────────────
|
|
|
|
DB_CONNECTION_STRING = os.environ.get("DB_CONNECTION_STRING", "postgresql://postgres:postgres@localhost:5432/iiot_platform")
|
|
DB_TIMEOUT = int(os.environ.get("DB_TIMEOUT", "10"))
|
|
|
|
VLLM_BASE_URL = os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1")
|
|
VLLM_MODEL = os.environ.get("VLLM_MODEL", "Qwen3.6-27B-FP8")
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
stream=sys.stderr,
|
|
format="%(asctime)s [nl2sql_worker] %(levelname)s %(message)s",
|
|
)
|
|
|
|
app = FastAPI()
|
|
|
|
# ── DB 연결 풀 ───────────────────────────────────────────────────────────────
|
|
|
|
def _get_db_connection():
|
|
import psycopg
|
|
return psycopg.connect(DB_CONNECTION_STRING, connect_timeout=DB_TIMEOUT)
|
|
|
|
# ── LLM 클라이언트 ───────────────────────────────────────────────────────────
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _llm_client():
|
|
from openai import AsyncOpenAI
|
|
return AsyncOpenAI(base_url=VLLM_BASE_URL, api_key="dummy")
|
|
|
|
# DB 스키마 — server.py::_DB_SCHEMA와 동일
|
|
DB_SCHEMA = """
|
|
PostgreSQL 시계열 데이터베이스 스키마
|
|
|
|
테이블: history_table (시계열 이력)
|
|
tagname TEXT - 태그명 (모두 소문자, 예: 'ficq-6113.pv') — 대소문자 구분
|
|
node_id TEXT - OPC UA 노드 ID
|
|
value TEXT - 측정값, 수치 연산 시 ::double precision 캐스트 필요
|
|
recorded_at TIMESTAMPTZ - 기록 시각(UTC), 스냅샷 주기 약 60초
|
|
|
|
테이블: realtime_table (실시간 최신값)
|
|
tagname TEXT - 태그명 (모두 소문자)
|
|
node_id TEXT - OPC UA 노드 ID
|
|
livevalue TEXT - 현재값
|
|
timestamp TIMESTAMPTZ - 최종 갱신 시각
|
|
|
|
N분 간격 집계 공식 (time_bucket 금지, date_trunc 사용):
|
|
1분 버킷: date_trunc('minute', recorded_at) AS bucket
|
|
2분 버킷: to_timestamp(FLOOR(EXTRACT(EPOCH FROM recorded_at)/120)*120) AS bucket
|
|
5분 버킷: to_timestamp(FLOOR(EXTRACT(EPOCH FROM recorded_at)/300)*300) AS bucket
|
|
10분 버킷: to_timestamp(FLOOR(EXTRACT(EPOCH FROM recorded_at)/600)*600) AS bucket
|
|
N분 버킷: to_timestamp(FLOOR(EXTRACT(EPOCH FROM recorded_at)/(N*60))*(N*60)) AS bucket
|
|
|
|
예시 (2분 간격, 여러 태그):
|
|
SELECT to_timestamp(FLOOR(EXTRACT(EPOCH FROM recorded_at)/120)*120) AS bucket,
|
|
tagname, AVG(value::double precision) AS avg_val
|
|
FROM history_table
|
|
WHERE tagname IN ('tag1', 'tag2')
|
|
AND recorded_at >= NOW() - INTERVAL '3 hours'
|
|
GROUP BY bucket, tagname ORDER BY bucket, tagname
|
|
|
|
규칙:
|
|
- SELECT만 허용 (INSERT/UPDATE/DELETE/DROP 등 불가)
|
|
- tagname은 모두 소문자로 정확히 입력
|
|
- value 컬럼은 TEXT이므로 집계 시 ::double precision 캐스트 필수
|
|
- time_bucket 함수 사용 금지 — 위의 to_timestamp/FLOOR/EPOCH 공식 사용
|
|
"""
|
|
|
|
async def _generate_sql(natural_language: str) -> str:
|
|
"""자연어를 SQL로 변환."""
|
|
client = _llm_client()
|
|
|
|
system = (
|
|
"You are a PostgreSQL SQL expert.\n"
|
|
"Convert the user's question into a SELECT SQL using the schema below.\n"
|
|
"IMPORTANT rules:\n"
|
|
"- Use ONLY PostgreSQL syntax. No DATE_FORMAT, no INTERVAL N DAY.\n"
|
|
"- Time column is 'recorded_at' (TIMESTAMPTZ). Do NOT use 'timestamp'.\n"
|
|
"- NEVER use time_bucket(). For N-minute buckets use to_timestamp/FLOOR/EPOCH formula.\n"
|
|
"- INTERVAL rule:\n"
|
|
" * If the question specifies an interval (e.g. '2분 간격', '5-minute interval'):\n"
|
|
" use: to_timestamp(FLOOR(EXTRACT(EPOCH FROM recorded_at)/(N*60))*(N*60)) AS bucket\n"
|
|
" with GROUP BY bucket, tagname and AVG(value::double precision) AS avg_val\n"
|
|
" * If NO interval is specified: SELECT recorded_at, tagname, value — NO GROUP BY.\n"
|
|
"- Current year is 2026. '4월 27일' means 2026-04-27.\n"
|
|
"- All times in DB are UTC. Korean input is KST (UTC+9). Convert: KST 12:00 = UTC 03:00.\n"
|
|
"- value column is TEXT; cast with ::double precision only when aggregating.\n"
|
|
"- All tagnames are lowercase (e.g. 'ficq-6113.pv'). Match exactly.\n"
|
|
"- PostgreSQL LIKE: dot has no special meaning, no escaping needed.\n"
|
|
"- Return ONLY the SQL statement. No explanation, no markdown.\n\n"
|
|
f"{DB_SCHEMA}"
|
|
)
|
|
|
|
response = await client.chat.completions.create(
|
|
model=VLLM_MODEL,
|
|
messages=[
|
|
{"role": "system", "content": system},
|
|
{"role": "user", "content": natural_language},
|
|
],
|
|
max_tokens=8192,
|
|
temperature=0.1,
|
|
)
|
|
sql = response.choices[0].message.content.strip()
|
|
# 마크다운 코드 블록 제거
|
|
if sql.startswith("```"):
|
|
lines = sql.splitlines()
|
|
sql = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip()
|
|
return sql
|
|
|
|
# ── NL2SQL 도구 구현 ─────────────────────────────────────────────────────────
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""워커 헬스체크."""
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/execute")
|
|
async def execute(request: Request):
|
|
"""HTTP 요청을 MCP 도구 호출로 변환."""
|
|
body = await request.json()
|
|
tool = body["tool"]
|
|
params = body["params"]
|
|
|
|
try:
|
|
if tool == "run_sql":
|
|
result = await _run_sql(**params)
|
|
elif tool == "query_pv_history":
|
|
result = await _query_pv_history(**params)
|
|
elif tool == "get_tag_metadata":
|
|
result = await _get_tag_metadata(**params)
|
|
elif tool == "list_drawings":
|
|
result = await _list_drawings(**params)
|
|
elif tool == "query_with_nl":
|
|
result = await _query_with_nl(**params)
|
|
else:
|
|
return {"success": False, "error": f"Unknown tool: {tool}"}
|
|
|
|
return result
|
|
except Exception as e:
|
|
logging.error(f"Error executing {tool}: {e}")
|
|
return {"success": False, "error": str(e)}
|
|
|
|
async def _run_sql(sql: str) -> str:
|
|
"""SQL 실행."""
|
|
conn = _get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql)
|
|
if cur.description:
|
|
columns = [desc[0] for desc in cur.description]
|
|
rows = cur.fetchall()
|
|
data = [dict(zip(columns, row)) for row in rows]
|
|
return {
|
|
"success": True,
|
|
"columns": columns,
|
|
"count": len(data),
|
|
"data": data,
|
|
}
|
|
else:
|
|
conn.commit()
|
|
return {
|
|
"success": True,
|
|
"message": f"Query executed successfully. {cur.rowcount} rows affected.",
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
async def _query_pv_history(tag_names: list[str], time_from: str, time_to: str, limit: int = 100) -> str:
|
|
"""과거 값(PV) 히스토리 조회."""
|
|
if not tag_names:
|
|
return {"success": False, "error": "tag_names is required"}
|
|
|
|
conn = _get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# TimescaleDB의 time_bucket 함수 사용
|
|
cur.execute(
|
|
"""
|
|
SELECT time_bucket('1 min', ts) AS time, tag_name, value
|
|
FROM realtime_table
|
|
WHERE tag_name = ANY(%s)
|
|
AND ts >= %s
|
|
AND ts <= %s
|
|
ORDER BY time DESC
|
|
LIMIT %s
|
|
""",
|
|
(tag_names, time_from, time_to, limit),
|
|
)
|
|
columns = ["time", "tag_name", "value"]
|
|
rows = cur.fetchall()
|
|
data = [dict(zip(columns, row)) for row in rows]
|
|
return {
|
|
"success": True,
|
|
"tag_names": tag_names,
|
|
"time_range": {"from": time_from, "to": time_to},
|
|
"limit": limit,
|
|
"count": len(data),
|
|
"data": data,
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
async def _get_tag_metadata(query: str, limit: int = 10) -> str:
|
|
"""태그 메타데이터 검색."""
|
|
conn = _get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT DISTINCT tag_name, unit, description
|
|
FROM realtime_table
|
|
WHERE tag_name ILIKE %s
|
|
ORDER BY tag_name
|
|
LIMIT %s
|
|
""",
|
|
(f"%{query}%", limit),
|
|
)
|
|
columns = ["tag_name", "unit", "description"]
|
|
rows = cur.fetchall()
|
|
data = [dict(zip(columns, row)) for row in rows]
|
|
return {
|
|
"success": True,
|
|
"query": query,
|
|
"count": len(data),
|
|
"tags": data,
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
async def _list_drawings(unit_no: str = None) -> str:
|
|
"""단위별 도면 목록 조회."""
|
|
conn = _get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
if unit_no:
|
|
cur.execute(
|
|
"""
|
|
SELECT DISTINCT name
|
|
FROM node_map_master
|
|
WHERE name LIKE %s
|
|
ORDER BY name
|
|
""",
|
|
(f"{unit_no}%",),
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"""
|
|
SELECT DISTINCT name
|
|
FROM node_map_master
|
|
ORDER BY name
|
|
"""
|
|
)
|
|
columns = ["name"]
|
|
rows = cur.fetchall()
|
|
data = [dict(zip(columns, row[0])) for row in rows]
|
|
return {
|
|
"success": True,
|
|
"unit_no": unit_no,
|
|
"count": len(data),
|
|
"names": [d["name"] for d in data],
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
async def _query_with_nl(question: str) -> str:
|
|
"""자연어로 SQL 쿼리 실행."""
|
|
import json
|
|
sql = await _generate_sql(question)
|
|
|
|
# SQL이 비어있으면 오류 반환
|
|
if not sql:
|
|
return json.dumps({"success": False, "sql": "", "error": "LLM이 SQL을 생성하지 못했습니다."}, ensure_ascii=False)
|
|
|
|
conn = _get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql)
|
|
if cur.description:
|
|
columns = [desc[0] for desc in cur.description]
|
|
rows = cur.fetchall()
|
|
data = [dict(zip(columns, row)) for row in rows]
|
|
return {
|
|
"success": True,
|
|
"sql": sql,
|
|
"columns": columns,
|
|
"count": len(data),
|
|
"data": data,
|
|
}
|
|
else:
|
|
conn.commit()
|
|
return {
|
|
"success": True,
|
|
"sql": sql,
|
|
"message": f"Query executed successfully. {cur.rowcount} rows affected.",
|
|
}
|
|
except Exception as db_error:
|
|
return {
|
|
"success": False,
|
|
"sql": sql,
|
|
"error": str(db_error),
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
# ── 메인 ─────────────────────────────────────────────────────────────────────
|
|
|
|
if __name__ == "__main__":
|
|
port = int(sys.argv[1]) if len(sys.argv) > 1 else 5003
|
|
logging.info(f"Starting NL2SQL worker on port {port}")
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|