# 5_design_editor.py

import os
import json
import argparse
from pathlib import Path
from typing import Dict, Optional, List
import re
import time
from datetime import datetime

from dotenv import load_dotenv
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationResponse, GenerationConfig

# --- Model & Pricing Configuration ---
MODELS_CONFIG = [
    {
        "name": "Gemini 2.5 Pro (GA)",
        "model_id": "gemini-2.5-pro",
        "location": "us-east1",
        "pricing": { "input": 1.25 / 1_000_000, "output": 10.00 / 1_000_000 }
    },
    {
        "name": "Gemini 2.5 Flash (GA)",
        "model_id": "gemini-2.5-flash",
        "location": "us-east1",
        "pricing": { "input": 0.30 / 1_000_000, "output": 2.50 / 1_000_000 }
    },
    {
        "name": "Gemini 2.5 Flash Lite (Preview)",
        "model_id": "gemini-2.5-flash-lite",
        "location": "global",
        "pricing": { "input": 0.10 / 1_000_000, "output": 0.40 / 1_000_000 }
    }
]

DEFAULT_MODEL_INDEX = 0  # Gemini 2.5 Pro for design editing

def print_timestamp():
    """Get current time formatted for progress reporting."""
    return datetime.now().strftime("%H:%M:%S")

def setup_gcp_credentials():
    """Auto-setup GCP credentials from default location or environment."""
    import os
    import json
    from pathlib import Path
    
    # Default paths to check for credentials
    default_key_path = Path("../data/credentials/gcp_key.json")
    alt_key_path = Path("./data/credentials/gcp_key.json")
    
    # Check if GOOGLE_APPLICATION_CREDENTIALS is already set
    if not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
        # Try to find the key file
        key_path = None
        if default_key_path.exists():
            key_path = default_key_path
        elif alt_key_path.exists():
            key_path = alt_key_path
        
        if key_path:
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(key_path.absolute())
            print(f"🔑 Auto-detected GCP key: {key_path}")
    
    # Check if GCLOUD_PROJECT is already set
    if not os.getenv("GCLOUD_PROJECT"):
        # Try to read project ID from the key file
        creds_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
        if creds_path and Path(creds_path).exists():
            try:
                with open(creds_path, 'r') as f:
                    key_data = json.load(f)
                    project_id = key_data.get("project_id")
                    if project_id:
                        os.environ["GCLOUD_PROJECT"] = project_id
                        print(f"🌐 Auto-detected GCP project: {project_id}")
            except Exception as e:
                print(f"⚠️ Could not read project ID from key file: {e}")

def get_gemini_model(model_index: int):
    """Initialize and return Gemini model."""
    # Auto-setup credentials if not already configured
    setup_gcp_credentials()
    
    active_model_config = MODELS_CONFIG[model_index]
    gcloud_project = os.getenv("GCLOUD_PROJECT")
    if not gcloud_project:
        raise ValueError("GCLOUD_PROJECT not found. Please check your GCP setup.")
    
    print(f"   🔗 Connecting to project: {gcloud_project}")
    vertexai.init(project=gcloud_project, location=active_model_config["location"])
    
    model = GenerativeModel(
        active_model_config["model_id"],
        generation_config=GenerationConfig(temperature=0.8)
    )
    return model

def print_usage_and_cost(response: GenerationResponse, model_index: int, stage_name: str) -> dict:
    """Calculate and print API usage costs."""
    usage_data = {"cost": 0, "tokens": 0}
    try:
        active_model_config = MODELS_CONFIG[model_index]
        usage = response.usage_metadata
        pricing = active_model_config["pricing"]
        input_tokens = usage.prompt_token_count
        output_tokens = usage.candidates_token_count
        usage_data["cost"] = (input_tokens * pricing["input"]) + (output_tokens * pricing["output"])
        usage_data["tokens"] = input_tokens + output_tokens
        print(f"   💰 Usage for {stage_name}: {(usage_data['tokens']):,} tokens, cost: ${usage_data['cost']:.6f}")
        print(f"   📊 Breakdown: {input_tokens:,} input + {output_tokens:,} output tokens")
    except Exception: 
        pass
    return usage_data 

