#!/usr/bin/env python3 """NL2SQL 전용 워커 프로세스 Usage: python nl2sql_worker.py 담당 도구: run_sql, query_pv_history, get_tag_metadata, list_drawings, query_with_nl 특징: - PostgreSQL 직접 연결 - LLM SQL 생성 + DB 실행 분리 - 메모리: ~1GB (SQL 생성용 LLM) - 생명주기: 메인 서버 종료 시까지 유지 """ from __future__ import annotations import sys import os import re # mcp-server 디렉토리를 Python 경로에 추가 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import logging import asyncio from functools import lru_cache from fastapi import FastAPI, Request import uvicorn import httpx # ── 설정 ───────────────────────────────────────────────────────────────────── DB_CONNECTION_STRING = os.environ.get("DB_CONNECTION_STRING", "postgresql://postgres:postgres@localhost:5432/iiot_platform") DB_TIMEOUT = int(os.environ.get("DB_TIMEOUT", "10")) VLLM_BASE_URL = os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1") from config import get_vllm_model VLLM_MODEL = get_vllm_model() logging.basicConfig( level=logging.INFO, stream=sys.stderr, format="%(asctime)s [nl2sql_worker] %(levelname)s %(message)s", ) app = FastAPI() # ── DB 연결 풀 ─────────────────────────────────────────────────────────────── def _get_db_connection(): import psycopg return psycopg.connect(DB_CONNECTION_STRING, connect_timeout=DB_TIMEOUT) async def _aget_db_connection(): """비동기 환경에서 안전하게 DB 연결 획득 (blocking connect를 to_thread로 격리).""" import asyncio return await asyncio.to_thread(_get_db_connection) # ── SQL 가드 ───────────────────────────────────────────────────────────────── SQL_MAX_ROWS = int(os.environ.get("SQL_MAX_ROWS", "1000")) SQL_STATEMENT_TIMEOUT_MS = int(os.environ.get("SQL_STATEMENT_TIMEOUT_MS", "30000")) _RE_LIMIT_TAIL = re.compile(r"\bLIMIT\b\s+\d+(\s+OFFSET\s+\d+)?\s*$", re.IGNORECASE) _DANGEROUS_KW = ('EXEC', 'DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'CREATE', 'GRANT', 'REVOKE', 'TRUNCATE', 'COPY') def _validate_sql(sql: str) -> tuple[bool, str]: """SELECT/WITH만 허용, 위험 키워드/다중 문장 차단.""" if not sql or len(sql) > 2000: return False, "쿼리가 비어있거나 2000자를 초과했습니다." upper = sql.upper() for kw in _DANGEROUS_KW: if re.search(rf"\b{kw}\b", upper): return False, f"허용되지 않은 키워드 '{kw}'" head = upper.lstrip().lstrip('(').lstrip() if not (head.startswith('SELECT') or head.startswith('WITH')): return False, "SELECT 또는 WITH 쿼리만 허용됩니다." if ';' in sql.rstrip().rstrip(';'): return False, "다중 문장(세미콜론)은 허용되지 않습니다." return True, "" def _apply_sql_guards(sql: str, max_rows: int = SQL_MAX_ROWS) -> str: s = sql.strip().rstrip(';').strip() if _RE_LIMIT_TAIL.search(s): return s return f"SELECT * FROM ({s}) _capped LIMIT {max_rows}" # ── LLM 클라이언트 ─────────────────────────────────────────────────────────── @lru_cache(maxsize=1) def _llm_client(): from openai import AsyncOpenAI return AsyncOpenAI(base_url=VLLM_BASE_URL, api_key="dummy") # DB 스키마 + SQL system 프롬프트 — worker/sql_prompt.py 로 단일화(production+eval 공유) from sql_prompt import DB_SCHEMA, SQL_SYSTEM_PROMPT # noqa: E402,F401 async def _generate_sql(natural_language: str) -> str: """자연어를 SQL로 변환.""" client = _llm_client() system = SQL_SYSTEM_PROMPT response = await client.chat.completions.create( model=VLLM_MODEL, messages=[ {"role": "system", "content": system}, {"role": "user", "content": natural_language}, ], max_tokens=8192, temperature=0.1, ) sql = response.choices[0].message.content.strip() # 마크다운 코드 블록 제거 if sql.startswith("```"): lines = sql.splitlines() sql = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip() return sql # ── NL2SQL 도구 구현 ───────────────────────────────────────────────────────── @app.get("/health") async def health(): """워커 헬스체크.""" return {"status": "ok"} @app.post("/execute") async def execute(request: Request): """HTTP 요청을 MCP 도구 호출로 변환.""" body = await request.json() tool = body["tool"] params = body["params"] try: if tool == "run_sql": result = await _run_sql(**params) elif tool == "query_pv_history": result = await _query_pv_history(**params) elif tool == "get_tag_metadata": result = await _get_tag_metadata(**params) elif tool == "list_drawings": result = await _list_drawings(**params) elif tool == "query_with_nl": result = await _query_with_nl(**params) else: return {"success": False, "error": f"Unknown tool: {tool}"} return result except Exception as e: logging.error(f"Error executing {tool}: {e}") return {"success": False, "error": str(e)} async def _run_sql(sql: str) -> str: """SQL 실행 (가드: SELECT/WITH만, auto-LIMIT, statement_timeout).""" valid, err = _validate_sql(sql) if not valid: return {"success": False, "error": f"SQL 검증 실패: {err}"} capped_sql = _apply_sql_guards(sql) conn = await _aget_db_connection() try: with conn.cursor() as cur: cur.execute(f"SET statement_timeout = {SQL_STATEMENT_TIMEOUT_MS}") cur.execute(capped_sql) if cur.description: columns = [desc[0] for desc in cur.description] rows = cur.fetchall() data = [dict(zip(columns, row)) for row in rows] return { "success": True, "columns": columns, "count": len(data), "row_limit": SQL_MAX_ROWS, "data": data, } else: conn.commit() return { "success": True, "message": f"Query executed successfully. {cur.rowcount} rows affected.", } except Exception as e: return {"success": False, "error": f"SQL 실행 실패: {e}"} finally: conn.close() async def _query_pv_history(tag_names: list[str], time_from: str, time_to: str, limit: int = 100) -> str: """과거 값(PV) 히스토리 조회.""" if not tag_names: return {"success": False, "error": "tag_names is required"} conn = await _aget_db_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT recorded_at AS time, tagname AS tag_name, value FROM history_table WHERE tagname = ANY(%s) AND recorded_at >= %s AND recorded_at <= %s ORDER BY recorded_at DESC, tagname LIMIT %s """, (tag_names, time_from, time_to, limit), ) columns = ["time", "tag_name", "value"] rows = cur.fetchall() data = [dict(zip(columns, row)) for row in rows] return { "success": True, "tag_names": tag_names, "time_range": {"from": time_from, "to": time_to}, "limit": limit, "count": len(data), "data": data, } finally: conn.close() async def _get_tag_metadata(query: str, limit: int = 10) -> str: """태그 메타데이터 검색.""" conn = await _aget_db_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT tagname, livevalue, timestamp, node_id FROM realtime_table WHERE tagname ILIKE %s ORDER BY tagname LIMIT %s """, (f"%{query}%", limit), ) columns = ["tag_name", "current_value", "last_updated", "node_id"] rows = cur.fetchall() data = [ { "tag_name": r[0], "current_value": r[1], "last_updated": r[2].isoformat() if r[2] else None, "node_id": r[3], } for r in rows ] return { "success": True, "query": query, "count": len(data), "tags": data, } finally: conn.close() async def _list_drawings(unit_no: str = None) -> str: """단위별 도면 목록 조회.""" conn = await _aget_db_connection() try: with conn.cursor() as cur: if unit_no: cur.execute( """ SELECT DISTINCT name FROM node_map_master WHERE name LIKE %s ORDER BY name """, (f"{unit_no}%",), ) else: cur.execute( """ SELECT DISTINCT name FROM node_map_master ORDER BY name """ ) rows = cur.fetchall() names = [row[0] for row in rows] return { "success": True, "unit_no": unit_no, "count": len(names), "names": names, } finally: conn.close() async def _query_with_nl(question: str) -> str: """자연어로 SQL 쿼리 실행.""" import json sql = await _generate_sql(question) # SQL이 비어있으면 오류 반환 if not sql: return json.dumps({"success": False, "sql": "", "error": "LLM이 SQL을 생성하지 못했습니다."}, ensure_ascii=False) # LLM 생성 SQL도 동일 가드 적용 valid, err = _validate_sql(sql) if not valid: return {"success": False, "sql": sql, "error": f"SQL 검증 실패: {err}"} capped_sql = _apply_sql_guards(sql) conn = await _aget_db_connection() try: with conn.cursor() as cur: cur.execute(f"SET statement_timeout = {SQL_STATEMENT_TIMEOUT_MS}") cur.execute(capped_sql) if cur.description: columns = [desc[0] for desc in cur.description] rows = cur.fetchall() data = [dict(zip(columns, row)) for row in rows] return { "success": True, "sql": sql, "columns": columns, "count": len(data), "row_limit": SQL_MAX_ROWS, "data": data, } else: conn.commit() return { "success": True, "sql": sql, "message": f"Query executed successfully. {cur.rowcount} rows affected.", } except Exception as db_error: return { "success": False, "sql": sql, "error": str(db_error), } finally: conn.close() # ── 메인 ───────────────────────────────────────────────────────────────────── if __name__ == "__main__": port = int(sys.argv[1]) if len(sys.argv) > 1 else 5003 logging.info(f"Starting NL2SQL worker on port {port}") uvicorn.run(app, host="0.0.0.0", port=port)