如何在没有用户界面的情况下访问f5-tts模型

问题描述 投票:0回答:1

我正在尝试构建一个项目,我需要访问 f5-tts 模型才能动态使用它。

我尝试将它与用户界面一起使用,但这不会减少它。我想实时向 f5-tts 提供输入并实时返回结果。我有什么办法可以做到这一点吗?我在其他地方找不到有用的信息。

python text-to-speech
1个回答
0
投票

我也一直在研究类似的东西,但不断遇到丢失按键的情况,这会导致静态音频,但这就是我迄今为止的代码。如果您解决了模型丢失钥匙的问题,请告诉我。

        import torch
    from f5_tts.model.backbones.dit import DiT  # Use DiT as the backbone
    from vocos import Vocos
    from vocos.feature_extractors import MelSpectrogramFeatures
    from vocos.heads import ISTFTHead
    import sounddevice as sd
    # Paths to cached files
    model_cache_path = ""
    vocoder_cache_path = ""
    import numpy as np

    # Initialize device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Define reproducibility function
    def seed_everything(seed: int):
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    # Create feature extractor
    feature_extractor = MelSpectrogramFeatures(
        sample_rate=24000, 
        n_fft=1024, 
        hop_length=256, 
        n_mels=100, 
        padding="center"
    )

    # Use DiT backbone as specified in the log
    backbone = DiT(
        dim=1024, 
        depth=22, 
        heads=16, 
        ff_mult=2, 
        text_dim=512, 
        conv_layers=4
    )

    # Define head (to process the final output into a waveform)
    head = ISTFTHead(
        dim=512, 
        n_fft=1024, 
        hop_length=256, 
        padding="center"
    )

    # Initialize the vocoder with the backbone, head, and feature extractor
    vocoder = Vocos(
        backbone=backbone, 
        head=head, 
        feature_extractor=feature_extractor
    ).to(device)

    # Model configuration based on the log
    model_config = {
        "dim": 1024,
        "depth": 22,
        "heads": 16,
        "ff_mult": 2,
        "text_dim": 512,
        "conv_layers": 4
    }

    # Initialize the DiT model with the provided configuration
    model = DiT(**model_config).to(device)

    # Load the model weights from the cache path with strict=False
    state_dict = torch.load(model_cache_path, map_location=device)
    # Load the model weights from the cache path with strict=False
    model.load_state_dict(torch.load(model_cache_path, map_location=device), strict=False)

    # Set model to evaluation mode
    model.eval()

    # Placeholder for `infer_process` function
    def infer_process(ref_audio, ref_text, gen_text, model, vocoder, mel_spec_type, device):
        # Implement the logic for inference based on your framework
        # This is a placeholder to ensure the script runs
        # Replace this with the actual inference logic
        wav = torch.randn(1, 24000 * 5)  # Simulated 5-second waveform
        sr = 24000  # Sample rate
        spect = None  # Spectrogram, if needed
        return wav, sr, spect

    # Function to generate TTS
    def generate_tts(gen_text: str):
        print(f"Generating speech for: '{gen_text}'")
        seed_everything(42)  # For reproducibility

        # Placeholder reference audio and text (replace with actual data in production)
        dummy_ref_audio = torch.randn(1, 80, 100).to(device)  # Example Mel spectrogram
        dummy_ref_text = torch.tensor([1, 2, 3, 4]).to(device)  # Example tokenized text

        # Infer process
        wav, sr, spect = infer_process(
            ref_audio='',
            ref_text="",
            gen_text="""""",
            model=model,
            vocoder=vocoder,
            mel_spec_type="vocos",  # Using vocos as per the log
            device=device,
        )
        return wav, sr

    # Run test
    if __name__ == "__main__":
        gen_text = "Hello, this is a quick test of the TTS system."
        wav, sr = generate_tts(gen_text)
        print(f"Generated audio waveform shape: {wav.shape}")
        print(f"Sample rate: {sr}")

        # Ensure the audio format is correct (16-bit PCM)
        audio = (wav.cpu().numpy()[0] * 32767).astype(np.int16)

        # Debugging statements
        print("Audio waveform min/max values:", audio.min(), audio.max())
        print("Default audio device sample rate:", sd.query_devices(None, 'output')['default_samplerate'])
        import scipy.signal
       # Resample the audio to 44.1 kHz
        audio_resampled = scipy.signal.resample(audio, int(44100 / sr * len(audio)))
        # Normalize audio data
        audio = audio / np.max(np.abs(audio))
        # Play the resampled audio
        sd.play(audio_resampled, 44100)
        sd.wait()
        print("Audio waveform values:", audio[:10])
        import matplotlib.pyplot as plt

        plt.plot(audio)
        plt.show()
© www.soinside.com 2019 - 2024. All rights reserved.