Skip to content

Commit

Permalink
[Feature] Standalone MCP server script
Browse files Browse the repository at this point in the history
  • Loading branch information
amidabuddha committed Jan 1, 2025
1 parent c1caf0e commit e523815
Showing 1 changed file with 58 additions and 34 deletions.
92 changes: 58 additions & 34 deletions mcp_servers/mcp_tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import logging
import os
import shutil
import signal
Expand All @@ -12,16 +13,38 @@
from mcp import ClientSession, StdioServerParameters, Tool
from mcp.client.stdio import stdio_client

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from mcp_errors import (CommandNotFoundError, ConfigError, MCPError,
ServerInitError, ToolExecutionError)

from console_gpt.config_manager import _join_and_check
# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Set the base logger level to DEBUG

BASE_PATH = os.path.dirname(os.path.realpath(f"{__file__}/.."))
MCP_SAMPLE_PATH = _join_and_check(BASE_PATH, "mcp_config.json.sample", target="mcp_config.json")
MCP_PATH = _join_and_check(BASE_PATH, "mcp_config.json", create="mcp_config.json")
# Create a FileHandler for DEBUG logs
file_handler = logging.FileHandler(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mcp_tcp_server.log'),
mode='w')
file_handler.setLevel(logging.DEBUG) # Log all messages (DEBUG and above) to the file
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))

# Create a StreamHandler for INFO logs
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO) # Only log INFO and above to the console
stream_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))

# Add both handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)

# Get configurations from mcp_config.json
MCP_PATH = os.path.join(os.path.dirname(os.path.realpath(f"{__file__}/..")), "mcp_config.json")

if not os.path.exists(MCP_PATH):
if os.path.exists(MCP_PATH + ".sample"):
shutil.copy(MCP_PATH + ".sample", MCP_PATH)
logger.info('"mcp_config.json" created from sample')
else:
logger.error('"mcp_config.json.sample" is either missing or renamed, please update from source.')
exit(1)

class MCPServer:
def __init__(self, server_name, server_config):
Expand All @@ -31,6 +54,7 @@ def __init__(self, server_name, server_config):
self.session = None
self.tools = {}
self.client_entered = False # Track if client context was entered successfully
self.logger = logging.getLogger(f"{__name__}.MCPServer.{server_name}")

async def __aenter__(self):
return self
Expand All @@ -44,23 +68,23 @@ async def cleanup(self):
try:
await self.session.__aexit__(None, None, None)
except Exception as e:
print(f"Error during session cleanup for {self.server_name}: {e}")
self.logger.error(f"Error during session cleanup: {e}")
self.session = None
if self.client and self.client_entered: # Check if client context was entered
try:
await self.client.__aexit__(None, None, None)
except Exception as e:
print(f"Error during client cleanup for {self.server_name}: {e}")
self.logger.error(f"Error during client cleanup: {e}")
self.client = None


class MCPTCPServer:
def __init__(self, host: str = "localhost", port: int = 8765):
self.host = host
self.port = port
self.servers: Dict[str, MCPServer] = {}
self.initialization_timeout = 30 # 30 seconds timeout for tool initialization
self.server_processes: Dict[str, subprocess.Popen] = {}
self.logger = logging.getLogger(f"{__name__}.MCPTCPServer")

@staticmethod
def validate_config(config: Dict[str, Dict[str, Any]]) -> None:
Expand Down Expand Up @@ -222,27 +246,27 @@ async def init_server(self, server_name: str, server_config: Dict[str, Any]) ->
except (asyncio.CancelledError, Exception) as e:
# Handle cancellation or any other exception
if isinstance(e, asyncio.CancelledError):
print(f"Server initialization for {server_name} was cancelled.")
self.logger.warning(f"Server initialization for {server_name} was cancelled.")
else:
print(f"Server initialization for {server_name} failed: {e}")
self.logger.error(f"Server initialization for {server_name} failed: {e}")

# Terminate the subprocess if it's still running
if server_name in self.server_processes:
process = self.server_processes[server_name]
if process.poll() is None: # Check if the process is still running
print(f"Terminating subprocess for server {server_name}")
self.logger.info(f"Terminating subprocess for server {server_name}")
process.terminate()
try:
process.wait(timeout=5) # Wait for process to terminate with a timeout
except subprocess.TimeoutExpired:
print(f"Forcefully killing subprocess for server {server_name}")
self.logger.warning(f"Forcefully killing subprocess for server {server_name}")
process.kill()

# Clean up resources
await server.cleanup()

if isinstance(e, CommandNotFoundError):
print(f"Command not found error for server {server_name}: {str(e)}")
self.logger.error(f"Command not found error for server {server_name}: {str(e)}")
raise
elif isinstance(e, asyncio.CancelledError):
raise # Re-raise CancelledError to be handled by the caller if needed
Expand All @@ -267,38 +291,38 @@ async def initialize_tools(self) -> Tuple[List[Dict[str, Any]], List[Exception]]
return [], initialization_errors