def extract_design_number_from_filename(filename: str) -> Optional[int]:
    """Extract the design number from filename (e.g., design_1.html -> 1)."""
    match = re.match(r'design_(\d+)', filename)
    if match:
        return int(match.group(1))
    return None

def get_next_version_number(design_file: Path) -> int:
    """Get the next version number for a design file."""
    parent_dir = design_file.parent
    base_name = design_file.stem  # e.g., 'design_1' or 'design_1_version2'
    
    # Extract the base design name (without version)
    if '_version' in base_name:
        base_design_name = base_name.split('_version')[0]
    else:
        base_design_name = base_name
    
    # Find all existing versions
    existing_versions = []
    for file in parent_dir.glob(f"{base_design_name}*.html"):
        if '_version' in file.stem:
            version_match = re.search(r'_version(\d+)', file.stem)
            if version_match:
                existing_versions.append(int(version_match.group(1)))
    
    # If no versions exist, next is version 2
    if not existing_versions:
        return 2
    
    # Otherwise, return max + 1
    return max(existing_versions) + 1

def load_design_data(site_dir: Path, design_number: int) -> Dict:
    """Load the original design prompt and data."""
    print(f"   [{print_timestamp()}] Loading design data...")
    
    # Load designs.json
    designs_path = site_dir / "designs.json"
    if not designs_path.exists():
        raise FileNotFoundError(f"Designs file not found at {designs_path}")
    
    with open(designs_path, 'r', encoding='utf-8') as f:
        designs_data = json.load(f)
    
    designs = designs_data.get('designs', [])
    if design_number < 1 or design_number > len(designs):
        raise ValueError(f"Invalid design number {design_number}. Available: 1-{len(designs)}")
    
    design_data = designs[design_number - 1]
    
    # Load site brief for context
    brief_path = site_dir / "site_brief.json"
    if not brief_path.exists():
        raise FileNotFoundError(f"Site brief not found at {brief_path}")
    
    with open(brief_path, 'r', encoding='utf-8') as f:
        site_brief = json.load(f)
    
    print(f"   ✅ Loaded design: {design_data.get('design_name', f'Design {design_number}')}")
    
    return {
        'design_data': design_data,
        'site_brief': site_brief,
        'designs_data': designs_data
    } 

