342 lines
9.9 KiB
Python
342 lines
9.9 KiB
Python
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
|
from fastapi.staticfiles import StaticFiles
|
||
|
|
from fastapi.responses import StreamingResponse
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from typing import List, Optional, Dict, Any
|
||
|
|
import ollama
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import uvicorn
|
||
|
|
from datetime import datetime
|
||
|
|
import subprocess
|
||
|
|
import tempfile
|
||
|
|
import os
|
||
|
|
|
||
|
|
app = FastAPI(title="AI Code Assistant API")
|
||
|
|
|
||
|
|
# CORS pour permettre le frontend
|
||
|
|
app.add_middleware(
|
||
|
|
CORSMiddleware,
|
||
|
|
allow_origins=["*"],
|
||
|
|
allow_credentials=True,
|
||
|
|
allow_methods=["*"],
|
||
|
|
allow_headers=["*"],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Models
|
||
|
|
class Message(BaseModel):
|
||
|
|
role: str
|
||
|
|
content: str
|
||
|
|
timestamp: Optional[str] = None
|
||
|
|
|
||
|
|
class ChatRequest(BaseModel):
|
||
|
|
messages: List[Message]
|
||
|
|
model: str = "code-expert"
|
||
|
|
stream: bool = True
|
||
|
|
|
||
|
|
class CodeExecutionRequest(BaseModel):
|
||
|
|
code: str
|
||
|
|
language: str = "python"
|
||
|
|
|
||
|
|
class ConversationHistory:
|
||
|
|
def __init__(self):
|
||
|
|
self.conversations: Dict[str, List[Message]] = {}
|
||
|
|
|
||
|
|
def add_message(self, conversation_id: str, message: Message):
|
||
|
|
if conversation_id not in self.conversations:
|
||
|
|
self.conversations[conversation_id] = []
|
||
|
|
self.conversations[conversation_id].append(message)
|
||
|
|
|
||
|
|
def get_conversation(self, conversation_id: str) -> List[Message]:
|
||
|
|
return self.conversations.get(conversation_id, [])
|
||
|
|
|
||
|
|
history = ConversationHistory()
|
||
|
|
|
||
|
|
# WebSocket manager
|
||
|
|
class ConnectionManager:
|
||
|
|
def __init__(self):
|
||
|
|
self.active_connections: List[WebSocket] = []
|
||
|
|
|
||
|
|
async def connect(self, websocket: WebSocket):
|
||
|
|
await websocket.accept()
|
||
|
|
self.active_connections.append(websocket)
|
||
|
|
|
||
|
|
def disconnect(self, websocket: WebSocket):
|
||
|
|
self.active_connections.remove(websocket)
|
||
|
|
|
||
|
|
async def send_message(self, message: str, websocket: WebSocket):
|
||
|
|
await websocket.send_text(message)
|
||
|
|
|
||
|
|
manager = ConnectionManager()
|
||
|
|
|
||
|
|
# Routes
|
||
|
|
@app.get("/")
|
||
|
|
async def root():
|
||
|
|
return {"message": "AI Code Assistant API", "version": "1.0"}
|
||
|
|
|
||
|
|
@app.get("/models")
|
||
|
|
async def list_models():
|
||
|
|
"""Liste tous les modèles Ollama disponibles."""
|
||
|
|
try:
|
||
|
|
models = ollama.list()
|
||
|
|
return {
|
||
|
|
"models": [
|
||
|
|
{
|
||
|
|
"name": m['name'],
|
||
|
|
"size": m.get('size', 0),
|
||
|
|
"modified": m.get('modified_at', '')
|
||
|
|
}
|
||
|
|
for m in models['models']
|
||
|
|
]
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.post("/chat")
|
||
|
|
async def chat(request: ChatRequest):
|
||
|
|
"""Endpoint de chat standard (non-streaming ou streaming)."""
|
||
|
|
try:
|
||
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||
|
|
|
||
|
|
if request.stream:
|
||
|
|
async def generate():
|
||
|
|
stream = ollama.chat(
|
||
|
|
model=request.model,
|
||
|
|
messages=messages,
|
||
|
|
stream=True
|
||
|
|
)
|
||
|
|
|
||
|
|
for chunk in stream:
|
||
|
|
if 'message' in chunk and 'content' in chunk['message']:
|
||
|
|
content = chunk['message']['content']
|
||
|
|
yield f"data: {json.dumps({'content': content})}\n\n"
|
||
|
|
|
||
|
|
yield "data: [DONE]\n\n"
|
||
|
|
|
||
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
|
else:
|
||
|
|
response = ollama.chat(
|
||
|
|
model=request.model,
|
||
|
|
messages=messages
|
||
|
|
)
|
||
|
|
return {"response": response['message']['content']}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.websocket("/ws/chat")
|
||
|
|
async def websocket_chat(websocket: WebSocket):
|
||
|
|
"""WebSocket pour chat en temps réel."""
|
||
|
|
await manager.connect(websocket)
|
||
|
|
conversation_id = str(datetime.now().timestamp())
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
# Recevoir le message du client
|
||
|
|
data = await websocket.receive_text()
|
||
|
|
request_data = json.loads(data)
|
||
|
|
|
||
|
|
user_message = Message(
|
||
|
|
role="user",
|
||
|
|
content=request_data['message'],
|
||
|
|
timestamp=datetime.now().isoformat()
|
||
|
|
)
|
||
|
|
history.add_message(conversation_id, user_message)
|
||
|
|
|
||
|
|
# Récupérer l'historique de conversation
|
||
|
|
conversation = history.get_conversation(conversation_id)
|
||
|
|
messages = [{"role": m.role, "content": m.content} for m in conversation]
|
||
|
|
|
||
|
|
# Envoyer un accusé de réception
|
||
|
|
await manager.send_message(
|
||
|
|
json.dumps({
|
||
|
|
"type": "status",
|
||
|
|
"content": "Processing..."
|
||
|
|
}),
|
||
|
|
websocket
|
||
|
|
)
|
||
|
|
|
||
|
|
# Générer la réponse avec streaming
|
||
|
|
model = request_data.get('model', 'code-expert')
|
||
|
|
full_response = ""
|
||
|
|
|
||
|
|
stream = ollama.chat(
|
||
|
|
model=model,
|
||
|
|
messages=messages,
|
||
|
|
stream=True
|
||
|
|
)
|
||
|
|
|
||
|
|
for chunk in stream:
|
||
|
|
if 'message' in chunk and 'content' in chunk['message']:
|
||
|
|
content = chunk['message']['content']
|
||
|
|
full_response += content
|
||
|
|
|
||
|
|
await manager.send_message(
|
||
|
|
json.dumps({
|
||
|
|
"type": "stream",
|
||
|
|
"content": content
|
||
|
|
}),
|
||
|
|
websocket
|
||
|
|
)
|
||
|
|
|
||
|
|
# Sauvegarder la réponse complète
|
||
|
|
assistant_message = Message(
|
||
|
|
role="assistant",
|
||
|
|
content=full_response,
|
||
|
|
timestamp=datetime.now().isoformat()
|
||
|
|
)
|
||
|
|
history.add_message(conversation_id, assistant_message)
|
||
|
|
|
||
|
|
# Signal de fin
|
||
|
|
await manager.send_message(
|
||
|
|
json.dumps({
|
||
|
|
"type": "done",
|
||
|
|
"content": full_response
|
||
|
|
}),
|
||
|
|
websocket
|
||
|
|
)
|
||
|
|
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
manager.disconnect(websocket)
|
||
|
|
except Exception as e:
|
||
|
|
await manager.send_message(
|
||
|
|
json.dumps({
|
||
|
|
"type": "error",
|
||
|
|
"content": str(e)
|
||
|
|
}),
|
||
|
|
websocket
|
||
|
|
)
|
||
|
|
manager.disconnect(websocket)
|
||
|
|
|
||
|
|
@app.post("/execute")
|
||
|
|
async def execute_code(request: CodeExecutionRequest):
|
||
|
|
"""Exécute du code de manière sécurisée."""
|
||
|
|
try:
|
||
|
|
# Créer un fichier temporaire
|
||
|
|
with tempfile.NamedTemporaryFile(
|
||
|
|
mode='w',
|
||
|
|
suffix=f'.{request.language}',
|
||
|
|
delete=False
|
||
|
|
) as f:
|
||
|
|
f.write(request.code)
|
||
|
|
temp_file = f.name
|
||
|
|
|
||
|
|
# Exécuter selon le langage
|
||
|
|
if request.language == "python":
|
||
|
|
result = subprocess.run(
|
||
|
|
['python3', temp_file],
|
||
|
|
capture_output=True,
|
||
|
|
text=True,
|
||
|
|
timeout=10
|
||
|
|
)
|
||
|
|
elif request.language == "javascript":
|
||
|
|
result = subprocess.run(
|
||
|
|
['node', temp_file],
|
||
|
|
capture_output=True,
|
||
|
|
text=True,
|
||
|
|
timeout=10
|
||
|
|
)
|
||
|
|
elif request.language == "bash":
|
||
|
|
result = subprocess.run(
|
||
|
|
['bash', temp_file],
|
||
|
|
capture_output=True,
|
||
|
|
text=True,
|
||
|
|
timeout=10
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise HTTPException(status_code=400, detail=f"Language {request.language} not supported")
|
||
|
|
|
||
|
|
# Nettoyer
|
||
|
|
os.unlink(temp_file)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"stdout": result.stdout,
|
||
|
|
"stderr": result.stderr,
|
||
|
|
"returncode": result.returncode,
|
||
|
|
"success": result.returncode == 0
|
||
|
|
}
|
||
|
|
|
||
|
|
except subprocess.TimeoutExpired:
|
||
|
|
os.unlink(temp_file)
|
||
|
|
raise HTTPException(status_code=408, detail="Code execution timeout")
|
||
|
|
except Exception as e:
|
||
|
|
if os.path.exists(temp_file):
|
||
|
|
os.unlink(temp_file)
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.post("/analyze-code")
|
||
|
|
async def analyze_code(request: CodeExecutionRequest):
|
||
|
|
"""Analyse du code avec l'IA."""
|
||
|
|
try:
|
||
|
|
response = ollama.chat(
|
||
|
|
model='code-reviewer',
|
||
|
|
messages=[{
|
||
|
|
'role': 'user',
|
||
|
|
'content': f'''Analyse ce code {request.language}:
|
||
|
|
|
||
|
|
```{request.language}
|
||
|
|
{request.code}
|
||
|
|
```
|
||
|
|
|
||
|
|
Fournis:
|
||
|
|
1. Une brève description de ce que fait le code
|
||
|
|
2. Les problèmes potentiels (bugs, sécurité, performance)
|
||
|
|
3. Des suggestions d'amélioration
|
||
|
|
'''
|
||
|
|
}]
|
||
|
|
)
|
||
|
|
|
||
|
|
return {"analysis": response['message']['content']}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.post("/fix-code")
|
||
|
|
async def fix_code(code: str, error: str, language: str = "python"):
|
||
|
|
"""Corrige du code avec une erreur."""
|
||
|
|
try:
|
||
|
|
response = ollama.chat(
|
||
|
|
model='debugger',
|
||
|
|
messages=[{
|
||
|
|
'role': 'user',
|
||
|
|
'content': f'''Code {language} avec erreur:
|
||
|
|
|
||
|
|
```{language}
|
||
|
|
{code}
|
||
|
|
```
|
||
|
|
|
||
|
|
Erreur:
|
||
|
|
```
|
||
|
|
{error}
|
||
|
|
```
|
||
|
|
|
||
|
|
Identifie le problème et fournis le code corrigé.
|
||
|
|
'''
|
||
|
|
}]
|
||
|
|
)
|
||
|
|
|
||
|
|
return {"fixed_code": response['message']['content']}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
async def health_check():
|
||
|
|
"""Vérifie que Ollama est accessible."""
|
||
|
|
try:
|
||
|
|
ollama.list()
|
||
|
|
return {"status": "healthy", "ollama": "connected"}
|
||
|
|
except Exception as e:
|
||
|
|
return {"status": "unhealthy", "error": str(e)}
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
uvicorn.run(
|
||
|
|
"main:app",
|
||
|
|
host="0.0.0.0",
|
||
|
|
port=9001,
|
||
|
|
reload=True,
|
||
|
|
ws_ping_interval=20,
|
||
|
|
ws_ping_timeout=20
|
||
|
|
)
|