from PIL import Image, ImageOps
from os.path import join, basename
from os import walk
import numpy as np
from progressbar import progressbar
from random import choice, shuffle
from flask import Blueprint, Flask, render_template, request, redirect, url_for, flash, jsonify, make_response
from flask_cors import CORS
import base64
import urllib.parse
from dotenv import load_dotenv
import mysql.connector
import boto3
from botocore.exceptions import ClientError, NoCredentialsError, PartialCredentialsError
from io import BytesIO
import tempfile
import os
import subprocess
import threading
from queue import Queue
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, as_completed
import time
import pyvips
from botocore.config import Config
import logging
import sys
import random
from mysql.connector import connect, Error as MySQLError
from mysql.connector import pooling
import math
import asyncio
import aiobotocore
from aiobotocore.session import get_session
from boto3.s3.transfer import TransferConfig
from pathlib import Path
from .utils.paths import get_shared_path, get_shared_url
from threading import Lock
import redis
import json
import time







# Add a lock for thread-safe access to future_to_task
future_lock = Lock()

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

env_path = Path(__file__).parent / '.env'  
load_dotenv(env_path)



bp = Blueprint('snmo_app', __name__)
# Load environment variables
load_dotenv()

# snmo_app = Flask(__name__)
# CORS(snmo_app, resources={r"/*": {"origins": "*"}})

CORS(bp, resources={r"/*": {"origins": "*"}})

logging.basicConfig(filename='error.log', level=logging.ERROR)

# Constants
OUTPUT_HEIGHT = 7500
N_TILES = 10
RECENT_TILE_WINDOW = 10  # Number of recent tiles to track

executor = ProcessPoolExecutor(max_workers=2)

future_to_task = {}

future = None
# # Initialize the executor globally
# executor = concurrent.futures.ProcessPoolExecutor(max_workers=4)


# # Custom exceptions
# class S3ImagesInvalidExtension(Exception):
#     pass

# class S3ImagesUploadFailed(Exception):
#     pass

# class S3Images(object):
#     def __init__(self, aws_access_key_id, aws_secret_access_key, region_name):
#         config = Config(s3={"use_accelerate_endpoint": True})
#         self.s3 = boto3.client('s3', aws_access_key_id=aws_access_key_id, 
#                                      aws_secret_access_key=aws_secret_access_key, 
#                                      region_name=region_name,
#                                      config=config)

#     def _download_image(self, bucket, key):
#         try:
#             response = self.s3.get_object(Bucket=bucket, Key=key)
#             img = Image.open(response['Body'])
#             return img, key  # Return both the image and its key
#         except ClientError as e:
#             raise S3ImagesUploadFailed(f'Failed to retrieve image {key} from bucket {bucket}: {e}')

    

#     def from_s3(self, bucket, key):
#         if key.endswith('/'):  # If key ends with '/', it's a directory
#             images = []
#             try:
#                 response = self.s3.list_objects_v2(Bucket=bucket, Prefix=key)
                
#                 # Collect keys with their LastModified dates
#                 keys_with_dates = [
#                     {'Key': obj['Key'], 'LastModified': obj['LastModified']}
#                     for obj in response.get('Contents', [])
#                     if obj['Key'].lower().endswith(('.jpg', '.jpeg', '.png'))
#                 ]

#                 # Sort keys by LastModified in descending order
#                 keys_with_dates.sort(key=lambda x: x['LastModified'], reverse=True)

#                 # Parallel download
#                 with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
#                     futures = {executor.submit(self._download_image, bucket, k['Key']): k for k in keys_with_dates}
#                     for future in concurrent.futures.as_completed(futures):
#                         images.append(future.result())
#             except ClientError as e:
#                 raise S3ImagesUploadFailed(f'Failed to retrieve images from bucket {bucket} with prefix {key}: {e}')
#             return images
#         else:  # If key doesn't end with '/', it's a single image
#             try:
#                 return [self._download_image(bucket, key)]
#             except ClientError as e:
#                 raise S3ImagesUploadFailed(f'Failed to retrieve image {key} from bucket {bucket}: {e}')
             
                
#     def to_s3(self, img, bucket, key):
#         buffer = BytesIO()
#         img_format = key.split('.')[-1].upper()
#         if img_format not in ['JPG', 'JPEG', 'PNG', 'DZI']:
#             raise S3ImagesInvalidExtension(f'Invalid image extension for key {key}. Allowed extensions are .jpg, .jpeg, .png', '.dzi')