def edit_design(
    design_file: Path,
    edit_instructions: str,
    model_index: int = DEFAULT_MODEL_INDEX,
    preserve_versions: bool = True
) -> Dict:
    """
    Edit a design file based on instructions.
    
    Args:
        design_file: Path to the HTML file to edit
        edit_instructions: Natural language instructions for changes
        model_index: Which Gemini model to use
        preserve_versions: Whether to save as a new version (True) or overwrite (False)
    
    Returns:
        Dict with:
            - success: bool
            - output_file: Path to the edited file
            - version_number: int (if versioned)
            - usage_data: dict with cost and tokens
            - error: str (if failed)
    """
    try:
        print(f"\n🎨 DESIGN EDITOR - Editing {design_file.name}")
        print(f"   [{print_timestamp()}] Starting edit process...")
        
        # Read the current HTML
        if not design_file.exists():
            raise FileNotFoundError(f"Design file not found: {design_file}")
        
        with open(design_file, 'r', encoding='utf-8') as f:
            current_html = f.read()
        
        print(f"   📖 Loaded design file: {design_file.name} ({len(current_html):,} bytes)")
        
        # Extract design number to load original context
        design_number = extract_design_number_from_filename(design_file.stem)
        if not design_number:
            raise ValueError(f"Could not extract design number from filename: {design_file.name}")
        
        # Get the site directory (parent of final_output)
        site_dir = design_file.parent.parent
        
        # Load original design data and context
        context = load_design_data(site_dir, design_number)
        design_data = context['design_data']
        site_brief = context['site_brief']
        
        # Initialize Gemini
        model = get_gemini_model(model_index)
        
        print(f"   🤖 Using model: {MODELS_CONFIG[model_index]['name']}")
        print(f"   📝 Edit instructions: {edit_instructions[:100]}...")
        
        # Build the editing prompt
        prompt = build_edit_prompt(
            current_html=current_html,
            edit_instructions=edit_instructions,
            design_data=design_data,
            site_brief=site_brief
        )
        
        print(f"   🚀 Sending to Gemini for editing...")
        start_time = time.time()
        
        # Add retry logic for socket errors - increased for API stability
        max_retries = 5
        retry_delay = 10
        
        for attempt in range(max_retries):
            try:
                response = model.generate_content(prompt)
                break  # Success, exit retry loop
            except Exception as e:
                if "503" in str(e) or "Socket" in str(e):
                    if attempt < max_retries - 1:
                        print(f"   ⚠️ Socket error, retrying in {retry_delay} seconds... (attempt {attempt + 1}/{max_retries})")
                        time.sleep(retry_delay)
                        retry_delay *= 2  # Exponential backoff
                    else:
                        raise  # Re-raise on final attempt
                else:
                    raise  # Re-raise non-socket errors immediately
        
        end_time = time.time()
        print(f"   ✅ Response received in {end_time - start_time:.1f} seconds")
        
        # Extract usage data
        usage_data = print_usage_and_cost(response, model_index, "Design Edit")
        
        # Process the response
        edited_html = extract_html_from_response(response.text)
        
        # Validate the edited HTML
        if not validate_edited_html(edited_html, design_number):
            raise ValueError("Generated HTML failed validation")
        
        # Determine output filename
        if preserve_versions:
            version_number = get_next_version_number(design_file)
            base_name = design_file.stem.split('_version')[0]  # Remove any existing version
            output_file = design_file.parent / f"{base_name}_version{version_number}.html"
        else:
            version_number = None
            output_file = design_file
        
        # Save the edited HTML
        print(f"   💾 Saving edited design to: {output_file.name}")
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(edited_html)
        
        file_size = output_file.stat().st_size
        print(f"   ✅ Design saved successfully! ({file_size:,} bytes)")
        
        return {
            'success': True,
            'output_file': output_file,
            'version_number': version_number,
            'usage_data': usage_data,
            'design_name': design_data.get('design_name', f'Design {design_number}')
        }
        
    except Exception as e:
        print(f"   ❌ Error editing design: {e}")
        return {
            'success': False,
            'error': str(e),
            'usage_data': {'cost': 0, 'tokens': 0}
        }

def build_edit_prompt(current_html: str, edit_instructions: str, design_data: Dict, site_brief: Dict) -> str:
    """Build the prompt for editing a design."""
    
    # Extract key design philosophy and constraints
    design_name = design_data.get('design_name', 'Unknown Design')
    design_philosophy = design_data.get('design_philosophy', '')
    color_scheme = design_data.get('color_scheme', {})
    typography = design_data.get('typography', {})
    
    # Extract brand context
    brand_profile = site_brief.get('brand_profile', {})
    brand_identity = brand_profile.get('brand_identity', '')
    business_keywords = brand_profile.get('keywords', [])
    
    prompt = f"""You are an expert web developer tasked with editing an existing website design based on specific instructions.

**CRITICAL REQUIREMENTS:**
1. Make ONLY the requested changes - preserve everything else
2. Maintain the design's core philosophy and style
3. Keep all technical requirements (responsive, contrast, etc.)
4. Preserve the existing CSS scoping (#design-X)
5. Return ONLY the complete edited HTML

**CURRENT DESIGN CONTEXT:**
- Design: {design_name}
- Philosophy: {design_philosophy}
- Business: {', '.join(business_keywords[:3])}
- Brand: {brand_identity[:200]}

**EDIT INSTRUCTIONS FROM USER:**
{edit_instructions}

**IMPORTANT RULES:**
1. PRESERVE all existing functionality (navigation, responsiveness, etc.)
2. MAINTAIN the design's visual identity unless specifically asked to change
3. KEEP the existing CSS scope selector pattern
4. ENSURE all text remains readable (contrast ratios)
5. UPDATE only what's specifically requested
6. If changing images, use appropriate Unsplash URLs or placeholders
7. If changing text, keep it relevant to the business context

**CURRENT HTML TO EDIT:**
{current_html}

**VALIDATION CHECKLIST:**
- ✅ Changes match the user's instructions
- ✅ Original design philosophy preserved (unless instructed otherwise)
- ✅ All functionality still works
- ✅ Text contrast ratios maintained
- ✅ Responsive design intact
- ✅ CSS properly scoped

Return ONLY the complete edited HTML code, starting with <!DOCTYPE html>."""

    return prompt