async def init_with_timeout(server_name: str, server_config: Dict[str, Any]):
print(f"Initializing server: {server_name}")
self.logger.info(f"Initializing server: {server_name}")
try:
server = await asyncio.wait_for(
self.init_server(server_name, server_config), timeout=self.initialization_timeout
)
self.servers[server_name] = server
print(f"Server {server_name} initialized successfully")
self.logger.info(f"Server {server_name} initialized successfully")
return [self.tool_to_dict(tool) for tool in server.tools.values()]
except asyncio.TimeoutError:
error = ServerInitError(
f"Server initialization timed out after {self.initialization_timeout} seconds", server_name
)
self.servers[server_name] = error # Store the error
initialization_errors.append(error)
print(f"TimeoutError initializing server {server_name}")
self.logger.warning(f"TimeoutError initializing server {server_name}")
return []
except Exception as e:
print(f"Exception caught in init_with_timeout for {server_name}: {e}")
self.logger.error(f"Exception caught in init_with_timeout for {server_name}: {e}")

# Try to get more details from specific exception types
if isinstance(e, CommandNotFoundError):
error_details = e.to_dict()
print(f"CommandNotFoundError details: {error_details}")
self.logger.error(f"CommandNotFoundError details: {error_details}")
elif isinstance(e, ServerInitError):
error_details = e.to_dict()
print(f"ServerInitError details: {error_details}")
self.logger.error(f"ServerInitError details: {error_details}")
else:
error_details = {
"error_type": type(e).__name__,
"message": str(e),
}
print(f"Other exception details: {error_details}")
self.logger.error(f"Other exception details: {error_details}")

error = ServerInitError(f"Server initialization failed: {e}", server_name)
error.details = error_details
Expand Down Expand Up @@ -402,7 +426,7 @@ async def handle_client(self, reader: asyncio.StreamReader, writer: asyncio.Stre
await writer.drain() # Make sure data is sent before continuing

except Exception as e:
print(f"Error handling client: {e}")
self.logger.error(f"Error handling client: {e}")
finally:
writer.close()
await writer.wait_closed()
Expand All @@ -417,14 +441,14 @@ async def cleanup(self):
results = await asyncio.gather(*cleanup_tasks, return_exceptions=True)
for res in results:
if isinstance(res, Exception):
print(f"Error during server cleanup: {res}")
self.logger.error(f"Error during server cleanup: {res}")

self.servers.clear()

# Terminate server processes
for server_name, process in self.server_processes.items():
if process.poll() is None: # Check if process is still running
print(f"Terminating server process: {server_name}")
self.logger.info(f"Terminating server process: {server_name}")
try:
if os.name == "nt":
process.send_signal(signal.CTRL_C_EVENT)
Expand All @@ -433,7 +457,7 @@ async def cleanup(self):

process.wait(timeout=5) # Wait for process to terminate
except subprocess.TimeoutExpired:
print(f"Force killing server process: {server_name}")
self.logger.warning(f"Force killing server process: {server_name}")
process.kill()
self.server_processes.clear()

Expand All @@ -446,42 +470,42 @@ async def start(self):
# Check if config load failed and prevent server start
config_error = next((e for e in errors if isinstance(e, ConfigError)), None)
if config_error:
print(f"Failed to start server: {config_error}")
self.logger.error(f"Failed to start server: {config_error}")
exit(1)

# Start TCP server even if some tools failed to initialize
server = await asyncio.start_server(self.handle_client, self.host, self.port)

# Print information about successful tool initialization
if tools:
print(f"Total tools initialized: {len(tools)}")
self.logger.info(f"Total tools initialized: {len(tools)}")
for tool_info in tools:
print(f" - {tool_info['name']}: {tool_info['description']}")
self.logger.info(f" - {tool_info['name']}: {tool_info['description']}")

# Print information about failed tool initializations
if errors:
print(f"Failed to initialize {len(errors)} servers:")
self.logger.warning(f"Failed to initialize {len(errors)} servers")
for error in errors:
if isinstance(error, Exception):
print(f" - {error}")
self.logger.warning(f" - {error}")

async with server:
print(f"Server running on {self.host}:{self.port}")
self.logger.info(f"Server running on {self.host}:{self.port}")
await server.serve_forever()

except Exception as e:
print(f"Server error: {e}")
self.logger.error(f"Server error: {e}")
await self.cleanup()
raise


if __name__ == "__main__":
server = MCPTCPServer()

async def main():
try:
await server.start()
except KeyboardInterrupt:
print("\nShutting down server...")
logger.info("\nShutting down server...")
await server.cleanup()

asyncio.run(main())
asyncio.run(main())

0 comments on commit e523815

Please sign in to comment.