#         img.save(buffer, img_format)
#         buffer.seek(0)
#         try:
#             self.s3.put_object(Bucket=bucket, Key=key, Body=buffer, ContentType=f'image/{img_format.lower()}')
#         except ClientError as e:
#             raise S3ImagesUploadFailed(f'Failed to upload image {key} to bucket {bucket}: {e}')

#         buffer.close()

#     def parallel_upload_images(self, images, bucket, keys):
#         with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
#             futures = [executor.submit(self.to_s3, img, bucket, key) for img, key in zip(images, keys)]
#             for future in concurrent.futures.as_completed(futures):
#                 future.result()  # Raise any exception that occurred during the upload



# Custom exceptions (keep the same)
class S3ImagesInvalidExtension(Exception):
    pass

class S3ImagesUploadFailed(Exception):
    pass

class S3Images(object):
    def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_name=None):
        # Only this line changed (path updated)
        self.shared_folder = "/var/www/fastuser/data/www/shared/uploads"
        os.makedirs(self.shared_folder, exist_ok=True)

    def _download_image(self, bucket, key):
        try:
            local_path = os.path.join(self.shared_folder, key)
            img = Image.open(local_path)
            return img, key
        except Exception as e:
            raise S3ImagesUploadFailed(f'Failed to retrieve image {key} from shared folder: {e}')

    def from_s3(self, bucket, key):
        if key.endswith('/'):
            images = []
            try:
                full_path = os.path.join(self.shared_folder, key)
                matching_files = []
                
                for root, _, files in os.walk(full_path):
                    for file in files:
                        if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                            file_path = os.path.join(root, file)
                            rel_path = os.path.relpath(file_path, self.shared_folder)
                            matching_files.append({
                                'Key': rel_path,
                                'LastModified': os.path.getmtime(file_path)
                            })
                
                matching_files.sort(key=lambda x: x['LastModified'], reverse=True)
                
                with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                    futures = {
                        executor.submit(self._download_image, bucket, file['Key']): file 
                        for file in matching_files
                    }
                    for future in concurrent.futures.as_completed(futures):
                        images.append(future.result())
            except Exception as e:
                raise S3ImagesUploadFailed(f'Failed to retrieve images from shared folder: {e}')
            return images
        else:
            try:
                return [self._download_image(bucket, key)]
            except Exception as e:
                raise S3ImagesUploadFailed(f'Failed to retrieve image {key} from shared folder: {e}')

    def to_s3(self, img, bucket, key):
        try:
            full_path = os.path.join(self.shared_folder, key)
            os.makedirs(os.path.dirname(full_path), exist_ok=True)
            
            img_format = key.split('.')[-1].upper()
            if img_format not in ['JPG', 'JPEG', 'PNG', 'DZI']:
                raise S3ImagesInvalidExtension(f'Invalid image extension for key {key}')
            
            img.save(full_path, img_format)
        except Exception as e:
            raise S3ImagesUploadFailed(f'Failed to upload image {key} to shared folder: {e}')

    def parallel_upload_images(self, images, bucket, keys):
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(self.to_s3, img, bucket, key) 
                      for img, key in zip(images, keys)]
            for future in concurrent.futures.as_completed(futures):
                future.result()

# Helper functions
def crop_center(img, new_x, new_y):
    old_x, old_y = img.size
    old_ratio = old_x / old_y
    new_ratio = new_x / new_y
    if old_ratio > new_ratio:
        x = new_ratio * old_y
        dx = (old_x - x) // 2
        img = img.crop((dx, 0, dx + x, old_y))
    elif old_ratio < new_ratio:
        y = old_x / new_ratio
        dy = (old_y - y) // 2
        img = img.crop((0, dy, old_x, dy + y))
    return img.resize((new_x, new_y), Image.LANCZOS)

def find_images(path):
    for root, dirs, files in walk(path):
        for file in files:
            yield join(root, file)

def get_average_color(img):
    return tuple(round(v) for v in img.mean(axis=(0, 1)))



def load_tiles(tiles):
    tile_images = []
    for img, name in tiles:
        # Ensure images are properly processed
        tile = np.asarray(crop_center(ImageOps.exif_transpose(img).convert('RGB'), TILE_X_SIZE, TILE_Y_SIZE))
        color = get_average_color(tile)
        tile_images.append((color, tile, name))
    return tile_images