def extract_html_from_response(response_text: str) -> str:
    """Extract clean HTML from the AI response."""
    html_content = response_text.strip()
    
    # Remove markdown code blocks if present
    if '```' in html_content:
        # Find the HTML content between markdown blocks
        parts = html_content.split('```')
        for i, part in enumerate(parts):
            # Skip the language identifier (html, HTML, etc.)
            cleaned_part = part.strip()
            if cleaned_part.lower().startswith('html'):
                cleaned_part = cleaned_part[4:].strip()
            
            if cleaned_part.startswith('<!DOCTYPE html>') or cleaned_part.lower().startswith('<html'):
                html_content = cleaned_part
                break
    
    # Clean up any remaining artifacts
    html_content = html_content.strip()
    
    # If still no DOCTYPE, check if the content is valid HTML that just needs DOCTYPE
    if not html_content.startswith('<!DOCTYPE html>'):
        if html_content.lower().startswith('<html'):
            # Add DOCTYPE if missing but HTML tag present
            html_content = '<!DOCTYPE html>\n' + html_content
        else:
            # Last resort - look for HTML anywhere in the response
            doctype_index = html_content.find('<!DOCTYPE html>')
            if doctype_index >= 0:
                html_content = html_content[doctype_index:]
            else:
                raise ValueError("Response does not contain valid HTML starting with <!DOCTYPE html>")
    
    return html_content

def validate_edited_html(html: str, design_number: int) -> bool:
    """Validate that the edited HTML meets requirements."""
    print(f"   🔍 Validating edited HTML...")
    
    # Check basic structure
    if not html.startswith('<!DOCTYPE html>'):
        print(f"   ❌ Missing DOCTYPE declaration")
        return False
    
    # Check for design wrapper
    if f'id="design-{design_number}"' not in html:
        print(f"   ❌ Missing design wrapper id='design-{design_number}'")
        return False
    
    # Check for CSS scoping
    if f'#design-{design_number}' not in html:
        print(f"   ❌ Missing CSS scope selector #design-{design_number}")
        return False
    
    # Check minimum size (should be substantial)
    if len(html) < 5000:
        print(f"   ❌ HTML seems too small ({len(html)} bytes)")
        return False
    
    print(f"   ✅ HTML validation passed")
    return True 

