from tabnanny import verbose
import ollama
import time
from typing import List, Dict, Any
import json
from statistics import mean
import re
import ast
import argparse
import requests
import os
from together import Together
from cpuinfo import get_cpu_info
import subprocess


# ANSI color codes
SUCCESS = '\033[38;5;78m'   # Soft mint green for success
ERROR = '\033[38;5;203m'    # Soft coral red for errors
INFO = '\033[38;5;75m'      # Sky blue for info
HEADER = '\033[38;5;147m'   # Soft purple for headers
WARNING = '\033[38;5;221m'  # Warm gold for warnings
EMPHASIS = '\033[38;5;159m' # Cyan for emphasis
MUTED = '\033[38;5;246m'    # Subtle gray for less important text
ENDC = '\033[0m'
BOLD = '\033[1m'

# Replace existing color usages
GREEN = SUCCESS
RED = ERROR
BLUE = INFO
YELLOW = WARNING
WHITE = MUTED

# Server configurations
SERVERS = {
    'local': 'http://localhost:11434',
    'z60': 'http://192.168.196.60:11434'
}

class Timer:
    def __init__(self):
        self.start_time = None
        self.end_time = None

    def start(self):
        self.start_time = time.time()

    def stop(self):
        self.end_time = time.time()

    def elapsed_time(self):
        if self.start_time is None:
            return 0
        if self.end_time is None:
            return time.time() - self.start_time
        return self.end_time - self.start_time

def extract_code_from_response(response: str) -> str:
    """Extract Python code from a markdown-formatted string."""
    code_blocks = re.findall(r'```python\n(.*?)```', response, re.DOTALL)
    if code_blocks:
        return code_blocks[0].strip()
    return response

def is_valid_python(code: str) -> bool:
    """Check if the code is valid Python syntax."""
    try:
        ast.parse(code)
        return True
    except SyntaxError:
        return False