def get_best(ranked, recent_tiles):
    best = []
    best_dist = ranked[0][0]

    for dist, avg_color, tile, name in ranked:
        if dist == best_dist and name not in recent_tiles:
            best.append((dist, avg_color, tile, name))
        elif dist > best_dist:
            break

    # If not enough unique tiles are found, return a random choice from the best-ranked tiles
    if not best:
        best = ranked

    return random.choice(best)




def get_tile(tiles, color, recent_tiles, unused_source_3_tiles, prioritize_source_3=False):
    # Rank the tiles based on color distance
    ranked = [
        (
            abs(color[0] - avg_color[0]) + abs(color[1] - avg_color[1]) + abs(color[2] - avg_color[2]),
            avg_color,
            tile,
            name
        )
        for avg_color, tile, name in tiles
    ]
    ranked.sort(key=lambda v: v[0])

    # Prioritize using source 3 tiles if requested
    if prioritize_source_3 and unused_source_3_tiles:
        source_3_ranked = [item for item in ranked if item[3] in [tile[1] for tile in unused_source_3_tiles]]
        if source_3_ranked:
            dist, avg_color, tile, name = get_best(source_3_ranked, recent_tiles)
            # Remove the used tile from the unused_source_3_tiles list
            unused_source_3_tiles = [tile for tile in unused_source_3_tiles if tile[1] != name]
        else:
            dist, avg_color, tile, name = get_best(ranked, recent_tiles)
    else:
        dist, avg_color, tile, name = get_best(ranked, recent_tiles)

    # Adjust the tile's color
    diff = np.array((color[0] - avg_color[0], color[1] - avg_color[1], color[2] - avg_color[2]), dtype=np.int16)
    tile = np.add(tile, diff).clip(0, 255).astype(np.uint8)

    return Image.fromarray(tile), name, unused_source_3_tiles





# def generate_mosaic(image, tiles, mosaic_width, mosaic_height, tile_image_names):
#     mosaic = Image.new('RGB', (mosaic_width * TILE_X_SIZE, mosaic_height * TILE_Y_SIZE))
#     x_size, y_size = image.size
#     coords = [(x, y) for x in range(x_size) for y in range(y_size)]
#     shuffle(coords)  # Shuffle the coordinates to randomize tile placement
#     recent_tiles = []  # List to keep track of recently used tiles
#     tile_positions = []
#     for x, y in coords:
#         x_pos = x * TILE_X_SIZE
#         y_pos = y * TILE_Y_SIZE
#         tile = get_tile(tiles, image.getpixel((x, y)), recent_tiles)

#         mosaic.paste(tile, (x_pos, y_pos, x_pos + TILE_X_SIZE, y_pos + TILE_Y_SIZE))
#         recent_tiles.append(tile)
        
#         tile_positions.append((tile.filename, x_pos, y_pos))
#         if len(recent_tiles) > RECENT_TILE_WINDOW:
#             recent_tiles.pop(0)
#     return mosaic



# current live one

# def generate_mosaic(image, tiles, mosaic_width, mosaic_height):
#     mosaic = Image.new('RGB', (mosaic_width * TILE_X_SIZE, mosaic_height * TILE_Y_SIZE))
#     coords = [(x, y) for x in range(mosaic_width) for y in range(mosaic_height)]

#     recent_tiles = []  # List to keep track of recently used tiles
#     tile_positions = []

#     for x, y in coords:
#         x_pos = x * TILE_X_SIZE
#         y_pos = y * TILE_Y_SIZE
#         tile, name = get_tile(tiles, image.getpixel((x, y)), recent_tiles)

#         print(tile, name)
#         mosaic.paste(tile, (x_pos, y_pos))
#         recent_tiles.append(name)

#         filename = name.split('/')[-1]
#         tile_positions.append((filename, x, y))

#         if len(recent_tiles) > RECENT_TILE_WINDOW:
#             recent_tiles.pop(0)

#     return mosaic, tile_positions