def list_design_versions(site_dir: Path, design_number: int) -> List[Dict]:
    """List all versions of a design."""
    output_dir = site_dir / "final_output"
    if not output_dir.exists():
        return []
    
    versions = []
    base_name = f"design_{design_number}"
    
    # Check for original
    original_file = output_dir / f"{base_name}.html"
    if original_file.exists():
        versions.append({
            'version': 'original',
            'file': original_file,
            'size': original_file.stat().st_size,
            'modified': datetime.fromtimestamp(original_file.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
        })
    
    # Check for versions
    for file in sorted(output_dir.glob(f"{base_name}_version*.html")):
        version_match = re.search(r'_version(\d+)', file.stem)
        if version_match:
            version_num = int(version_match.group(1))
            versions.append({
                'version': f'version{version_num}',
                'file': file,
                'size': file.stat().st_size,
                'modified': datetime.fromtimestamp(file.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
            })
    
    return versions

def main():
    """Main function for CLI usage."""
    print(f"\n🎨 DESIGN EDITOR - Edit existing website designs")
    print(f"⏰ Started at: {print_timestamp()}")
    
    load_dotenv()
    
    model_options_text = ', '.join([f"{i}='{config['name']}'" for i, config in enumerate(MODELS_CONFIG)])
    
    parser = argparse.ArgumentParser(description="Edit existing website designs using AI.")
    
    # Support both old positional arguments and new API-style arguments
    parser.add_argument("design_file", type=str, nargs='?', help="Path to the design HTML file to edit (e.g., final_output/design_1.html)")
    parser.add_argument("instructions", type=str, nargs='?', help="Edit instructions in natural language")
    
    # New API-style arguments
    parser.add_argument("--domain", type=str, help="Domain name (for API calls)")
    parser.add_argument("--design", type=int, help="Design number (1, 2, or 3)")
    parser.add_argument("--edit-instructions", type=str, help="Edit instructions")
    parser.add_argument("--version", type=str, help="Version to edit (e.g., 'version2')")
    
    parser.add_argument("--model-index", type=int, default=DEFAULT_MODEL_INDEX, 
                       choices=range(len(MODELS_CONFIG)),
                       help=f"Model to use ({model_options_text})")
    parser.add_argument("--overwrite", action="store_true", 
                       help="Overwrite the original file instead of creating a version")
    parser.add_argument("--list-versions", action="store_true",
                       help="List all versions of this design and exit")
    
    args = parser.parse_args()
    
    # Handle both API-style and traditional arguments
    if args.domain and args.design and args.edit_instructions:
        # API-style call - construct paths from domain and design info
        print(f"🌐 API mode: domain={args.domain}, design={args.design}, version={args.version}")
        
        # Find the website data directory (same logic as other scripts)
        possible_data_dirs = [
            Path("../data/website_data"),  # From scripts directory
            Path("data/website_data"),     # From project root
            Path("dashboard/data/website_data")  # Alternative path
        ]
        
        website_data_dir = None
        for data_dir in possible_data_dirs:
            if data_dir.exists():
                website_data_dir = data_dir
                break
        
        if not website_data_dir:
            print(f"❌ Error: Could not find website data directory")
            return
        
        site_dir = website_data_dir / args.domain
        output_dir = site_dir / "final_output"
        
        if not output_dir.exists():
            print(f"❌ Error: Output directory not found: {output_dir}")
            return
        
        # Determine which file to edit
        if args.version and args.version != 'original':
            design_file = output_dir / f"design_{args.design}_{args.version}.html"
        else:
            design_file = output_dir / f"design_{args.design}.html"
        
        instructions = args.edit_instructions
        
    elif args.design_file and args.instructions:
        # Traditional positional arguments
        print(f"📁 Traditional mode: file={args.design_file}")
        design_file = Path(args.design_file)
        instructions = args.instructions
        
        # Extract site directory
        if 'final_output' in design_file.parts:
            site_dir = design_file.parent.parent
        else:
            print(f"❌ Error: Design file should be in a final_output directory")
            return
    else:
        print(f"❌ Error: Must provide either (design_file + instructions) or (--domain + --design + --instructions)")
        parser.print_help()
        return
    
    # If listing versions, do that and exit
    if args.list_versions:
        if args.domain and args.design:
            design_number = args.design
        else:
            design_number = extract_design_number_from_filename(design_file.stem)
            if not design_number:
                print(f"❌ Error: Could not extract design number from {design_file.name}")
                return
        
        print(f"\n📋 Versions of Design {design_number}:")
        versions = list_design_versions(site_dir, design_number)
        
        if not versions:
            print(f"   No versions found")
        else:
            for v in versions:
                print(f"   • {v['version']:10} - {v['file'].name:30} ({v['size']:,} bytes) - Modified: {v['modified']}")
        return
    
    # Perform the edit
    print(f"\n🎯 Target file: {design_file}")
    print(f"🤖 Using model: {MODELS_CONFIG[args.model_index]['name']}")
    print(f"📝 Instructions: {instructions}")
    print(f"💾 Version mode: {'Overwrite original' if args.overwrite else 'Create new version'}")
    
    result = edit_design(
        design_file=design_file,
        edit_instructions=instructions,
        model_index=args.model_index,
        preserve_versions=not args.overwrite
    )
    
    if result['success']:
        print(f"\n✅ EDIT SUCCESSFUL!")
        print(f"   📄 Output file: {result['output_file']}")
        if result.get('version_number'):
            print(f"   🔢 Version: {result['version_number']}")
        print(f"   🎨 Design: {result.get('design_name', 'Unknown')}")
        print(f"   💰 Cost: ${result['usage_data']['cost']:.6f}")
        print(f"   🔢 Tokens: {result['usage_data']['tokens']:,}")
        
        print(f"\n🌐 To view the edited design:")
        print(f"   1. Open the file directly in your browser:")
        print(f"      {result['output_file']}")
        
        if result.get('version_number'):
            print(f"\n📋 Version Management:")
            print(f"   • This is version {result['version_number']} of the design")
            print(f"   • To revert, simply use the original or any previous version")
            print(f"   • Run with --list-versions to see all versions")
        
    else:
        print(f"\n❌ EDIT FAILED!")
        print(f"   Error: {result.get('error', 'Unknown error')}")
    
    print(f"\n⏰ Finished at: {print_timestamp()}")

# API function for dashboard integration
def edit_design_api(
    domain: str,
    design_number: int,
    version: Optional[str],
    edit_instructions: str,
    model_index: int = DEFAULT_MODEL_INDEX
) -> Dict:
    """
    API-friendly function for editing designs from the dashboard.
    
    Args:
        domain: The website domain (e.g., 'devramesh.com')
        design_number: Which design to edit (1, 2, or 3)
        version: Version to edit ('original', 'version2', etc.) or None for latest
        edit_instructions: Natural language instructions for changes
        model_index: Which Gemini model to use
    
    Returns:
        Dict with:
            - success: bool
            - output_file: str (relative path from website data dir)
            - version_number: int (new version number)
            - usage_data: dict with cost and tokens
            - design_name: str
            - all_versions: list of all versions after edit
            - error: str (if failed)
    """
    try:
        # Construct paths
        website_dir = Path("data/website_data") / domain
        output_dir = website_dir / "final_output"
        
        if not output_dir.exists():
            raise FileNotFoundError(f"No output directory found for {domain}")
        
        # Determine which file to edit
        if version and version != 'original':
            design_file = output_dir / f"design_{design_number}_{version}.html"
        else:
            design_file = output_dir / f"design_{design_number}.html"
        
        if not design_file.exists():
            # Try to find the latest version
            versions = list_design_versions(website_dir, design_number)
            if versions:
                design_file = versions[-1]['file']
            else:
                raise FileNotFoundError(f"No design file found for design {design_number}")
        
        # Perform the edit
        result = edit_design(
            design_file=design_file,
            edit_instructions=edit_instructions,
            model_index=model_index,
            preserve_versions=True  # Always preserve versions in API mode
        )
        
        if result['success']:
            # Get all versions after edit
            all_versions = list_design_versions(website_dir, design_number)
            
            # Convert paths to relative strings
            result['output_file'] = str(result['output_file'].relative_to(Path("data/website_data")))
            result['all_versions'] = [
                {
                    'version': v['version'],
                    'file': str(v['file'].relative_to(Path("data/website_data"))),
                    'size': v['size'],
                    'modified': v['modified']
                }
                for v in all_versions
            ]
        
        return result
        
    except Exception as e:
        return {
            'success': False,
            'error': str(e),
            'usage_data': {'cost': 0, 'tokens': 0}
        }

if __name__ == "__main__":
    main() 