def analyze_failed_code(code: str, test_case: tuple, expected: any, actual: any, function_name: str, model: str) -> bool:
    """Analyze why code failed using Together API. Returns True if Together thinks the code should work."""
    prompt = f"""Analyze this Python code and explain why it failed the test case. Format your response EXACTLY as follows:

ASSESSMENT: [Write a one-line assessment: either "SHOULD PASS" or "SHOULD FAIL" followed by a brief reason]

ANALYSIS:
[Detailed analysis of why the code failed and how to fix it]

Code:
{code}

Test case:
Input: {test_case}
Expected output: {expected}
Actual output: {actual}
Function name required: {function_name}
Model: {model}"""

    try:
        TOGETHER_API_KEY = os.environ["TOGETHER_API_KEY"]
        together_client = Together(api_key=TOGETHER_API_KEY)
        response = together_client.chat.completions.create(
            model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
            messages=[
                {"role": "system", "content": "You are a Python expert analyzing code failures. Always format your response with ASSESSMENT and ANALYSIS sections."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=1000,
            temperature=0.7,
            top_p=0.7,
            top_k=50,
            repetition_penalty=1,
            stop=["<|eot_id|>", "<|eom_id|>"]
        )
        
        analysis = response.choices[0].message.content
        should_pass = "SHOULD PASS" in analysis.upper()
        if verbose: print(f"\n{BLUE}[{model}] Together Analysis:{ENDC}")
        if verbose: print(f"{GREEN if should_pass else RED}{analysis}{ENDC}")
        return should_pass
    except Exception as e:
        print(f"\n{RED}Error getting Together API analysis: {e}{ENDC}")
        return False

def validate_with_debug(code: str, function_name: str, test_cases: List[tuple], model: str) -> tuple[bool, str, List[bool]]:
    """Validate code with detailed debug information. Returns (success, debug_info, test_results)"""
    debug_info = []
    test_results = []  # Track individual test case results
    test_outputs = []  # Store test outputs for combined display
    
    try:
        # Create a local namespace
        namespace = {}
        debug_info.append(f"Executing code:\n{code}")
        
        try:
            # Redirect stdout to capture prints from the executed code
            import io
            import sys
            stdout = sys.stdout
            sys.stdout = io.StringIO()
            
            # Execute the code
            exec(code, namespace)
            
            # Restore stdout
            sys.stdout = stdout
            
        except Exception as e:
            if 'sys' in locals():  # Restore stdout if it was changed
                sys.stdout = stdout
            if verbose: print(f"\n{RED}Failed code:{ENDC}\n{code}")
            return False, f"Error executing code: {str(e)}", test_results
            
        if function_name not in namespace:
            if verbose: print(f"\n{RED}Failed code:{ENDC}\n{code}")
            together_opinion = analyze_failed_code(code, "N/A", f"Function named '{function_name}'", 
                                                f"Found functions: {list(namespace.keys())}", function_name, model)
            print(f"\nTests passed: ❌ Together opinion: {'✅' if together_opinion else '❌'}")
            return False, f"Function '{function_name}' not found in code. Available names: {list(namespace.keys())}", test_results
            
        function = namespace[function_name]
        debug_info.append(f"Function {function_name} found")
        
        # Run test cases
        all_passed = True
        for i, (test_input, expected) in enumerate(test_cases):
            try:
                # Redirect stdout for each test case
                stdout = sys.stdout
                sys.stdout = io.StringIO()
                
                if isinstance(test_input, tuple):
                    result = function(*test_input)
                else:
                    result = function(test_input)
                
                # Restore stdout
                sys.stdout = stdout
                
                # Store result but don't print individually
                test_outputs.append(str(result))
                test_passed = result == expected
                test_results.append(test_passed)
                
                if not test_passed:
                    if verbose: print(f"\n{RED}Failed code:{ENDC}\n{code}")
                    print(f"\n{RED}Test case {i+1} failed:{ENDC}")
                    print(f"Input: {test_input} Expected: {expected} Got: {result}")
                    
                    together_opinion = analyze_failed_code(code, test_input, expected, result, function_name, model)
                    print(f"Tests passed: ❌ Together opinion: {'✅' if together_opinion else '❌'}")
                    
                    all_passed = False
                    continue
                
                debug_info.append(f"Test case {i+1} passed: {test_input} → {result}")
            except Exception as e:
                if 'sys' in locals():  # Restore stdout if it was changed
                    sys.stdout = stdout
                test_outputs.append(f"Error: {str(e)}")
                if verbose: print(f"\n{RED}Failed code:{ENDC}\n{code}")
                print(f"\n{RED}{str(e)} in test case {i+1} Input: {test_input} Expected: {expected}")
                
                together_opinion = analyze_failed_code(code, test_input, expected, f"Error: {str(e)}", function_name, model)
                print(f"Tests passed: ❌ Together opinion: {'✅' if together_opinion else '❌'}")
                
                test_results.append(False)
                all_passed = False
                continue
            finally:
                if 'sys' in locals():  # Always restore stdout
                    sys.stdout = stdout
        
        # Print all test outputs on one line
        # print(f"{WHITE}{BOLD}Test outputs: {join(test_outputs)}{ENDC}")
        print(f"{WHITE}Test outputs: {', '.join(test_outputs)}{ENDC}")
                
        if all_passed:
            print(f"Tests passed: ✅")
            return True, "All tests passed!\n" + "\n".join(debug_info), test_results
        print(f"Tests passed: ❌")
        return False, "Some tests failed", test_results
    except Exception as e:
        if 'sys' in locals():  # Restore stdout if it was changed
            sys.stdout = stdout
        print(f"\n{RED}Error in validate_with_debug: {str(e)}{ENDC}")
        return False, f"Unexpected error: {str(e)}", test_results

def test_fibonacci():
    question = """Write a Python function named EXACTLY 'fibonacci' (not fibonacci_dp or any other name) that returns the nth Fibonacci number.
The function signature must be: def fibonacci(n)

Requirements:
1. Handle edge cases:
   - For n = 0, return 0
   - For n = 1 or n = 2, return 1
   - For negative numbers, return -1
2. For n > 2: F(n) = F(n-1) + F(n-2)
3. Use dynamic programming or memoization for efficiency
4. Do NOT use any print statements - just return the values

Example sequence: 0,1,1,2,3,5,8,13,21,...
Example calls:
- fibonacci(6) returns 8
- fibonacci(0) returns 0
- fibonacci(-1) returns -1"""

    test_cases = [
        (0, 0),    # Edge case: n = 0
        (1, 1),    # Edge case: n = 1
        (2, 1),    # Edge case: n = 2
        (6, 8),    # Regular case
        (10, 55),  # Larger number
        (-1, -1),  # Edge case: negative input
    ]
    
    def validate(code: str) -> bool:
        success, debug_info, test_results = validate_with_debug(code, 'fibonacci', test_cases, "N/A")
        return success
    
    return (question, validate, test_cases)

def test_binary_search():
    question = """Write a Python function named EXACTLY 'binary_search' that performs binary search on a sorted list.
The function signature must be: def binary_search(arr, target)

Requirements:
1. The function takes two arguments:
   - arr: a sorted list of integers
   - target: the integer to find
2. Return the index of the target if found
3. Return -1 if the target is not in the list
4. Do NOT use any print statements - just return the values

Example:
- binary_search([1,2,3,4,5], 3) returns 2
- binary_search([1,2,3,4,5], 6) returns -1"""

    test_cases = [
        (([1,2,3,4,5], 3), 2),     # Regular case: target in middle
        (([1,2,3,4,5], 1), 0),     # Edge case: target at start
        (([1,2,3,4,5], 5), 4),     # Edge case: target at end
        (([1,2,3,4,5], 6), -1),    # Edge case: target not in list
        (([], 1), -1),             # Edge case: empty list
        (([1], 1), 0),             # Edge case: single element list
    ]
    
    def validate(code: str) -> bool:
        success, debug_info, test_results = validate_with_debug(code, 'binary_search', test_cases, "N/A")
        return success
    
    return (question, validate, test_cases)

def test_palindrome():
    question = """Write a Python function named EXACTLY 'is_palindrome' that checks if a string is a palindrome.
The function signature must be: def is_palindrome(s)

Requirements:
1. The function takes one argument:
   - s: a string to check
2. Return True if the string is a palindrome, False otherwise
3. Ignore case (treat uppercase and lowercase as the same)
4. Ignore non-alphanumeric characters (spaces, punctuation)
5. Do NOT use any print statements - just return the values

Example:
- is_palindrome("A man, a plan, a canal: Panama") returns True
- is_palindrome("race a car") returns False"""

    test_cases = [
        ("A man, a plan, a canal: Panama", True),   # Regular case with punctuation
        ("race a car", False),                      # Regular case, not palindrome
        ("", True),                                 # Edge case: empty string
        ("a", True),                                # Edge case: single character
        ("Was it a car or a cat I saw?", True),     # Complex case with punctuation
        ("hello", False),                           # Simple case, not palindrome
    ]
    
    def validate(code: str) -> bool:
        success, debug_info, test_results = validate_with_debug(code, 'is_palindrome', test_cases, "N/A")
        return success
    
    return (question, validate, test_cases)

def test_anagram():
    question = """Write a Python function named EXACTLY 'are_anagrams' that checks if two strings are anagrams.
The function signature must be: def are_anagrams(str1, str2)

Requirements:
1. The function takes two arguments:
   - str1: first string
   - str2: second string
2. Return True if the strings are anagrams, False otherwise
3. Ignore case (treat uppercase and lowercase as the same)
4. Ignore spaces
5. Consider only alphanumeric characters
6. Do NOT use any print statements - just return the values

Example:
- are_anagrams("listen", "silent") returns True
- are_anagrams("hello", "world") returns False"""

    test_cases = [
        (("listen", "silent"), True),           # Regular case
        (("hello", "world"), False),            # Not anagrams
        (("", ""), True),                       # Edge case: empty strings
        (("a", "a"), True),                     # Edge case: single char
        (("Debit Card", "Bad Credit"), True),   # Case and space test
        (("Python", "Java"), False),            # Different lengths
    ]
    
    def validate(code: str) -> bool:
        success, debug_info, test_results = validate_with_debug(code, 'are_anagrams', test_cases, "N/A")
        return success
    
    return (question, validate, test_cases)

# List of all test cases
CODING_QUESTIONS = [
    test_fibonacci(),
    test_binary_search(),
    test_palindrome(),
    test_anagram()
]

# Add test names as constants
TEST_NAMES = {
    "Write a Python func": "Fibonacci",
    "Write a Python func": "Binary Search",
    "Write a Python func": "Palindrome",
    "Write a Python func": "Anagram Check"
}

def get_test_name(question: str) -> str:
    """Get a friendly name for the test based on the question."""
    if "fibonacci" in question.lower():
        return "Fibonacci"
    elif "binary_search" in question.lower():
        return "Binary Search"
    elif "palindrome" in question.lower():
        return "Palindrome"
    elif "anagram" in question.lower():
        return "Anagram Check"
    return question[:20] + "..."

def get_model_stats(model_name: str, question_tuple: tuple, server_url: str) -> Dict:
    """
    Get performance statistics for a specific model and validate the response.
    """
    question, validator, test_cases = question_tuple
    timer = Timer()
    results = {
        'model': model_name,
        'total_duration': 0,
        'tokens_per_second': 0,
        'code_valid': False,
        'tests_passed': False,
        'error': None,
        'test_results': []
    }
    
    try:
        timer.start()
        print(f'{WHITE}Requesting code from {server_url} with {model_name}{ENDC}')
        response = requests.post(
            f"{server_url}/api/chat",
            json={
                "model": model_name,
                "messages": [{'role': 'user', 'content': question}],
                "stream": False
            },
            headers={'Content-Type': 'application/json'}  # Add headers
        ).json()
        timer.stop()

        # Get performance metrics from response
        total_tokens = response.get('eval_count', 0)
        total_duration = response.get('total_duration', 0)
        total_response_time = float(total_duration) / 1e9
            
        results['total_duration'] = total_response_time
        if total_tokens > 0 and total_response_time > 0:
            results['tokens_per_second'] = total_tokens / total_response_time
                
        # Print concise performance metrics
        print(f"Total Duration (s): {total_response_time:.2f} / Total Tokens: {total_tokens} / Tokens per Second: {results['tokens_per_second']:.2f}")

        # Extract code from response
        if 'message' in response and 'content' in response['message']:
            code = extract_code_from_response(response['message']['content'])
            
            # Validate code
            results['code_valid'] = is_valid_python(code)
            
            if results['code_valid']:
                print(f"Code validation: ✅")
                # Get validation results
                print(f'{WHITE}Running tests...{ENDC}')
                for test_case in CODING_QUESTIONS:
                    if test_case[0] == question:  # Found matching test case
                        function_name = get_function_name_from_question(question)
                        test_cases = test_case[2]  # Get test cases from tuple
                        success, debug_info, test_results = validate_with_debug(code, function_name, test_cases, model_name)  # Changed model to model_name
                        results['tests_passed'] = success
                        results['test_results'] = test_results
                        break
            else:
                print(f"Code Validation: ❌")

        else:
            results['error'] = f"Unexpected response format: {response}"
        
    except Exception as e:
        print(f"\n{RED}Error in get_model_stats: {str(e)}{ENDC}")
        results['error'] = str(e)
    
    return results

def get_function_name_from_question(question: str) -> str:
    """Extract function name from question."""
    if "fibonacci" in question.lower():
        return "fibonacci"
    elif "binary_search" in question.lower():
        return "binary_search"
    elif "palindrome" in question.lower():
        return "is_palindrome"
    elif "anagram" in question.lower():
        return "are_anagrams"
    return ""

def run_model_benchmark(model: str, server_url: str, num_runs: int = 4) -> Dict:
    """
    Run multiple benchmarks for a model and calculate average metrics.
    """
    metrics = []
    
    for i in range(num_runs):
        print(f"\n{YELLOW}[{model}] Run {i+1}/{num_runs}:{ENDC}")
        
        run_results = {}
        for question_tuple in CODING_QUESTIONS:
            test_name = get_test_name(question_tuple[0])
            print(f"\n{BOLD}Testing {test_name}...{ENDC}")
            try:
                result = get_model_stats(model, question_tuple, server_url)
                # Fix: Count actual passed cases from test results
                result['passed_cases'] = len([r for r in result.get('test_results', []) if r])
                result['total_cases'] = len(question_tuple[2])
                run_results[test_name] = result
            except Exception as e:
                print(f"Error in run {i+1}: {e}")
                continue
        
        if run_results:
            metrics.append(run_results)
    
    # Take only the last 3 runs for averaging
    metrics = metrics[-3:]
    num_runs_used = len(metrics)  # Actual number of runs used
    
    if not metrics:
        return {}
        
    # Aggregate results
    aggregated = {
        'model': model,
        'total_duration': mean([m[list(m.keys())[0]]['total_duration'] for m in metrics if m]),
        'tokens_per_second': mean([m[list(m.keys())[0]]['tokens_per_second'] for m in metrics if m]),
        'test_results': {}
    }
    
    # Calculate results per test
    for test_name in metrics[-1].keys():
        # Each test has 6 cases and we use last 3 runs
        cases_per_run = 6
        total_cases_this_test = cases_per_run * len(metrics)  # 6 cases × number of runs used
        
        # Sum up actual passed cases from the test results
        total_passed_this_test = 0
        for m in metrics:
            test_results = m[test_name].get('test_results', [])
            passed_in_run = len([r for r in test_results if r])
            total_passed_this_test += passed_in_run
        
        success_rate = (total_passed_this_test / total_cases_this_test * 100)
        status = '✅' if success_rate == 100 else '❌'
        
        # Print cumulative results header and results
        if test_name == list(metrics[-1].keys())[0]:
            print(f"\n{BOLD}Cumulative Results for each code question:{ENDC}")
            
        print(f"{test_name}: {status} ({total_passed_this_test}/{total_cases_this_test} cases)")
        
        aggregated['test_results'][test_name] = {
            'success_rate': success_rate,
            'passed_cases': total_passed_this_test,
            'total_cases': total_cases_this_test,
            'success_cases_rate': total_passed_this_test / total_cases_this_test,
            'avg_duration': mean([m[test_name]['total_duration'] for m in metrics]),
            'avg_tokens_sec': mean([m[test_name]['tokens_per_second'] for m in metrics])
        }
    
    # Calculate overall success rate and add min/max metrics
    total_passed = sum(t['passed_cases'] for t in aggregated['test_results'].values())
    total_cases = sum(t['total_cases'] for t in aggregated['test_results'].values())
    aggregated['overall_success_rate'] = (total_passed / total_cases * 100) if total_cases > 0 else 0
    aggregated['overall_success_cases_rate'] = (total_passed / total_cases) if total_cases > 0 else 0
    
    # Add min and max metrics for both duration and tokens/sec
    avg_durations = [t['avg_duration'] for t in aggregated['test_results'].values()]
    avg_tokens_sec = [t['avg_tokens_sec'] for t in aggregated['test_results'].values()]
    aggregated['min_avg_duration'] = min(avg_durations) if avg_durations else 0
    aggregated['max_avg_duration'] = max(avg_durations) if avg_durations else 0
    aggregated['min_tokens_per_second'] = min(avg_tokens_sec) if avg_tokens_sec else 0
    aggregated['max_tokens_per_second'] = max(avg_tokens_sec) if avg_tokens_sec else 0
    
    return aggregated

def print_leaderboard(results: List[Dict]):
    """Print leaderboard of model results."""
    if not results:
        print("No results to display")
        return
        
    # Sort by success rate first, then by tokens per second
    sorted_results = sorted(results, key=lambda x: (
        sum(t['passed_cases'] for t in x['test_results'].values()) / sum(t['total_cases'] for t in x['test_results'].values()) if sum(t['total_cases'] for t in x['test_results'].values()) > 0 else 0,
        x['tokens_per_second']
    ), reverse=True)
    
    print(f"\n{HEADER}{BOLD}🏆 Final Model Leaderboard:{ENDC}")
    for i, result in enumerate(sorted_results, 1):
        # Calculate stats for each model
        total_passed = sum(t['passed_cases'] for t in result['test_results'].values())
        total_cases = sum(t['total_cases'] for t in result['test_results'].values())
        success_rate = (total_passed / total_cases * 100) if total_cases > 0 else 0
        
        print(f"\n{BOLD}{YELLOW}{result['model']}{ENDC}")
        print(f"   {BOLD}Overall Success Rate:{ENDC} {success_rate:.1f}% ({total_passed}/{total_cases} cases)")
        print(f"   {BOLD}Average Tokens/sec:{ENDC} {result['tokens_per_second']:.2f} ({result['min_tokens_per_second']:.2f} - {result['max_tokens_per_second']:.2f})")
        print(f"   {BOLD}Average Duration:{ENDC} {result['total_duration']:.2f}s")
        print(f"   {BOLD}Min/Max Avg Duration:{ENDC} {result['min_avg_duration']:.2f}s / {result['max_avg_duration']:.2f}s")
        print(f"   {BOLD}Test Results:{ENDC}")
        for test_name, test_result in result['test_results'].items():
            status = '✅' if test_result['success_rate'] == 100 else '❌'
            print(f"   - {test_name}: {status} {test_result['passed_cases']}/{test_result['total_cases']} cases ({test_result['success_rate']:.1f}%)")

def get_available_models(server_url: str) -> List[str]:
    """Get list of available models from the specified Ollama server."""
    try:
        response = requests.get(f"{server_url}/api/tags").json()
        return [model['name'] for model in response['models']]
    except Exception as e:
        print(f"{RED}Error getting model list from {server_url}: {e}{ENDC}")
        return []

def get_model_details(model_name):
    try:
        result = subprocess.run(
            ["ollama", "show", model_name],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            encoding='utf-8',
            errors='replace'
        )
        
        if result.returncode != 0:
            print(f"Error: {result.stderr.strip()}")
            return None

        if not result.stdout.strip():
            print(f"No details available for model: {model_name}")
            return None

        raw_output = result.stdout.strip()
        lines = raw_output.split('\n')
        current_section = None
        
        for line in lines:
            line = line.rstrip()
            if line and not line.startswith('  '):  # Section headers
                current_section = line.strip()
                print(f"\n  {current_section}")
            elif line and current_section:  # Section content
                # Split by multiple spaces and filter out empty parts
                parts = [part for part in line.split('  ') if part.strip()]
                if len(parts) >= 2:
                    key, value = parts[0].strip(), parts[-1].strip()
                    # Ensure consistent spacing for alignment
                    print(f"    {key:<16} {value}")
                elif len(parts) == 1:
                    # Handle single-value lines (like license text)
                    print(f"    {parts[0].strip()}")

        return None  # No need to return formatted details anymore

    except Exception as e:
        print(f"An error occurred while getting model details: {e}")
        return None

def update_server_results(server_url: str, results: List[Dict]) -> None:
    try:
        # Get CPU brand and format it for filename
        cpu_info = get_cpu_info()
        cpu_brand = cpu_info.get('brand_raw', 'Unknown_CPU').replace(' ', '_')
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        
        # Create a unique filename for this server's results
        server_id = server_url.replace('http://', '').replace(':', '_').replace('/', '_')
        results_dir = "benchmark_results"
        
        os.makedirs(results_dir, exist_ok=True)
        
        # Include CPU brand in filename
        base_filename = f"{cpu_brand}_{server_id}"
        json_filename = os.path.join(results_dir, f"{base_filename}.json")
        log_filename = os.path.join(results_dir, f"{base_filename}.log")
        
        # Load existing results or create new file
        try:
            with open(json_filename, 'r') as f:
                existing_data = json.load(f)
        except FileNotFoundError:
            existing_data = {
                'server_url': server_url,
                'benchmarks': []
            }
        
        # Add new results with timestamp and ensure overall success rate is included
        benchmark_entry = {
            'timestamp': timestamp,
            'results': []
        }
        
        # Add overall success rate to each model's results
        for result in results:
            total_passed = sum(t['passed_cases'] for t in result['test_results'].values())
            total_cases = sum(t['total_cases'] for t in result['test_results'].values())
            result['overall_success_rate'] = (total_passed / total_cases * 100) if total_cases > 0 else 0
            result['min_avg_duration'] = min(t['avg_duration'] for t in result['test_results'].values()) if result['test_results'] else 0
            result['max_avg_duration'] = max(t['avg_duration'] for t in result['test_results'].values()) if result['test_results'] else 0
            benchmark_entry['results'].append(result)
        
        existing_data['benchmarks'].append(benchmark_entry)
        
        # Save updated results
        with open(json_filename, 'w') as f:
            json.dump(existing_data, f, indent=2)
            
        print(f"{GREEN}Successfully saved results to {json_filename}{ENDC}")
        
        # Save console output to log file
        with open(log_filename, 'w') as f:
            # Redirect stdout to capture the leaderboard output
            import io
            import sys
            stdout = sys.stdout
            str_output = io.StringIO()
            sys.stdout = str_output
            
            # Print CPU info
            print("CPU Information:")
            for key, value in cpu_info.items():
                print(f"{key}: {value}")
            print("\nBenchmark Results:")
            print_leaderboard(results)
            
            # Restore stdout and get the captured output
            sys.stdout = stdout
            log_content = str_output.getvalue()
            
            # Write to log file
            f.write(f"Benchmark Run: {timestamp}\n")
            f.write(f"Server: {server_url}\n\n")
            f.write(log_content)
            
        print(f"{GREEN}Console output saved to {log_filename}{ENDC}")
        
    except Exception as e:
        print(f"{RED}Failed to save results: {str(e)}{ENDC}")

def main():
    parser = argparse.ArgumentParser(description='Run Ollama model benchmarks')
    parser.add_argument('--server', choices=['local', 'z60'], default='local',
                      help='Choose Ollama server (default: local)')
    parser.add_argument('--model', type=str, help='Specific model to benchmark')
    parser.add_argument('--number', type=str, help='Number of models to benchmark (number or "all")')
    parser.add_argument('--verbose', action='store_true', help='Enable verbose output')
    args = parser.parse_args()
    
    server_url = SERVERS[args.server]
    
    print()
    print(f"{HEADER}{BOLD}CPU Information:{ENDC}")
    cpu_info = get_cpu_info()
    for key, value in cpu_info.items():
        print(f"{MUTED}{key}: {value}{ENDC}")
    
    print()
    print(f"{INFO}Using Ollama server at {server_url}...{ENDC}")

    # Get available models or use specified model
    if args.model:
        models = [args.model]
    else:
        models = get_available_models(server_url)
        
    if not models:
        print(f"{RED}No models found on server {server_url}. Exiting.{ENDC}")
        return

    # Handle number of models to test
    if args.number and args.number.lower() != 'all':
        try:
            num_models = int(args.number)
            if num_models > 0:
                models = models[:num_models]
            else:
                print(f"{WARNING}Invalid number of models. Using all available models.{ENDC}")
        except ValueError:
            print(f"{WARNING}Invalid number format. Using all available models.{ENDC}")

    print(f"{INFO}Testing {len(models)} models :{ENDC}")
    for i, model in enumerate(models, 1):
        print(f"{YELLOW}{i}. {model}{ENDC}")
    
    # Run benchmarks
    all_results = []
    
    for model in models:
        print(f"\n{HEADER}{BOLD}Benchmarking {model}...{ENDC}")
        details = get_model_details(model)
        if details:
            print(f"\n{INFO}Model Details:{ENDC}")
            if "details" in details:
                for section, items in details["details"].items():
                    print(f"\n{BOLD}{section}{ENDC}")
                    for key, value in items.items():
                        print(f"  {key}: {value}")
            else:
                print(json.dumps(details, indent=2))
        result = run_model_benchmark(model, server_url)
        if 'error' not in result:
            all_results.append(result)
    
    # Print and save results
    print_leaderboard(all_results)
    update_server_results(server_url, all_results)

if __name__ == "__main__":
    main()