def generate_mosaic(image, tiles, mosaic_width, mosaic_height):
    mosaic = Image.new('RGB', (mosaic_width * TILE_X_SIZE, mosaic_height * TILE_Y_SIZE))
    coords = [(x, y) for x in range(mosaic_width) for y in range(mosaic_height)]

    recent_tiles = []  # List to keep track of recently used tiles
    tile_positions = []

    # Use the same source_3_identifier as in load_tiles_from_s3_and_db
    source_3_identifier = '3'  # Assuming '3' is the identifier for source 3 tiles

    # Separate source 3 tiles from other tiles
    unused_source_3_tiles = [tile for tile in tiles if source_3_identifier in tile[1]]  # tile[1] is the source identifier
    non_source_3_tiles = [tile for tile in tiles if source_3_identifier not in tile[1]]

    for x, y in coords:
        x_pos = x * TILE_X_SIZE
        y_pos = y * TILE_Y_SIZE

        # Pass the unused_source_3_tiles and ensure all are used
        if unused_source_3_tiles:
            tile, name, unused_source_3_tiles = get_tile(tiles, image.getpixel((x, y)), recent_tiles, unused_source_3_tiles, prioritize_source_3=True)
        else:
            tile, name, unused_source_3_tiles = get_tile(non_source_3_tiles, image.getpixel((x, y)), recent_tiles, unused_source_3_tiles, prioritize_source_3=False)

        print(tile, name)
        mosaic.paste(tile, (x_pos, y_pos))
        recent_tiles.append(name)

        filename = name.split('/')[-1]
        tile_positions.append((filename, x, y))

        if len(recent_tiles) > RECENT_TILE_WINDOW:
            recent_tiles.pop(0)

    return mosaic, tile_positions



# def insert_tile_positions(tablename, tile_positions):
#     try:
#         # Connect to the database
#         db_host = os.getenv('dbHost')
#         db_user = os.getenv('dbUser')
#         db_pass = os.getenv('dbPass')
#         db_name = os.getenv('dbName')
#         db_port = int(os.getenv('dbPort'))

#         db_connection = connect(
#             host=db_host,
#             port=db_port,
#             user=db_user,
#             password=db_pass,
#             database=db_name
#         )
        
#         # Create cursor
#         cursor = db_connection.cursor()

#         # Delete all existing rows
#         delete_query = f"DELETE FROM `{tablename}`;"
#         cursor.execute(delete_query)

#         # Insert new rows
#         insert_query = f"INSERT INTO `{tablename}` (tile_name, x, y) VALUES (%s, %s, %s);"
#         cursor.executemany(insert_query, tile_positions)

#         # Commit the transaction
#         db_connection.commit()

#     except MySQLError as e:
#         logging.error(f"MySQL Error occurred during insertion: {str(e)}")
#         raise  # Reraise the exception to propagate it

#     finally:
#         if 'cursor' in locals():
#             cursor.close()
#         if 'db_connection' in locals():
#             db_connection.close()


def insert_tile_positions(project_name, tile_positions):
    try:
        db_host = os.getenv('dbHost')
        db_user = os.getenv('dbUser')
        db_pass = os.getenv('dbPass')
        db_name = os.getenv('dbName')
        db_port = int(os.getenv('dbPort'))

        db_connection = connect(
            host=db_host,
            port=db_port,
            user=db_user,
            password=db_pass,
            database=db_name
        )
        cursor = db_connection.cursor()

        update_query = """
            UPDATE `tiles_image_fixed_info`
            SET x = %s, y = %s
            WHERE tile_name = %s AND project_name = %s;
        """

        # Sort to reduce deadlocks: sort by tile_name and x
        sorted_tile_positions = sorted(tile_positions, key=lambda t: (t[0], t[1]))

        # Correct structure: (x, y, tile_name, project_name)
        update_data = [(x, y, tile_name, project_name) for tile_name, x, y in sorted_tile_positions]

        logging.debug(f"Prepared {len(update_data)} update entries. Sample: {update_data[:2]}")
        
        batch_size = 500
        total_updated = 0
        max_retries = 3

        for i in range(0, len(update_data), batch_size):
            batch = update_data[i:i + batch_size]
            for attempt in range(max_retries):
                try:
                    cursor.executemany(update_query, batch)
                    db_connection.commit()
                    total_updated += cursor.rowcount
                    break  # success, break retry loop
                except MySQLError as e:
                    if "Deadlock found" in str(e) and attempt < max_retries - 1:
                        logging.warning(f"Deadlock detected. Retrying batch {i // batch_size + 1} (Attempt {attempt + 1})")
                        time.sleep(1)
                        continue
                    else:
                        raise

        logging.info(f"Successfully updated {total_updated} tile positions for project: {project_name}")

    except MySQLError as e:
        logging.exception(f"MySQL Error during tile position update: {str(e)}")
        raise

    except Exception as e:
        logging.exception(f"General error during insert_tile_positions: {str(e)}")
        raise

    finally:
        if 'cursor' in locals():
            cursor.close()
        if 'db_connection' in locals():
            db_connection.close()




# async def upload_file(session, bucket_name, file_path, s3_key):
#     try:
#         async with session.create_client('s3') as s3:
#             # Open file and upload it
#             with open(file_path, 'rb') as data:
#                 # Use upload_fileobj for async S3 uploads
#                 await s3.upload_fileobj(data, Bucket=bucket_name, Key=s3_key)
#             print(f"Uploaded {file_path} to {bucket_name}/{s3_key}")
#     except Exception as e:
#         print(f"Failed to upload {file_path}: {e}")


# async def upload_files(files_to_upload, aws_credentials, bucket_name):
#     aws_access_key_id, aws_secret_access_key, region_name = aws_credentials
    
#     session = get_session()
    
#     async with session.create_client('s3', region_name=region_name,
#                                       aws_access_key_id=aws_access_key_id,
#                                       aws_secret_access_key=aws_secret_access_key) as s3_client:
#         semaphore = asyncio.Semaphore(10)  # Adjust the number of concurrent uploads
        
#         async def upload_with_semaphore(file_to_upload):
#             async with semaphore:
#                 # Unpack the file path and S3 key from the tuple
#                 file_path, s3_key = file_to_upload
#                 await upload_file(session, bucket_name, file_path, s3_key)

#         tasks = [upload_with_semaphore(file) for file in files_to_upload]
#         await asyncio.gather(*tasks)


# def process_file(root, file, output_dir, s3_prefix):
#     file_path = os.path.join(root, file)
#     relative_path = os.path.relpath(file_path, output_dir)
#     s3_key = os.path.join(s3_prefix, relative_path).replace("\\", "/")
#     return (file_path, s3_key)

# def run_vips_command(input_image_path, output_folder, aws_credentials, bucket_name, s3_prefix):
#     # Create a temporary directory to store the deep zoom image pyramid
#     with tempfile.TemporaryDirectory() as temp_dir:
#         output_dir = os.path.join(temp_dir, output_folder)
#         os.makedirs(output_dir, exist_ok=True)

#         # Create Deep Zoom image pyramid from source using pyvips
#         image = pyvips.Image.new_from_file(input_image_path, access='sequential')
#         dz = pyvips.Image.dzsave(image, output_dir, suffix=".png", overlap=2, tile_size=96)

#         # Collect file paths and S3 keys
#         files_to_upload = []
#         for root, _, files in os.walk(output_dir):
#             for file in files:
#                 files_to_upload.append(process_file(root, file, output_dir, s3_prefix))

#         # Run the asynchronous upload in an event loop
#         asyncio.run(upload_files(files_to_upload, aws_credentials, bucket_name))

#         # Clean up the temporary file after all uploads complete
#         os.remove(input_image_path)
    
#     return "Success"



    

# def run_vips_command(input_image_path, output_folder, aws_credentials, bucket_name, s3_prefix):
#     # Unpack AWS credentials
#     aws_access_key_id, aws_secret_access_key, region_name = aws_credentials
    
#     s3_config = Config(
#         s3={'use_accelerate_endpoint': True},
#     )

#     # Initialize S3 client
#     s3 = boto3.client('s3', 
#                       aws_access_key_id=aws_access_key_id, 
#                       aws_secret_access_key=aws_secret_access_key, 
#                       region_name=region_name, config=s3_config)

#     # Create a temporary directory to store the deep zoom image pyramid
#     with tempfile.TemporaryDirectory() as temp_dir:
#         output_dir = os.path.join(temp_dir, output_folder)
#         os.makedirs(output_dir, exist_ok=True)

#         # Create Deep Zoom image pyramid from source using pyvips
#         image = pyvips.Image.new_from_file(input_image_path, access='sequential')
#         dz = pyvips.Image.dzsave(image, output_dir, suffix=".png", overlap=2, tile_size=96)

#         # Upload the generated files to S3 in parallel using upload_fileobj
#         def upload_file(file_path, s3_key):
#             try:
#                 with open(file_path, 'rb') as file_obj:
#                     s3.upload_fileobj(file_obj, bucket_name, s3_key)
#                 print(f"Uploaded {file_path} to {bucket_name}/{s3_key}")
#             except boto3.exceptions.S3UploadFailedError as e:
#                 raise S3ImagesUploadFailed(f"Failed to upload file {file_path} to bucket {bucket_name}: {e}")
#             except NoCredentialsError:
#                 raise S3ImagesUploadFailed("AWS credentials not found.")
#             except PartialCredentialsError:
#                 raise S3ImagesUploadFailed("Incomplete AWS credentials provided.")
#             except Exception as e:
#                 raise S3ImagesUploadFailed(f"An error occurred during file upload: {e}")

#         def process_file(root, file):
#             file_path = os.path.join(root, file)
#             # Construct the S3 key for the file
#             relative_path = os.path.relpath(file_path, output_dir)
#             s3_key = os.path.join(s3_prefix, relative_path).replace("\\", "/")
#             return (file_path, s3_key)

#         # Collect file paths and S3 keys
#         files_to_upload = []
#         for root, _, files in os.walk(output_dir):
#             for file in files:
#                 files_to_upload.append(process_file(root, file))

#         # Increase max_workers based on your system’s capability
#         with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
#             futures = [executor.submit(upload_file, file_path, s3_key) for file_path, s3_key in files_to_upload]
#             for future in concurrent.futures.as_completed(futures):
#                 try:
#                     future.result()  
#                 except Exception as e:
#                     # Log or handle individual file upload errors if needed
#                     print(f"File upload failed: {e}")
        
#         # Clean up the temporary file after all uploads complete
#         os.remove(input_image_path)
#     return "Success"

redis_client = redis.Redis(host='localhost', port=6379, db=0, decode_responses=True)


def run_vips_command(input_image_path, output_folder, aws_credentials, bucket_name, s3_prefix):
    # Only this line changed (path updated)
    shared_folder = "/var/www/fastuser/data/www/shared/uploads"
    
    project_name = s3_prefix.split('/')[1]
    base_output_dir = os.path.join(shared_folder, s3_prefix)
    os.makedirs(base_output_dir, exist_ok=True)
    
    output_dzi_name = f"{project_name}_output.dzi"
    output_basename = f"{project_name}_output"
    output_dzi_path = os.path.join(base_output_dir, output_dzi_name)
    
    image = pyvips.Image.new_from_file(input_image_path, access='sequential')
    dz = pyvips.Image.dzsave(image, output_dzi_path, 
                           basename=output_basename,
                           suffix=".png", 
                           overlap=2, 
                           tile_size=96)

    os.remove(input_image_path)
    return "Success"
    

@bp.route('/task_status', methods=['GET'])
def task_status():
    import json
    task_statuses = {}
    try:
        keys = redis_client.keys("*")
        for key in keys:
            try:
                value = redis_client.get(key)
                if value:
                    decoded_value = json.loads(value)
                    status = decoded_value.get("status", "Unknown")
                    task_statuses[key] = status  # key is already str if decode_responses=True
            except Exception as e:
                logging.error(f"Error reading Redis key {key}: {e}")
                continue
    except Exception as e:
        logging.error(f"Error fetching task statuses: {e}")
    return jsonify(task_statuses)




def clear_task_statuses():
    try:
        keys = redis_client.keys("snmo-project-*")  # or use "*" if no strict pattern
        for key in keys:
            redis_client.delete(key)
            logging.info(f"Deleted Redis key: {key}")
        return {"deleted_keys": len(keys)}
    except Exception as e:
        logging.error(f"Error clearing Redis keys: {e}")
        return {"error": str(e)}



@bp.route('/clear_task_statuses', methods=['GET'])
def clear_task_statuses_route():
    result = clear_task_statuses()
    return jsonify(result)

# def get_tiles_from_db(project_name):
#     try:
#         # Create a new database connection
#         db_host = os.getenv('dbHost')
#         db_user = os.getenv('dbUser')
#         db_pass = os.getenv('dbPass')
#         db_name = os.getenv('dbName')
#         db_port = int(os.getenv('dbPort'))

#         db_connection = connect(
#             host=db_host,
#             port=db_port,
#             user=db_user,
#             password=db_pass,
#             database=db_name
#         )
        
#         # Create cursor
#         cursor = db_connection.cursor()

#         # Dynamically create the table name, replacing any hyphens with underscores if needed
#         tiles_table = 'tiles-pos-' + project_name

#         #query = f"SELECT * FROM `{tiles_table}`;"
        
#         #second last query in live
#         #query = f"SELECT * FROM `{tiles_table}` WHERE queued IS NULL OR queued = '' OR queued = '0' OR queued = 0;"
        
#         query = f"SELECT * FROM `{tiles_table}` WHERE (queued IS NULL OR queued = '' OR queued = '0' OR queued = 0) AND tile_terms_and_condition = '1';"
#         print(f"Executing query: {query}")  # Debugging line
#         cursor.execute(query)  # No need to pass parameters for a simple query
#         return cursor.fetchall()
    
#     except MySQLError as e:
#         print(f"An error occurred: {e}")
#         return []  # Return an empty list on error
#     finally:
#         cursor.close()  # Close the cursor if it was created
#         db_connection.close()  # Close the database connection

def get_tiles_from_db(project_name):
    try:
        # Create a new database connection
        db_host = os.getenv('dbHost')
        db_user = os.getenv('dbUser')
        db_pass = os.getenv('dbPass')
        db_name = os.getenv('dbName')
        db_port = int(os.getenv('dbPort'))

        db_connection = connect(
            host=db_host,
            port=db_port,
            user=db_user,
            password=db_pass,
            database=db_name
        )
        
        # Create cursor
        cursor = db_connection.cursor()

        # Use the new table name
        tiles_table = 'tiles_image_fixed_info'

        # Query to select tiles based on project_name and conditions
        query = f"SELECT * FROM `{tiles_table}` WHERE project_name = %s AND (queued IS NULL OR queued = '' OR queued = '0' OR queued = 0) AND tile_terms_and_condition = '1';"
        print(f"Executing query: {query}")  # Debugging line
        cursor.execute(query, (project_name,))  # Pass project_name as a parameter
        return cursor.fetchall()
    
    except MySQLError as e:
        print(f"An error occurred: {e}")
        return []  # Return an empty list on error
    finally:
        cursor.close()  # Close the cursor if it was created
        db_connection.close()  # Close the database connection

# def load_tiles_from_s3_and_db(bucket, project_name, s3_images):
#     tiles_from_db = get_tiles_from_db(project_name)

#     # Separate tiles based on their source using source_3_identifier
#     source_3_identifier = '3'  # Assuming '3' is the identifier for source 3
#     source_3_tiles = [(tile[1], tile[2]) for tile in tiles_from_db if tile[1] == source_3_identifier]  # Using index 2 for image name, index 11 for source
#     other_tiles = [(tile[1], tile[2]) for tile in tiles_from_db if tile[1] != source_3_identifier]

#     # Load all tiles from S3 with the new tile_key_prefix
#     tile_key_prefix = f'projects/{project_name}/tiles-pos-{project_name}/'
#     all_tiles_s3 = s3_images.from_s3(bucket=bucket, key=tile_key_prefix)

#     tile_images = []

#     # Load source 3 tiles first
#     for tile_name, _ in source_3_tiles:
#         for img, key in all_tiles_s3:
#             if str(tile_name) in key:
#                 tile_images.append((img, key))
    
#     # Load other tiles only if no source 3 tiles were found
#     if not tile_images:  # If no source 3 tiles, then load others
#         for tile_name, _ in other_tiles:
#             for img, key in all_tiles_s3:
#                 if str(tile_name) in key:
#                     tile_images.append((img, key))

#     return tile_images


def load_tiles_from_s3_and_db(bucket, project_name, s3_images):
    tiles_from_db = get_tiles_from_db(project_name)

    # Separate tiles based on their source using source_3_identifier
    source_3_identifier = '3'  # Assuming '3' is the identifier for source 3
    source_3_tiles = [(tile[1], tile[2]) for tile in tiles_from_db if str(tile[6]) == source_3_identifier]  # Using index 1 for tile_name, index 2 for tile_name, index 6 for source
    other_tiles = [(tile[1], tile[2]) for tile in tiles_from_db if str(tile[6]) != source_3_identifier]

    # Load all tiles from S3 with the new tile_key_prefix
    tile_key_prefix = f'projects/{project_name}/tiles-pos-{project_name}/'
    all_tiles_s3 = s3_images.from_s3(bucket=bucket, key=tile_key_prefix)

    tile_images = []

    # Load source 3 tiles first
    for tile_name, _ in source_3_tiles:
        for img, key in all_tiles_s3:
            if str(tile_name) in key:
                tile_images.append((img, key))
    
    # Load other tiles only if no source 3 tiles were found
    if not tile_images:  # If no source 3 tiles, then load others
        for tile_name, _ in other_tiles:
            for img, key in all_tiles_s3:
                if str(tile_name) in key:
                    tile_images.append((img, key))

    return tile_images


logging.basicConfig(filename='error.log', level=logging.ERROR)

@bp.route('/generate_mosaic', methods=['POST'])
def generate_mosaic_route():
    import json
    import time  
    try:
        start_time = time.time()  

        logging.info("Starting mosaic generation")
        images = S3Images()
        project_name = request.form.get("project_name")
        user_id = request.form.get("user_id")

        url_decoded_string = urllib.parse.unquote(project_name)
        decoded_bytes = base64.b64decode(url_decoded_string)
        project_name = decoded_bytes.decode('utf-8')

        task_id = str((project_name, user_id))
        redis_client.set(task_id, json.dumps({"status": "Started"}))

        db_port = int(os.getenv('dbPort', '3306'))

        db_connection = mysql.connector.connect(
            host=os.getenv('dbHost'),
            port=db_port,
            user=os.getenv('dbUser'),
            password=os.getenv('dbPass'),
            database=os.getenv('dbName')
        )
        cursor = db_connection.cursor(dictionary=True)

        cursor.execute(
            "SELECT * FROM snmo_projects WHERE user_id = %s AND project_slug = %s;",
            (user_id, project_name)
        )
        results = cursor.fetchall()

        if not results:
            return jsonify({"error": "Project not found"}), 404

        image_filename = results[0]["main_image"]
        tile_size = int(os.getenv('TILE_IMAGE_DENSITY'))
        global TILE_X_SIZE, TILE_Y_SIZE
        TILE_X_SIZE = TILE_Y_SIZE = tile_size

        image_data = images.from_s3("ignored", f'projects/{project_name}/{image_filename}')
        image = image_data[0][0]

        tiles = load_tiles_from_s3_and_db("ignored", project_name, images)

        aspect_ratio = image.width / image.height  

        mosaic_height = OUTPUT_HEIGHT // TILE_Y_SIZE
        mosaic_width = int((OUTPUT_HEIGHT * aspect_ratio) // TILE_X_SIZE)
        image = crop_center(image, mosaic_width, mosaic_height)

        tile_images = load_tiles(tiles)
        mosaic, tile_positions = generate_mosaic(image, tile_images, mosaic_width, mosaic_height)

        mosaic_key = f'projects/{project_name}/{project_name}_mosaic.png'
        images.to_s3(mosaic, "ignored", mosaic_key)

        output_filename = f"temp_mosaic_{project_name}.png"
        output_dir = "/tmp"
        mosaic_output_image_path = os.path.join(output_dir, output_filename)
        mosaic.save(mosaic_output_image_path)

        future = executor.submit(
            run_vips_command,
            mosaic_output_image_path,
            f'{project_name}_output/',
            ("ignored", "ignored", "ignored"),
            "ignored",
            f'projects/{project_name}/'
        )

        future_to_task[future] = (project_name, user_id)

        def update_status(fut):
            try:
                proj_name, usr_id = future_to_task.get(fut, (None, None))
                task_key = str((proj_name, usr_id))
                try:
                    fut.result()
                    redis_client.set(task_key, json.dumps({"status": "Completed"}))
                except Exception as e:
                    logging.error(f"Task failed: {e}")
                    redis_client.set(task_key, json.dumps({"status": f"Failed: {str(e)}"}))
            finally:
                if fut in future_to_task:
                    del future_to_task[fut]
                if os.path.exists(mosaic_output_image_path):
                    os.remove(mosaic_output_image_path)

        future.add_done_callback(update_status)

        redis_client.set(task_id, json.dumps({"status": "In Progress"}))

        insert_tile_positions(project_name, tile_positions)

        # if os.path.exists(mosaic_output_image_path):
        #     os.remove(mosaic_output_image_path)

        end_time = time.time()  
        duration = round(end_time - start_time, 2)

        logging.info(f"[{project_name}] Mosaic generated in {duration} sec with aspect ratio: {round(aspect_ratio, 2)}")

        response = make_response(jsonify({"success": True, "message": "Mosaic generation started"}), 200)
        response.headers['Connection'] = 'keep-alive'
        return response

    except Exception as e:
        logging.error(f"An error occurred: {str(e)}", exc_info=True)
        if 'mosaic_output_image_path' in locals() and os.path.exists(mosaic_output_image_path):
            os.remove(mosaic_output_image_path)
        return jsonify({"error": str(e)}), 507
