715 lines
27 KiB
Python
715 lines
27 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
'''
|
|
LLM Inference Module
|
|
A simplified module for making inferences with various LLM providers
|
|
|
|
Requirements:
|
|
pip install -r requirements.txt
|
|
'''
|
|
|
|
import sys
|
|
import requests
|
|
import ollama
|
|
import google.generativeai as genai
|
|
from huggingface_hub import InferenceClient
|
|
from together import Together
|
|
from groq import Groq
|
|
import os
|
|
import time
|
|
from openai import OpenAI # Used for both NVIDIA and GitHub endpoints
|
|
|
|
CONFIG = {
|
|
"api_keys": {
|
|
"HF_API_KEY": os.environ.get("HF_API_KEY"),
|
|
"TOGETHER_API_KEY": os.environ.get("TOGETHER_API_KEY"),
|
|
"GEMINI_API_KEY": os.environ.get("GEMINI_API_KEY"),
|
|
"AIQL_API_KEY": os.environ.get("AIQL_API_KEY"),
|
|
"GROQ_API_KEY": os.environ.get("GROQ_API_KEY"),
|
|
"NVIDIA_API_KEY": os.environ.get("NVIDIA_API_KEY"),
|
|
"GITHUB_TOKEN": os.environ.get("GITHUB_TOKEN")
|
|
},
|
|
"models": {
|
|
"aiql": {
|
|
"Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct",
|
|
"Llama-3.3-70B-Chat": "meta-llama/Llama-3.3-70B-Chat"
|
|
},
|
|
"together": {
|
|
"DeepSeek-70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
|
|
"Llama-3-3-70B-Turbo": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"
|
|
},
|
|
"gemini": {
|
|
"gemini-2.5-pro-preview": "gemini-2.5-pro-preview-03-25",
|
|
"gemini-2.5-flash-preview": "gemini-2.5-flash-preview-04-17",
|
|
"gemini-1.5-flash": "gemini-1.5-flash",
|
|
"gemini-1.5-pro": "gemini-1.5-pro",
|
|
"gemini-1.5-flash-002": "gemini-1.5-flash-002",
|
|
"gemini-1.5-flash-001": "gemini-1.5-flash-001",
|
|
"gemini-1.5-pro-002": "gemini-1.5-pro-002",
|
|
"gemini-1.5-pro-001": "gemini-1.5-pro-001",
|
|
"gemini-2.0-flash": "gemini-2.0-flash",
|
|
"gemini-2.0-flash-exp": "gemini-2.0-flash-exp",
|
|
"gemini-2.0-flash-thinking-exp-01-21": "gemini-2.0-flash-thinking-exp-01-21"
|
|
},
|
|
"hf": {
|
|
"DeepSeek-R1-Distill-Qwen-32B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
"Qwen2.5-Coder-32B": "Qwen/Qwen2.5-Coder-32B-Instruct"
|
|
},
|
|
"ollama": [
|
|
"falcon3:10b"
|
|
],
|
|
"groq": {
|
|
"llama-3.3-70b-versatile": "llama-3.3-70b-versatile",
|
|
"deepseek-r1-distill-llama-70b": "deepseek-r1-distill-llama-70b"
|
|
},
|
|
"nvidia": {
|
|
"qwen2.5-coder-32b": "qwen/qwen2.5-coder-32b-instruct",
|
|
"llama2-70b": "meta-llama/llama-2-70b-chat",
|
|
"mixtral-8x7b": "mistralai/mixtral-8x7b-instruct",
|
|
"yi-34b": "01-ai/yi-34b-chat"
|
|
},
|
|
"github": {
|
|
"gpt-4o": "gpt-4o",
|
|
"gpt-4o-mini": "gpt-4o-mini",
|
|
"mistral-small": "mistral-small-2503",
|
|
"deepseek-v3": "deepseek-v3",
|
|
"phi-4": "phi-4",
|
|
"llama-3.3-70b": "llama-3.3-70b-instruct"
|
|
}
|
|
},
|
|
"defaults": {
|
|
"ollama": "llama3.2:3b",
|
|
"hf": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
"together": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
|
|
"gemini": "gemini-1.5-flash",
|
|
"aiql": "meta-llama/Llama-3.3-70B-Instruct",
|
|
"groq": "llama-3.3-70b-versatile",
|
|
"nvidia": "qwen/qwen2.5-coder-32b-instruct",
|
|
"github": "gpt-4o"
|
|
}
|
|
}
|
|
|
|
class InferenceHandler:
|
|
@staticmethod
|
|
def preload_ollama_model(model: str):
|
|
"""Preload an Ollama model by sending a simple query to it"""
|
|
try:
|
|
print(f"Loading Ollama model {model}...")
|
|
client = ollama.Client()
|
|
|
|
# Send a simple query to load the model into memory
|
|
client.chat(model=model, messages=[{'role': 'user', 'content': 'hello'}])
|
|
print(f"Model {model} loaded successfully")
|
|
return True
|
|
except Exception as e:
|
|
print(f"Failed to preload model {model}: {str(e)}")
|
|
return False
|
|
|
|
@staticmethod
|
|
def ollama(prompt: str, model: str, system_content: str = None, preload: bool = False) -> str:
|
|
try:
|
|
client = ollama.Client()
|
|
|
|
# Preload the model if requested
|
|
if preload:
|
|
InferenceHandler.preload_ollama_model(model)
|
|
|
|
# If system_content is provided, use the chat API with messages
|
|
if system_content:
|
|
messages = [
|
|
{'role': 'system', 'content': system_content},
|
|
{'role': 'user', 'content': prompt}
|
|
]
|
|
response = client.chat(model=model, messages=messages)
|
|
|
|
# Add response validation
|
|
if not response or 'message' not in response or 'content' not in response['message']:
|
|
return "Error: Empty response from Ollama chat API"
|
|
|
|
return response['message']['content']
|
|
else:
|
|
# Use the generate API without system content
|
|
response = client.generate(model=model, prompt=prompt)
|
|
|
|
# Add response validation
|
|
if not response or 'response' not in response:
|
|
return "Error: Empty response from Ollama generate API"
|
|
|
|
return response['response']
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
if "connection" in error_msg.lower():
|
|
return "Error: Could not connect to Ollama server"
|
|
return f"Ollama Error: {error_msg}"
|
|
|
|
@staticmethod
|
|
def hf(prompt: str, model: str) -> str:
|
|
try:
|
|
client = InferenceClient(token=CONFIG['api_keys']['HF_API_KEY'])
|
|
response = client.text_generation(prompt, model=model)
|
|
|
|
# Add response validation
|
|
if not response or response.isspace():
|
|
return "Error: Empty response from HuggingFace"
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
# Improve error message
|
|
error_msg = str(e)
|
|
if "Expecting value" in error_msg:
|
|
return "Error: Invalid response format from HuggingFace API"
|
|
return f"HF Error: {error_msg}"
|
|
|
|
@staticmethod
|
|
def together(prompt: str, model: str) -> str:
|
|
try:
|
|
client = Together(api_key=CONFIG['api_keys']['TOGETHER_API_KEY'])
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
max_tokens=2048 # Add reasonable token limit
|
|
)
|
|
|
|
# Add response validation
|
|
if not response or not response.choices:
|
|
return "Error: Empty response from Together"
|
|
|
|
return response.choices[0].message.content
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
if "authentication" in error_msg.lower():
|
|
return "Error: Invalid Together API key"
|
|
return f"Together Error: {error_msg}"
|
|
|
|
@staticmethod
|
|
def gemini(prompt: str, model: str) -> str:
|
|
try:
|
|
genai.configure(api_key=CONFIG['api_keys']['GEMINI_API_KEY'])
|
|
model = genai.GenerativeModel(model)
|
|
response = model.generate_content(prompt)
|
|
|
|
# Add response validation
|
|
if not response or not response.text:
|
|
return "Error: Empty response from Gemini"
|
|
|
|
return response.text
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
if "invalid" in error_msg.lower() and "model" in error_msg.lower():
|
|
return "Error: Invalid Gemini model"
|
|
return f"Gemini Error: {error_msg}"
|
|
|
|
@staticmethod
|
|
def aiql(prompt: str, model: str) -> str:
|
|
try:
|
|
headers = {
|
|
"Authorization": f"Bearer {CONFIG['api_keys']['AIQL_API_KEY']}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
data = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": prompt}]
|
|
}
|
|
response = requests.post(
|
|
"https://ai.aiql.com/v1/chat/completions",
|
|
headers=headers,
|
|
json=data
|
|
)
|
|
|
|
# Add response validation
|
|
if not response or response.status_code != 200:
|
|
return f"Error: API request failed with status {response.status_code}"
|
|
|
|
response_json = response.json()
|
|
if not response_json:
|
|
return "Error: Invalid response format from AIQL"
|
|
|
|
# Try different response formats
|
|
if 'choices' in response_json:
|
|
return response_json['choices'][0]['message']['content']
|
|
elif 'response' in response_json:
|
|
return response_json['response']
|
|
elif 'content' in response_json:
|
|
return response_json['content']
|
|
else:
|
|
return "Error: Could not find response content in AIQL response"
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
if "Expecting value" in error_msg:
|
|
return "Error: Invalid response format from AIQL API"
|
|
return f"AIQL Error: {error_msg}"
|
|
|
|
@staticmethod
|
|
def groq(prompt: str, model: str) -> str:
|
|
try:
|
|
client = Groq(api_key=CONFIG['api_keys']['GROQ_API_KEY'])
|
|
response = client.chat.completions.create(
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "you are a helpful assistant."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": prompt
|
|
}
|
|
],
|
|
model=model,
|
|
temperature=0.7,
|
|
max_completion_tokens=2048,
|
|
top_p=1,
|
|
stream=False
|
|
)
|
|
|
|
# Add response validation
|
|
if not response or not response.choices:
|
|
return "Error: Empty response from Groq"
|
|
|
|
return response.choices[0].message.content
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
if "authentication" in error_msg.lower():
|
|
return "Error: Invalid Groq API key"
|
|
return f"Groq Error: {error_msg}"
|
|
|
|
@staticmethod
|
|
def nvidia(prompt: str, model: str) -> str:
|
|
try:
|
|
# Get the actual model ID from the models dictionary
|
|
# This is the key difference - we need to use the model ID, not the model name
|
|
model_id = model
|
|
if model in CONFIG['models']['nvidia']:
|
|
model_id = CONFIG['models']['nvidia'][model]
|
|
|
|
print(f"NVIDIA: Initializing client with model {model} (ID: {model_id})")
|
|
client = OpenAI(
|
|
base_url="https://integrate.api.nvidia.com/v1",
|
|
api_key=CONFIG['api_keys']['NVIDIA_API_KEY']
|
|
)
|
|
|
|
print(f"NVIDIA: Sending request to model {model_id}")
|
|
completion = client.chat.completions.create(
|
|
model=model_id,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
temperature=0.2,
|
|
top_p=0.7,
|
|
max_tokens=1024
|
|
)
|
|
|
|
# Add response validation
|
|
if not completion or not completion.choices:
|
|
print(f"NVIDIA: Empty response received")
|
|
return "Error: Empty response from NVIDIA API"
|
|
|
|
response_content = completion.choices[0].message.content
|
|
print(f"NVIDIA: Response received, length: {len(response_content)}")
|
|
return response_content
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
print(f"NVIDIA Error: {error_msg}")
|
|
if "authentication" in error_msg.lower():
|
|
return "Error: Invalid NVIDIA API key"
|
|
return f"NVIDIA Error: {error_msg}"
|
|
|
|
@staticmethod
|
|
def github(prompt: str, model: str) -> str:
|
|
try:
|
|
# GitHub endpoint for OpenAI API
|
|
ENDPOINT = "https://models.inference.ai.azure.com"
|
|
|
|
# Get the actual model ID from the models dictionary
|
|
model_id = model
|
|
if model in CONFIG['models']['github']:
|
|
model_id = CONFIG['models']['github'][model]
|
|
|
|
print(f"GitHub: Initializing client with model {model} (ID: {model_id})")
|
|
|
|
client = OpenAI(
|
|
base_url=ENDPOINT,
|
|
api_key=CONFIG['api_keys']['GITHUB_TOKEN']
|
|
)
|
|
|
|
print(f"GitHub: Sending request to model {model_id}")
|
|
response = client.chat.completions.create(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
model=model_id,
|
|
max_tokens=1024,
|
|
temperature=0.7
|
|
)
|
|
|
|
# Add response validation
|
|
if not response or not response.choices:
|
|
print(f"GitHub: Empty response received")
|
|
return "Error: Empty response from GitHub API"
|
|
|
|
response_content = response.choices[0].message.content
|
|
print(f"GitHub: Response received, length: {len(response_content)}")
|
|
return response_content
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
print(f"GitHub Error: {error_msg}")
|
|
if "authentication" in error_msg.lower():
|
|
return "Error: Invalid GitHub token"
|
|
return f"GitHub Error: {error_msg}"
|
|
|
|
def get_available_models():
|
|
"""Returns a dictionary of all available models"""
|
|
return CONFIG['models']
|
|
|
|
def get_default_models():
|
|
"""Returns a dictionary of default models for each provider"""
|
|
return CONFIG['defaults']
|
|
|
|
def get_ollama_models():
|
|
"""Get available Ollama models from local server using subprocess"""
|
|
try:
|
|
import subprocess
|
|
|
|
# Execute the shell command and capture the output
|
|
result = subprocess.run(['ollama', 'list'], capture_output=True, text=True)
|
|
|
|
# Check if the command was successful
|
|
if result.returncode == 0:
|
|
# Split the output into lines and skip the first line (header)
|
|
lines = result.stdout.strip().split('\n')[1:]
|
|
|
|
# Extract the first field from each line (model name)
|
|
models = [line.split()[0] for line in lines]
|
|
return models
|
|
else:
|
|
print(f"Error executing 'ollama list': {result.stderr}")
|
|
return CONFIG['models']['ollama']
|
|
except Exception as e:
|
|
print(f"Exception in get_ollama_models: {str(e)}")
|
|
return CONFIG['models']['ollama']
|
|
|
|
def check_provider_key_available(provider):
|
|
"""Check if the API key for a specific provider is available.
|
|
|
|
Args:
|
|
provider (str): The provider to check
|
|
|
|
Returns:
|
|
bool: True if the key is available, False otherwise
|
|
"""
|
|
# Ollama is a local service, so no API key is needed
|
|
if provider == "ollama":
|
|
try:
|
|
client = ollama.Client()
|
|
models = client.list()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
# For other providers, check if the API key is available
|
|
key_mapping = {
|
|
"hf": "HF_API_KEY",
|
|
"together": "TOGETHER_API_KEY",
|
|
"gemini": "GEMINI_API_KEY",
|
|
"aiql": "AIQL_API_KEY",
|
|
"groq": "GROQ_API_KEY",
|
|
"nvidia": "NVIDIA_API_KEY",
|
|
"github": "GITHUB_TOKEN"
|
|
}
|
|
|
|
if provider not in key_mapping:
|
|
return False
|
|
|
|
key_name = key_mapping[provider]
|
|
return bool(CONFIG["api_keys"][key_name])
|
|
|
|
def run_inference(prompt, provider=None, model=None, system_content=None):
|
|
"""Run inference with specified provider and model.
|
|
|
|
Args:
|
|
prompt (str): The prompt to send to the model
|
|
provider (str, optional): The provider to use (ollama, hf, together, gemini, aiql, groq, nvidia, github)
|
|
model (str, optional): The specific model to use
|
|
system_content (str, optional): Custom system role content for models that support it
|
|
|
|
Returns:
|
|
str: The model's response
|
|
"""
|
|
# If no provider specified, use the first available one
|
|
if not provider:
|
|
available = check_available_apis()
|
|
if not available:
|
|
return "Error: No available providers found. Please check your API keys and Ollama installation."
|
|
provider = available[0]
|
|
|
|
# Check if the API key for the provider is available
|
|
if not check_provider_key_available(provider):
|
|
return f"Error: API key for {provider} is not available. Please set the appropriate environment variable."
|
|
|
|
# If no model specified, use the default for the provider
|
|
if not model:
|
|
model = CONFIG["defaults"][provider]
|
|
|
|
# For ollama, we need to check if the model exists locally
|
|
if provider == "ollama" and model not in get_ollama_models():
|
|
return f"Error: Model '{model}' not found in Ollama. Please pull it first with 'ollama pull {model}'."
|
|
|
|
print(f"Running inference with provider: {provider}, model: {model}")
|
|
print(f"Prompt: {prompt[:50]}..." if len(prompt) > 50 else f"Prompt: {prompt}")
|
|
start_time = time.time()
|
|
|
|
# Call the appropriate provider method
|
|
try:
|
|
if provider == "ollama":
|
|
response = InferenceHandler.ollama(prompt, model, system_content)
|
|
elif provider == "hf":
|
|
response = InferenceHandler.hf(prompt, model)
|
|
elif provider == "together":
|
|
response = InferenceHandler.together(prompt, model)
|
|
elif provider == "gemini":
|
|
response = InferenceHandler.gemini(prompt, model)
|
|
elif provider == "aiql":
|
|
response = InferenceHandler.aiql(prompt, model)
|
|
elif provider == "groq":
|
|
response = InferenceHandler.groq(prompt, model)
|
|
elif provider == "nvidia":
|
|
print(f"Calling NVIDIA handler with model: {model}")
|
|
response = InferenceHandler.nvidia(prompt, model)
|
|
elif provider == "github":
|
|
print(f"Calling GitHub handler with model: {model}")
|
|
response = InferenceHandler.github(prompt, model)
|
|
else:
|
|
return f"Error: Unknown provider '{provider}'"
|
|
|
|
end_time = time.time()
|
|
print(f"Inference completed in {end_time - start_time:.2f} seconds")
|
|
return response
|
|
except Exception as e:
|
|
print(f"Error during inference: {str(e)}")
|
|
return f"Error with {provider}: {str(e)}"
|
|
|
|
def check_available_apis():
|
|
"""Check which API tokens are available in the environment and return available providers."""
|
|
available = []
|
|
|
|
# Check Ollama by attempting to connect
|
|
try:
|
|
client = ollama.Client()
|
|
models = client.list()
|
|
if models:
|
|
available.append("ollama")
|
|
except Exception as e:
|
|
print(f"Ollama not available: {e}")
|
|
|
|
# Check API keys
|
|
if CONFIG["api_keys"]["HF_API_KEY"]:
|
|
available.append("hf")
|
|
|
|
if CONFIG["api_keys"]["TOGETHER_API_KEY"]:
|
|
available.append("together")
|
|
|
|
if CONFIG["api_keys"]["GEMINI_API_KEY"]:
|
|
available.append("gemini")
|
|
|
|
if CONFIG["api_keys"]["AIQL_API_KEY"]:
|
|
available.append("aiql")
|
|
|
|
if CONFIG["api_keys"]["GROQ_API_KEY"]:
|
|
available.append("groq")
|
|
|
|
if CONFIG["api_keys"]["NVIDIA_API_KEY"]:
|
|
print("NVIDIA API key found")
|
|
available.append("nvidia")
|
|
|
|
if CONFIG["api_keys"]["GITHUB_TOKEN"]:
|
|
print("GitHub token found")
|
|
available.append("github")
|
|
|
|
return available
|
|
|
|
def print_available_apis():
|
|
"""Print information about available APIs and possible requests"""
|
|
available_providers = check_available_apis()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("AVAILABLE API PROVIDERS")
|
|
print("=" * 60)
|
|
|
|
if not available_providers:
|
|
print("\nNo API providers are available. Please set environment variables for API keys:")
|
|
for key in CONFIG['api_keys'].keys():
|
|
print(f" - {key}")
|
|
print("\nOr start Ollama locally to use local models.")
|
|
return False
|
|
|
|
print(f"\nFound {len(available_providers)} available API providers:\n")
|
|
|
|
for provider in available_providers:
|
|
print(f"- {provider.upper()}:")
|
|
|
|
# Special handling for Ollama to show actual local models
|
|
if provider == "ollama":
|
|
ollama_models = get_ollama_models()
|
|
for model in ollama_models:
|
|
print(f" - {model}")
|
|
else:
|
|
models = CONFIG['models'][provider]
|
|
|
|
if isinstance(models, dict):
|
|
for model_name, model_id in models.items():
|
|
print(f" - {model_name} ({model_id})")
|
|
else: # It's a list
|
|
for model in models:
|
|
print(f" - {model}")
|
|
|
|
print("\n" + "=" * 60)
|
|
return True
|
|
|
|
def main():
|
|
"""Example function that runs the same prompt through all available providers and models."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='LLM Inference Module')
|
|
parser.add_argument('prompt', nargs='?', type=str, help='The prompt to send to the model', default="Why is the sky blue?")
|
|
parser.add_argument('--provider', type=str, help='The provider to use (ollama, hf, together, gemini, aiql, groq, nvidia, github)')
|
|
parser.add_argument('--model', type=str, help='The specific model to use')
|
|
parser.add_argument('--system', type=str, help='System content for chat models', default="You are a helpful assistant.")
|
|
parser.add_argument('--list', action='store_true', help='List available providers and models')
|
|
parser.add_argument('--debug', action='store_true', help='Enable debug output')
|
|
parser.add_argument('-a', '--all', action='store_true', help='Run inference on all available providers and models')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check if the specified provider's API key is available
|
|
if args.provider and not check_provider_key_available(args.provider):
|
|
print(f"Error: API key for {args.provider} is not available. Please set the appropriate environment variable.")
|
|
return
|
|
|
|
if args.list:
|
|
print_available_apis()
|
|
return
|
|
|
|
# If provider is specified but no model, use the default model for that provider
|
|
if args.provider and not args.model:
|
|
args.model = CONFIG["defaults"][args.provider]
|
|
|
|
if args.debug:
|
|
print(f"Running with provider: {args.provider}, model: {args.model}")
|
|
print(f"Prompt: {args.prompt[:50]}..." if len(args.prompt) > 50 else f"Prompt: {args.prompt}")
|
|
|
|
# If -a/--all flag is specified, run on all providers regardless of whether a specific provider was given
|
|
if args.all:
|
|
# Continue to the code below that runs on all providers
|
|
pass
|
|
# Otherwise, if a specific provider is given, run only on that provider
|
|
elif args.provider:
|
|
start_time = time.time()
|
|
response = run_inference(args.prompt, args.provider, args.model, args.system)
|
|
end_time = time.time()
|
|
|
|
if args.debug:
|
|
print(f"Inference completed in {end_time - start_time:.2f} seconds")
|
|
|
|
# Print the response
|
|
print("\nResponse:")
|
|
print(response)
|
|
return
|
|
|
|
# If we get here, either --all flag was specified or no provider was specified
|
|
print(f"\nPrompt: {args.prompt}\n")
|
|
print("Running inference on all models for each provider...\n")
|
|
|
|
# Get available providers (only those with API keys)
|
|
available_providers = check_available_apis()
|
|
|
|
# Store response times for leaderboard
|
|
response_times = []
|
|
|
|
# Import colorama for colored terminal output
|
|
try:
|
|
from colorama import init, Fore, Style
|
|
init() # Initialize colorama
|
|
color_enabled = True
|
|
except ImportError:
|
|
color_enabled = False
|
|
print("Note: Install 'colorama' package for colored error messages (pip install colorama)")
|
|
|
|
# Run inference on each provider with all its models
|
|
for provider in available_providers:
|
|
print(f"\n{'=' * 30}\n{provider.upper()} MODELS\n{'=' * 30}\n")
|
|
|
|
# Special handling for Ollama to use actual local models
|
|
if provider == "ollama":
|
|
ollama_models = get_ollama_models()
|
|
model_items = [(model, model) for model in ollama_models]
|
|
else:
|
|
models = CONFIG['models'][provider]
|
|
# Handle different model formats (list vs dict)
|
|
if isinstance(models, dict):
|
|
model_items = list(models.items())
|
|
else: # It's a list
|
|
model_items = [(model, model) for model in models]
|
|
|
|
for model_name, model_id in model_items:
|
|
try:
|
|
print(f"\n----- {model_name} -----\n")
|
|
|
|
# For Ollama models, preload the model first
|
|
if provider == "ollama":
|
|
# Preload the model with a dummy query
|
|
InferenceHandler.preload_ollama_model(model_id)
|
|
print("Warming up model...")
|
|
time.sleep(1) # Short pause for UI feedback
|
|
|
|
# Start timing after preloading
|
|
start_time = time.time()
|
|
|
|
# Run inference with the user's prompt and system message
|
|
response = run_inference(args.prompt, provider, model_id, args.system)
|
|
|
|
# End timing
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
|
|
# Store response time for leaderboard
|
|
full_model_name = f"{provider}/{model_name}"
|
|
response_times.append((full_model_name, elapsed_time))
|
|
|
|
# Print response and timing information
|
|
print(response)
|
|
print(f"\nResponse time: {elapsed_time:.2f} seconds")
|
|
print("\n" + "-" * 50)
|
|
except Exception as e:
|
|
# Print error in red if colorama is available
|
|
if color_enabled:
|
|
error_msg = f"{Fore.RED}Error with {provider}/{model_name}: {str(e)}{Style.RESET_ALL}"
|
|
else:
|
|
error_msg = f"Error with {provider}/{model_name}: {str(e)}"
|
|
print(error_msg)
|
|
print("\n" + "-" * 50)
|
|
|
|
# Only display leaderboard if we have results
|
|
if response_times:
|
|
# Display leaderboard
|
|
print("\n" + "=" * 50)
|
|
print("RESPONSE TIME LEADERBOARD")
|
|
print("=" * 50)
|
|
|
|
# Sort by response time (fastest first)
|
|
response_times.sort(key=lambda x: x[1])
|
|
|
|
# Print leaderboard
|
|
print(f"{'Rank':<6}{'Model':<40}{'Time (seconds)':<15}")
|
|
print("-" * 61)
|
|
for i, (model, time_taken) in enumerate(response_times, 1):
|
|
print(f"{i:<6}{model:<40}{time_taken:.2f}")
|
|
|
|
print("\n" + "=" * 50)
|
|
else:
|
|
print("\nNo successful responses to display in leaderboard.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |