
# TTS_Engine.py

from PyQt6.QtWidgets import (
    QDialog, QVBoxLayout, QLabel, QComboBox,
    QPushButton, QFileDialog, QHBoxLayout,
    QSpinBox, QLineEdit, QMessageBox, QGroupBox,
    QGridLayout
)
from PyQt6.QtCore import Qt, QThread, pyqtSignal, QTimer
import os
import subprocess
import re
import json
import numpy as np
from TTS_Progress import TTSProgressDialog


# TTS_Engine.py
# ... (импорты остаются без изменений)

class SileroTTSWorker(QThread):
    progress_updated = pyqtSignal(int, int, str)
    finished = pyqtSignal()
    error_occurred = pyqtSignal(str)

    def __init__(self, text, model_path, output_dir, filename, audio_format,
                 split_chars, bitrate_type, bitrate_value, speaker='xenia'):
        super().__init__()
        self.text = text
        self.model_path = model_path
        self.output_dir = output_dir
        self.filename = filename
        self.audio_format = audio_format
        self.split_chars = split_chars
        self.bitrate_type = bitrate_type
        self.bitrate_value = bitrate_value
        self.speaker = speaker
        self.is_cancelled = False
        self.sample_rate = 48000
        self.MODEL_MAX_CHARS = 800  # Уменьшаем для безопасности

    def safe_for_silero(self, text):
        """
        Очистка текста для Silero TTS.
        """
        if not text:
            return "[пустой текст]"
        
        # 1. Базовые замены
        text = text.replace('\xa0', ' ')  # Неразрывный пробел
        text = text.replace('\u200b', '')  # Zero-width space
        text = text.replace('…', '...')    # Многоточие
        
        # 2. Удаляем управляющие символы
        text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
        text = re.sub(r'[\u2000-\u200F\u2028-\u202F]', '', text)

        # 3. Расшифровка символов
        replacements = {
            '@': 'собака',
            '#': 'решетка',
            '$': 'доллар',
            '%': 'процент',
            '^': 'степень',
            '&': 'амперсанд',
            '*': 'звездочка',
            '{': 'скобка', '}': 'скобка',
            '[': 'скобка', ']': 'скобка',
            '<': 'меньше', '>': 'больше',
            '=': 'равно',
            '~': 'тильда',
            '|': 'вертикальная черта',
            '\\': 'обратный слэш',
            '/': 'слэш',
            '`': 'апостроф',
            '–': '-', '—': '-', '−': '-',
        }
        
        for symbol, replacement in replacements.items():
            text = text.replace(symbol, f' {replacement} ')
        
        # 4. Оставляем только разрешенные символы
        text = re.sub(r'[^а-яёА-ЯЁ0-9\s.,!?:;"\'()-]', '', text)
        
        # 5. Удаляем лишние пробелы
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'\s+([.,!?:;])', r'\1', text)
        text = text.strip()
        
        return text if text else "[пустой текст]"

    def split_text_into_chunks(self, text):
        """
        Основная функция разбивки текста на чанки.
        Возвращает список чанков, каждый не более MODEL_MAX_CHARS символов.
        """
        chunks = []
        
        # Разбиваем на абзацы
        paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
        
        for paragraph in paragraphs:
            # Если абзац короткий
            if len(paragraph) <= self.MODEL_MAX_CHARS:
                chunks.append(paragraph)
                continue
            
            # Разбиваем длинный абзац на предложения
            sentences = self._split_by_sentences(paragraph)
            current_chunk = ""
            
            for sentence in sentences:
                # Если предложение слишком длинное (редкий случай)
                if len(sentence) > self.MODEL_MAX_CHARS:
                    # Сохраняем текущий чанк, если он есть
                    if current_chunk:
                        chunks.append(current_chunk)
                        current_chunk = ""
                    
                    # Разбиваем очень длинное предложение по словам
                    words = sentence.split()
                    temp_chunk = ""
                    
                    for word in words:
                        if len(temp_chunk) + len(word) + 1 <= self.MODEL_MAX_CHARS:
                            temp_chunk = f"{temp_chunk} {word}".strip()
                        else:
                            if temp_chunk:
                                chunks.append(temp_chunk)
                            temp_chunk = word
                    
                    if temp_chunk:
                        chunks.append(temp_chunk)
                    continue
                
                # Нормальное предложение
                if len(current_chunk) + len(sentence) + 1 <= self.MODEL_MAX_CHARS:
                    current_chunk = f"{current_chunk} {sentence}".strip()
                else:
                    if current_chunk:
                        chunks.append(current_chunk)
                    current_chunk = sentence
            
            # Добавляем последний чанк абзаца
            if current_chunk:
                chunks.append(current_chunk)
        
        return chunks

    def _split_by_sentences(self, text):
        """Разбивает текст на предложения"""
        # Улучшенное разбиение
        pattern = r'(?<=[.!?…])\s+(?=[А-ЯA-Z0-9"«\\(])'
        sentences = re.split(pattern, text)
        return [s.strip() for s in sentences if s.strip()]

    def create_file_groups(self, chunks):
        """
        Группирует чанки в файлы.
        Если split_chars = 0: каждый чанк -> отдельный файл
        Если split_chars > 0: группируем чанки по указанному размеру
        """
        if self.split_chars <= 0:
            return [[chunk] for chunk in chunks]
        
        file_groups = []
        current_group = []
        current_size = 0
        
        for chunk in chunks:
            chunk_size = len(chunk)
            
            # Если чанк больше целевого размера
            if chunk_size > self.split_chars:
                # Сохраняем предыдущую группу
                if current_group:
                    file_groups.append(current_group)
                    current_group = []
                    current_size = 0
                
                # Этот большой чанк становится отдельным файлом
                file_groups.append([chunk])
                continue
            
            # Проверяем, помещается ли чанк в текущую группу
            if current_size + chunk_size <= self.split_chars or not current_group:
                current_group.append(chunk)
                current_size += chunk_size
            else:
                # Начинаем новую группу
                file_groups.append(current_group)
                current_group = [chunk]
                current_size = chunk_size
        
        # Добавляем последнюю группу
        if current_group:
            file_groups.append(current_group)
        
        return file_groups

    def run(self):
        try:
            import torch
            import traceback

            self.progress_updated.emit(0, 100, "Загрузка модели Silero...")

            if not os.path.exists(self.model_path):
                raise RuntimeError(f"Файл модели не найден: {self.model_path}")

            # Загрузка модели
            print(f"Загружаем модель из: {self.model_path}")
            importer = torch.package.PackageImporter(self.model_path)
            model = importer.load_pickle("tts_models", "model")

            if model is None:
                raise RuntimeError("Не удалось загрузить модель")

            if not hasattr(model, 'apply_tts'):
                raise AttributeError("Загруженная модель не имеет метода 'apply_tts'")

            print("Модель загружена успешно")
            
            # 1. ОЧИСТКА ТЕКСТА
            self.progress_updated.emit(0, 100, "Очистка текста...")
            cleaned_text = self.safe_for_silero(self.text)
            
            # Проверяем, не пустой ли текст после очистки
            if not cleaned_text or cleaned_text == "[пустой текст]":
                raise ValueError("Текст после очистки пуст или содержит только неподдерживаемые символы")
            
            print(f"Текст после очистки: {len(cleaned_text)} символов")
            
            # 2. РАЗБИВКА НА ЧАНКИ ДЛЯ МОДЕЛИ
            self.progress_updated.emit(0, 100, "Разбивка текста на чанки...")
            model_chunks = self.split_text_into_chunks(cleaned_text)
            
            print(f"Создано чанков для модели: {len(model_chunks)}")
            
            # Проверяем размеры чанков
            for i, chunk in enumerate(model_chunks[:5]):  # Проверяем первые 5 чанков
                print(f"Чанк {i+1}: {len(chunk)} символов")
                if len(chunk) > self.MODEL_MAX_CHARS:
                    print(f"  ВНИМАНИЕ: чанк {i+1} слишком длинный! ({len(chunk)} > {self.MODEL_MAX_CHARS})")
            
            # 3. ГРУППИРОВКА В ФАЙЛЫ
            file_groups = self.create_file_groups(model_chunks)
            
            print(f"Создано групп для файлов: {len(file_groups)}")
            
            total_files = len(file_groups)
            
            # 4. ОБРАБОТКА КАЖДОЙ ГРУППЫ
            for file_idx, group in enumerate(file_groups, 1):
                if self.is_cancelled:
                    break
                
                # Отправляем прогресс
                group_chars = sum(len(chunk) for chunk in group)
                status = f"Файл {file_idx}/{total_files} ({group_chars} символов, {len(group)} частей)"
                self.progress_updated.emit(file_idx, total_files, status)
                print(f"Обрабатываем {status}")
                
                # Генерируем аудио для каждого чанка в группе
                all_audio = []
                for chunk_idx, chunk in enumerate(group, 1):
                    if self.is_cancelled:
                        break
                    
                    print(f"  Чанк {chunk_idx}/{len(group)}: {len(chunk)} символов")
                    
                    # Дополнительная проверка длины чанка
                    if len(chunk) > self.MODEL_MAX_CHARS:
                        print(f"    ПРЕДУПРЕЖДЕНИЕ: чанк длиннее {self.MODEL_MAX_CHARS} символов, обрезаем")
                        chunk = chunk[:self.MODEL_MAX_CHARS]
                    
                    try:
                        audio = model.apply_tts(
                            text=chunk,
                            speaker=self.speaker,
                            sample_rate=self.sample_rate,
                            put_accent=True,
                            put_yo=True
                        )
                        all_audio.append(audio)
                        print(f"    Успешно сгенерировано")
                    except Exception as e:
                        print(f"    Ошибка при генерации чанка: {str(e)}")
                        # Пытаемся обработать ошибку - обрезаем чанк еще больше
                        if len(chunk) > 500:
                            print(f"    Пробуем обрезать до 500 символов")
                            chunk = chunk[:500]
                            audio = model.apply_tts(
                                text=chunk,
                                speaker=self.speaker,
                                sample_rate=self.sample_rate,
                                put_accent=True,
                                put_yo=True
                            )
                            all_audio.append(audio)
                        else:
                            raise
                
                # Проверяем, есть ли аудио для склейки
                if not all_audio:
                    print(f"  Нет аудио для склейки в группе {file_idx}")
                    continue
                
                # Склеиваем аудио части группы
                if len(all_audio) > 1:
                    audio = torch.cat(all_audio, dim=0)
                else:
                    audio = all_audio[0]
                
                # Сохраняем временный WAV
                temp_wav = os.path.join(self.output_dir, f"temp_file_{file_idx:03d}.wav")
                print(f"  Сохраняем временный файл: {temp_wav}")
                self.save_audio(temp_wav, audio, self.sample_rate)
                
                # Конвертируем в нужный формат
                final_file = os.path.join(
                    self.output_dir, 
                    f"{self.filename}_{file_idx:03d}.{self.audio_format}"
                )
                
                print(f"  Конвертируем в {self.audio_format}: {final_file}")
                if self.audio_format != "wav":
                    self.convert_audio(temp_wav, final_file, self.audio_format)
                    os.remove(temp_wav)
                else:
                    os.rename(temp_wav, final_file)
                
                print(f"  Файл {file_idx} готов: {final_file}")
            
            if not self.is_cancelled:
                print("Конвертация завершена успешно")
                self.finished.emit()

        except Exception as e:
            error_details = traceback.format_exc()
            print(f"Критическая ошибка в run(): {error_details}")
            self.error_occurred.emit(f"Ошибка: {str(e)}")

    def save_audio(self, filename, audio_tensor, sample_rate):
        """Сохраняет аудио в WAV"""
        try:
            import scipy.io.wavfile as wavfile
            
            if hasattr(audio_tensor, 'cpu'):
                audio_np = audio_tensor.cpu().numpy()
            else:
                audio_np = audio_tensor.numpy()
            
            if audio_np.ndim > 1:
                audio_np = audio_np.squeeze()
            
            # Нормализуем до int16
            audio_np = np.asarray(audio_np)
            if np.abs(audio_np).max() > 1.0:
                audio_np = audio_np / np.abs(audio_np).max()
            
            audio_np_int16 = np.int16(audio_np * 32767)
            wavfile.write(filename, sample_rate, audio_np_int16)
            
        except ImportError:
            import wave
            
            if hasattr(audio_tensor, 'cpu'):
                audio_np = audio_tensor.cpu().numpy()
            else:
                audio_np = audio_tensor.numpy()
            
            if audio_np.ndim > 1:
                audio_np = audio_np.squeeze()
            
            # Нормализуем
            audio_np = np.asarray(audio_np)
            if np.abs(audio_np).max() > 1.0:
                audio_np = audio_np / np.abs(audio_np).max()
            
            audio_np_int16 = np.int16(audio_np * 32767)
            
            with wave.open(filename, 'w') as wav_file:
                wav_file.setnchannels(1)
                wav_file.setsampwidth(2)
                wav_file.setframerate(sample_rate)
                wav_file.writeframes(audio_np_int16.tobytes())

    def convert_audio(self, input_file, output_file, audio_format):
        """Конвертирует аудио через FFmpeg"""
        try:
            if self.bitrate_type == "Variable":
                bitrate_param = ["-q:a", str(self.bitrate_value)]
            else:
                bitrate_param = ["-b:a", f"{self.bitrate_value}k"]
            
            cmd = [
                'ffmpeg',
                '-i', input_file,
                '-y',
                *bitrate_param,
                '-ac', '1',
                output_file
            ]
            
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
            if result.returncode != 0:
                raise Exception(f"FFmpeg ошибка: {result.stderr}")
        except Exception as e:
            raise Exception(f"Ошибка конвертации: {str(e)}")

class ModelManager:
    def __init__(self, models_dir):
        self.models_dir = models_dir
        self.models_config = {}
        self.available_models = []
        self.load_models()
    
    def load_models(self):
        config_path = os.path.join(self.models_dir, "models.json")
        
        if os.path.exists(config_path):
            try:
                with open(config_path, 'r', encoding='utf-8') as f:
                    self.models_config = json.load(f)
            except Exception:
                self.models_config = {}
        
        self.scan_for_models()
    
    def scan_for_models(self):
        if not os.path.exists(self.models_dir):
            return
        
        for file in os.listdir(self.models_dir):
            if file.endswith('.pt'):
                model_id = file.replace('.pt', '')
                
                if model_id in self.models_config:
                    self.available_models.append(model_id)
    
    def get_model_path(self, model_id):
        if model_id in self.models_config:
            return os.path.join(self.models_dir, self.models_config[model_id]["file"])
        return None
    
    def get_model_voices(self, model_id):
        if model_id in self.models_config:
            return self.models_config[model_id].get("voices", ["aidar", "baya", "kseniya", "xenia", "eugene"])
        return ["aidar", "baya", "kseniya", "xenia", "eugene"]


class TTSDialog(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle("Настройки Silero TTS")
        self.setFixedSize(480, 520)
        self.main_window = parent
        self.output_dir = os.path.expanduser(r"C:\Users\1\Desktop")
        
        self.models_dir = r"D:\TTS\Silero"
        self.model_manager = ModelManager(self.models_dir)
        
        self.init_ui()
    
    def init_ui(self):
        layout = QVBoxLayout()
        layout.setSpacing(10)
        layout.setContentsMargins(12, 12, 12, 12)
        
        # === Модель Silero ===
        model_group = QGroupBox("Модель Silero")
        model_layout = QGridLayout()
        model_layout.setSpacing(8)
        
        model_layout.addWidget(QLabel("Модель:"), 0, 0, Qt.AlignmentFlag.AlignLeft)
        self.model_combo = QComboBox()
        model_layout.addWidget(self.model_combo, 0, 1)
        
        model_layout.addWidget(QLabel("Голос:"), 1, 0, Qt.AlignmentFlag.AlignLeft)
        self.voice_combo = QComboBox()
        model_layout.addWidget(self.voice_combo, 1, 1)
        
        model_group.setLayout(model_layout)
        layout.addWidget(model_group)
        
        # === Настройки аудио ===
        audio_group = QGroupBox("Настройки аудио")
        audio_layout = QGridLayout()
        audio_layout.setSpacing(8)
        
        audio_layout.addWidget(QLabel("Формат:"), 0, 0)
        self.format_combo = QComboBox()
        self.format_combo.addItems(["wav", "mp3", "aac"])
        self.format_combo.setCurrentText('mp3')  # ← MP3 по умолчанию
        audio_layout.addWidget(self.format_combo, 0, 1)
        
        audio_layout.addWidget(QLabel("Тип битрейта:"), 1, 0)
        self.bitrate_type = QComboBox()
        self.bitrate_type.addItems(["Constant", "Variable"])
        audio_layout.addWidget(self.bitrate_type, 1, 1)
        
        audio_layout.addWidget(QLabel("Значение:"), 2, 0)
        self.bitrate_spin = QSpinBox()
        self.bitrate_spin.setRange(32, 320)
        self.bitrate_spin.setValue(128)
        audio_layout.addWidget(self.bitrate_spin, 2, 1)
        
        self.bitrate_hint = QLabel("кбит/с")
        self.bitrate_hint.setStyleSheet("color: gray;")
        audio_layout.addWidget(self.bitrate_hint, 2, 2)
        
        audio_group.setLayout(audio_layout)
        layout.addWidget(audio_group)
        
        # === Настройки вывода ===
        output_group = QGroupBox("Настройки вывода")
        output_layout = QVBoxLayout()
        output_layout.setSpacing(8)
        
        # Разделение текста
        split_row = QHBoxLayout()
        split_row.addWidget(QLabel("Символов в файле:"))
        self.split_spin = QSpinBox()
        self.split_spin.setRange(0, 50000)
        self.split_spin.setValue(10000)
        self.split_spin.setSingleStep(1000)
        split_row.addWidget(self.split_spin)
        split_row.addStretch()
        output_layout.addLayout(split_row)
        
        # Подсказка
        split_hint = QLabel("0 = не разбивать. Модель принимает до 1000 символов.\nТекст разбивается по абзацам, длинные абзацы - по предложениям.\nЗатем части склеиваются в файлы указанного размера.")
        split_hint.setStyleSheet("color: gray; font-size: 9px;")
        split_hint.setWordWrap(True)
        output_layout.addWidget(split_hint)
        
        # Имя файла
        name_row = QHBoxLayout()
        name_row.addWidget(QLabel("Имя файла:"))
        self.filename_edit = QLineEdit("output")
        name_row.addWidget(self.filename_edit)
        output_layout.addLayout(name_row)
        
        # Папка сохранения
        folder_row = QHBoxLayout()
        folder_row.addWidget(QLabel("Папка сохранения:"))
        self.folder_btn = QPushButton("Выбрать...")
        self.folder_btn.clicked.connect(self.select_output_folder)
        folder_row.addWidget(self.folder_btn)
        folder_row.addStretch()
        output_layout.addLayout(folder_row)
        
        output_group.setLayout(output_layout)
        layout.addWidget(output_group)
        
        layout.addStretch()
        
        # === Кнопка конвертации ===
        self.convert_btn = QPushButton("Конвертировать в речь")
        self.convert_btn.setMinimumHeight(40)
        self.convert_btn.clicked.connect(self.start_conversion)
        layout.addWidget(self.convert_btn)
        
        self.setLayout(layout)
        
        # Заполняем модели
        if self.model_manager.available_models:
            self.model_combo.addItems(self.model_manager.available_models)
            self.update_voices()
            self.voice_combo.setCurrentText('kseniya')  # ← Ксения по умолчанию
        else:
            self.model_combo.addItem("Модели не найдены")
            self.convert_btn.setEnabled(False)
        
        # При изменении модели обновляем голоса
        self.model_combo.currentTextChanged.connect(self.on_model_changed)
        
        # При изменении типа битрейта
        self.bitrate_type.currentTextChanged.connect(self.update_bitrate_hint)
        self.update_bitrate_hint(self.bitrate_type.currentText())
        
        self.progress_dialog = None
        self.worker = None
    
    def on_model_changed(self, model_id):
        """Обновляет доступные голоса при смене модели"""
        self.update_voices()
    
    def update_voices(self):
        """Обновляет список голосов для выбранной модели"""
        self.voice_combo.clear()
        model_id = self.model_combo.currentText()
        voices = self.model_manager.get_model_voices(model_id)
        self.voice_combo.addItems(voices)
    
    def update_bitrate_hint(self, bitrate_type):
        if bitrate_type == "Variable":
            self.bitrate_hint.setText("(0=лучшее, 9=худшее)")
            self.bitrate_spin.setRange(0, 9)
            self.bitrate_spin.setValue(4)
        else:
            self.bitrate_hint.setText("кбит/с")
            self.bitrate_spin.setRange(32, 320)
            self.bitrate_spin.setValue(128)
    
    def select_output_folder(self):
        folder = QFileDialog.getExistingDirectory(
            self,
            "Выберите папку для сохранения",
            self.output_dir
        )
        if folder:
            self.output_dir = folder
    
    def start_conversion(self):
        if self.main_window and hasattr(self.main_window, 'text_browser'):
            text = self.main_window.text_browser.toPlainText()
            if text.strip():
                model_id = self.model_combo.currentText()
                model_path = self.model_manager.get_model_path(model_id)
                
                if not model_path or not os.path.exists(model_path):
                    QMessageBox.warning(self, "Ошибка", 
                                      f"Модель Silero не найдена: {model_id}")
                    return
                
                if model_id == "Модели не найдены":
                    QMessageBox.warning(self, "Ошибка", 
                                      "Пожалуйста, поместите модели в папку D:\\TTS\\Silero")
                    return
                
                self.progress_dialog = TTSProgressDialog(self)
                
                self.worker = SileroTTSWorker(
                    text=text,
                    model_path=model_path,
                    output_dir=self.output_dir,
                    filename=self.filename_edit.text(),
                    audio_format=self.format_combo.currentText(),
                    split_chars=self.split_spin.value(),
                    bitrate_type=self.bitrate_type.currentText(),
                    bitrate_value=self.bitrate_spin.value(),
                    speaker=self.voice_combo.currentText()
                )
                
                self.progress_dialog.set_worker(self.worker)
                self.worker.progress_updated.connect(self.update_progress)
                self.worker.finished.connect(self.on_conversion_finished)
                self.worker.error_occurred.connect(self.on_conversion_error)
                
                self.worker.start()
                self.progress_dialog.start()
            else:
                QMessageBox.warning(self, "Внимание", 
                                  "Нет текста для конвертации.")
    
    def update_progress(self, current, total, status):
        if self.progress_dialog:
            self.progress_dialog.update_progress(current, total)
            self.progress_dialog.status_label.setText(status)
    
    def on_conversion_finished(self):
        if self.progress_dialog:
            self.progress_dialog.progress.setValue(100)
            self.progress_dialog.status_label.setText("Готово!")
            QTimer.singleShot(1000, self.progress_dialog.close)
        self.worker = None
    
    def on_conversion_error(self, error_msg):
        if self.progress_dialog:
            self.progress_dialog.status_label.setText(f"Ошибка: {error_msg}")
            QTimer.singleShot(3000, self.progress_dialog.close)
        self.worker = None
    
    def closeEvent(self, event):
        if self.worker and self.worker.isRunning():
            self.worker.is_cancelled = True
            self.worker.wait(1000)
        if self.progress_dialog:
            self.progress_dialog.close()
        event.accept()