Whisper-JNI

简介

Whisper 是 OpenAI 开发的一个基于 Transformer 的语音识别模型,能够处理多种语言并生成高度准确的语音到文本转录。该模型最初是基于 Python 和机器学习框架(如 PyTorch)构建的,因此开发者必须依赖 Python 和相应的库来进行语音识别任务。

Whisper

OpenAI 开源了 Whisper 模型,最初设计是为了在 Python 生态系统中进行高效的语音识别。以下是其原理和技术要点:

  1. Python 和 Transformer 模型:Whisper 基于 Transformer 神经网络架构,这种架构在自然语言处理 (NLP) 和语音识别任务中表现非常出色。它通过堆叠多个自注意力 (self-attention) 机制来处理语音和文本之间的转换。Whisper 模型的核心组件使用 Python 编写,并依赖于诸如 PyTorch 这样的深度学习框架。
  2. 依赖 Python 和高性能硬件:Whisper 的默认实现需要 Python 环境和 GPU 支持,这使得它适合构建在研究环境或具有强大计算能力的服务器上。然而,这也带来了部署的复杂性,尤其是在轻量级或嵌入式设备上运行时。

Whisper.cpp

whisper.cpp 是由 ggerganov 开源的一个项目,旨在解决 OpenAI Whisper 模型的 Python 依赖问题。它实现了一个纯 C++ 的版本,避免了 Python 和 GPU 的强依赖,能够更方便地在没有 GPU 或轻量级的系统上运行。

whisper.cpp 解决的问题

  • 跨平台支持:通过 C++ 实现,whisper.cpp 可以在 Windows、Linux 和 macOS 等平台上运行,且不依赖于 Python 生态。这极大简化了部署环境的复杂性。
  • 轻量级和高效:由于它是用 C++ 编写的,whisper.cpp 可以更好地利用 CPU 进行推理,适合在资源有限的设备(如嵌入式系统)上运行。
  • 无需 GPU 加速:与原始的 OpenAI Whisper 不同,whisper.cpp 允许在没有 GPU 的情况下运行模型,尽管在性能上可能稍有折中。

Whisper-JNI 的解决方案

whisper-jni 是基于 whisper.cpp 进一步封装的项目,由 GiviMAD 开发。该项目为 Java 环境提供了 JNI(Java Native Interface)绑定,使得 Java 开发者可以直接调用 whisper.cpp 的功能,而无需深入理解 C++。

whisper-jni 解决的问题

  • Java 开发集成:许多企业应用和服务是基于 Java 构建的。whisper-jniwhisper.cpp 的功能集成到 Java 环境中,允许开发者使用 Java 调用 Whisper 模型进行语音识别,而无需切换到其他编程语言(如 Python)。
  • 简化 C++ 调用:通过 JNI 封装,开发者无需直接与 C++ 代码交互,可以使用 Java 方法调用 Whisper 模型。这大大简化了应用集成的难度。

笔者 Whisper-JNI 的改造:Java 1.8 版本

虽然 whisper-jni 实现了将 whisper.cpp 集成到 Java 中,但默认情况下该库依赖于 Java 11 及以上的特性。然而,许多企业级应用仍然在使用 Java 1.8,这导致 whisper-jni 在这些环境中不可用。

为了满足 Java 1.8 用户的需求,笔者对 whisper-jni 进行了改造,使其能够在 Java 1.8 环境下运行。主要的改造点包括:

  1. Java 1.8 兼容性:通过修改代码中的语法和依赖项,移除了 Java 11 中引入的特性,使得该项目能够在 Java 1.8 环境下正常编译和运行。
  2. 优化 JNI 调用:对 JNI 代码进行调整,以确保在 Java 1.8 环境中调用 C++ 库时保持高效和稳定。
  3. 推送 Maven 仓库:笔者将改造后的 whisper-jni 项目推送到 Maven 中,以便开发者可以轻松依赖并使用这个适用于 Java 1.8 的版本。

总结

  • Whisper (OpenAI):基于 Python 和 Transformer 模型的语音识别系统,性能出色但依赖复杂。
  • Whisper.cpp:通过 C++ 实现的轻量级版本,摆脱了对 Python 和 GPU 的依赖,能够在更多平台上运行。
  • Whisper-JNI:基于 whisper.cpp 的 Java 封装,使 Java 开发者能够在 Java 项目中使用 Whisper 模型。
  • 我对 Whisper-JNI 的改造:专为 Java 1.8 环境进行的改造,确保企业应用能够无缝集成 Whisper 模型,并且推送到 Maven 仓库,便于开发者使用。

这使得 Java 开发者能够在不升级 JDK 的前提下,在 Java 1.8 环境中高效地使用 Whisper 进行语音识别。

whisper-jni 入门示例

Maven 坐标

<dependency>
  <groupId>com.litongjava</groupId>
  <artifactId>whisper-jni</artifactId>
  <version>1.6.1</version>
</dependency>

加载模型和推理示例

以下代码展示了如何加载 Whisper 模型并进行音频文件的推理操作:

package com.litongjava.ai.server.utils;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.nio.file.Path;
import java.nio.file.Paths;

import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.UnsupportedAudioFileException;

import io.github.givimad.whisperjni.WhisperContext;
import io.github.givimad.whisperjni.WhisperFullParams;
import io.github.givimad.whisperjni.WhisperJNI;
import io.github.givimad.whisperjni.WhisperSamplingStrategy;

public class WhisperSpeechRecognitionExample {

  public static void main(String[] args) throws Exception {
    // 加载模型文件
    Path modelFile = Paths.get("ggml-large-v3-turbo.bin");
    File file = modelFile.toFile();
    if (!file.exists() || !file.isFile()) {
      throw new RuntimeException("缺少模型文件: " + file.getAbsolutePath());
    }

    // 加载音频文件
    Path audioFilePath = Paths.get("samples/jfk.wav");
    File sampleFile = audioFilePath.toFile();
    if (!sampleFile.exists() || !sampleFile.isFile()) {
      throw new RuntimeException("缺少音频样本文件");
    }

    // 初始化 Whisper 库
    WhisperJNI.LoadOptions loadOptions = new WhisperJNI.LoadOptions();
    loadOptions.logger = System.out::println;
    WhisperJNI.loadLibrary(loadOptions);
    WhisperJNI.setLibraryLogger(null);

    WhisperJNI whisper = new WhisperJNI();
    WhisperContext ctx = whisper.init(modelFile);

    // 读取音频样本数据
    float[] samples = readAudioSamples(sampleFile);

    // 设置推理参数
    WhisperFullParams params = new WhisperFullParams(WhisperSamplingStrategy.GREEDY);
    int result = whisper.full(ctx, params, samples, samples.length);

    if (result != 0) {
      throw new RuntimeException("识别失败,错误码: " + result);
    }

    // 输出识别结果
    String transcription = whisper.fullGetSegmentText(ctx, 0);
    System.out.println("识别结果: " + transcription);

    ctx.close();
  }

  // 读取音频样本数据
  private static float[] readAudioSamples(File audioFile) throws UnsupportedAudioFileException, IOException {
    AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(audioFile);

    ByteBuffer captureBuffer = ByteBuffer.allocate(audioInputStream.available());
    captureBuffer.order(ByteOrder.LITTLE_ENDIAN);

    int read = audioInputStream.read(captureBuffer.array());
    if (read == -1) {
      throw new IOException("文件为空");
    }

    ShortBuffer shortBuffer = captureBuffer.asShortBuffer();
    float[] samples = new float[captureBuffer.capacity() / 2];
    int i = 0;
    while (shortBuffer.hasRemaining()) {
      samples[i++] = Math.max(-1f, Math.min(((float) shortBuffer.get()) / (float) Short.MAX_VALUE, 1f));
    }
    return samples;
  }
}

代码解释

  1. 模型加载:代码通过 WhisperJNI.init() 方法加载 .bin 模型文件。
  2. 音频读取readAudioSamples() 方法读取音频文件,并将其转换为浮点数数组,以便 Whisper 进行处理。
  3. 推理过程:调用 whisper.full() 方法对音频数据进行推理。识别成功后,可以通过 whisper.fullGetSegmentText() 获取文本结果。
  4. 日志管理:通过 WhisperJNI.LoadOptions 配置日志输出。此处使用 System.out::println 来打印日志。

封装为 Web 服务

由于 WhisperJNI 实例是线程不安全的,当多个线程同时调用 WhisperJNI 时,可能会出现调用异常。为了解决这个问题,我们将 WhisperJNI 封装为 WhisperJniService,并将 WhisperJniService 放入 ThreadLocal 中,以确保每个线程都有自己的 WhisperJniService 实例,从而避免线程安全问题。

为什么使用 ThreadLocal

1. 线程安全问题

WhisperJNI 实例是线程不安全的,这意味着多个线程同时访问同一个实例可能会导致竞态条件、数据不一致或其他不可预见的异常。这在高并发环境下尤为重要,因为多个请求可能会并行处理音频转录任务。

2. ThreadLocal 的作用

ThreadLocal 提供了一种机制,使得每个线程都可以拥有自己独立的变量副本。这意味着即使多个线程同时访问同一个 ThreadLocal 变量,每个线程都会看到自己独立的实例,避免了线程间的干扰。

在本项目中,通过 ThreadLocal<WhisperJniService>,每个线程在第一次访问时会初始化一个独立的 WhisperJniService 实例,并将其绑定到当前线程。后续该线程再次访问时,会直接使用该实例,而不会与其他线程共享。这确保了 WhisperJNI 的线程安全性。

3. 提高并发性能

使用 ThreadLocal 的主要优势在于:

  • 避免锁机制:传统的线程安全措施(如使用 synchronized 关键字)会引入锁竞争,影响性能。而 ThreadLocal 通过为每个线程提供独立实例,避免了共享资源的竞争,从而提升了并发性能。

  • 减少资源开销:尽管每个线程都有自己的 WhisperJniService 实例,但由于 WhisperJNI 的实例创建和初始化开销较高,通过 ThreadLocal 可以确保每个线程只初始化一次,避免了重复创建的资源浪费。

  • 简化代码设计:无需在方法内部进行复杂的线程同步或资源管理,ThreadLocal 自动为每个线程管理独立实例,简化了代码逻辑,提高了代码的可维护性。

4. 使用 ThreadLocal 的注意事项

  • 内存泄漏:在使用 ThreadLocal 时,需要注意及时清理不再使用的实例,尤其是在使用线程池的环境下。如果线程被复用而 ThreadLocal 没有被清理,可能会导致内存泄漏。

    在本项目中,可以在适当的位置(如服务关闭时)调用 threadLocalWhisper.remove() 来清理 ThreadLocal 中的实例。

  • 实例共享:确保通过 ThreadLocal 管理的实例仅在单个线程中使用,避免在不同线程间传递或共享这些实例。

WhisperJniService

package com.litongjava.ai.asr.service;

import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;

import com.litongjava.ai.asr.model.WhisperSegment;

import io.github.givimad.whisperjni.WhisperContext;
import io.github.givimad.whisperjni.WhisperFullParams;
import io.github.givimad.whisperjni.WhisperJNI;

public class WhisperJniService {

  private WhisperJNI whisper = null;
  private WhisperContext ctx = null;

  /**
   * 初始化 WhisperJNI 上下文
   *
   * @param path 模型文件路径
   * @throws IOException 如果初始化失败
   */
  public void initContext(Path path) throws IOException {
    whisper = new WhisperJNI();
    ctx = whisper.init(path);
  }

  /**
   * 进行完整的转录并返回带时间戳的转录段
   *
   * @param params     转录参数
   * @param samples    音频样本数据
   * @param numSamples 样本数量
   * @return 转录段列表
   */
  public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams params, float[] samples, int numSamples) {
    int result = whisper.full(ctx, params, samples, numSamples);
    if (result != 0) {
      throw new RuntimeException("转录失败,错误代码:" + result);
    }
    int numSegments = whisper.fullNSegments(ctx);
    ArrayList<WhisperSegment> segments = new ArrayList<>(numSegments);

    for (int i = 0; i < numSegments; i++) {
      String text = whisper.fullGetSegmentText(ctx, i);
      long start = whisper.fullGetSegmentTimestamp0(ctx, i);
      long end = whisper.fullGetSegmentTimestamp1(ctx, i);
      segments.add(new WhisperSegment(start, end, text));
    }
    return segments;
  }

  /**
   * 关闭 WhisperJNI 上下文
   */
  public void close() {
    if (ctx != null) {
      ctx.close();
    }
  }
}

LocalTinyWhisper 单例类

package com.litongjava.ai.asr.single;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import com.litongjava.ai.asr.model.WhisperSegment;
import com.litongjava.ai.asr.service.WhisperJniService;

import io.github.givimad.whisperjni.WhisperFullParams;
import io.github.givimad.whisperjni.WhisperJNI;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public enum LocalTinyWhisper {
  INSTANCE;

  private final ExecutorService executorService;
  private final ThreadLocal<WhisperJniService> threadLocalWhisper;
  private final WhisperFullParams defaultParams = new WhisperFullParams();

  LocalTinyWhisper() {
    try {
      WhisperJNI.loadLibrary();
    } catch (IOException e1) {
      e1.printStackTrace();
    }
    // 模型文件路径
    String userHome = System.getProperty("user.home");
    String modelName = "ggml-tiny.bin";
    Path path = Paths.get(userHome, ".cache", "whisper", modelName);

    int availableProcessors = Runtime.getRuntime().availableProcessors();
    log.info("可用处理器数量: {}", availableProcessors);
    this.executorService = Executors.newFixedThreadPool(Math.max(1, availableProcessors - 1));

    // 初始化 ThreadLocal,以确保每个线程有独立的 WhisperJniService 实例
    threadLocalWhisper = ThreadLocal.withInitial(() -> {
      WhisperJniService whisper = new WhisperJniService();
      try {
        whisper.initContext(path);
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
      return whisper;
    });

    defaultParams.printProgress = false;
  }

  /**
   * 进行完整的转录并返回带时间戳的转录段
   *
   * @param audioData  音频数据
   * @param numSamples 样本数量
   * @param params     转录参数
   * @return 转录段列表
   */
  public List<WhisperSegment> fullTranscribeWithTime(float[] audioData, int numSamples, WhisperFullParams params) {
    Callable<List<WhisperSegment>> task = () -> {
      WhisperJniService whisper = threadLocalWhisper.get();
      if (params != null) {
        log.info("使用自定义参数: {}", params);
        return whisper.fullTranscribeWithTime(params, audioData, numSamples);
      } else {
        return whisper.fullTranscribeWithTime(defaultParams, audioData, numSamples);
      }
    };

    try {
      return executorService.submit(task).get();
    } catch (InterruptedException | ExecutionException e) {
      log.error("转录任务执行失败", e);
      Thread.currentThread().interrupt();
      return null;
    }
  }

  /**
   * 重载方法,使用音频数据和默认参数进行转录
   *
   * @param floats 音频数据
   * @param params 转录参数
   * @return 转录段列表
   */
  public List<WhisperSegment> fullTranscribeWithTime(float[] floats, WhisperFullParams params) {
    return fullTranscribeWithTime(floats, floats.length, params);
  }
}

WhisperCppTinyService 服务类

package com.litongjava.ai.asr.service;

import java.io.IOException;
import java.net.URL;
import java.util.List;

import javax.sound.sampled.UnsupportedAudioFileException;

import com.litongjava.ai.asr.enumeration.AudioType;
import com.litongjava.ai.asr.enumeration.TextType;
import com.litongjava.ai.asr.model.WhisperSegment;
import com.litongjava.ai.asr.single.LocalTinyWhisper;
import com.litongjava.ai.asr.utils.Mp3Util;
import com.litongjava.ai.asr.utils.WhisperAudioUtils;
import com.litongjava.jfinal.aop.Aop;

import io.github.givimad.whisperjni.WhisperFullParams;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class WhisperCppTinyService {
  private final TextService textService = Aop.get(TextService.class);

  /**
   * 根据音频 URL 进行转录
   *
   * @param url    音频文件的 URL
   * @param params 转录参数
   * @return 转录段列表
   */
  public List<WhisperSegment> index(URL url, WhisperFullParams params) {
    try {
      float[] floats = WhisperAudioUtils.toAudioData(url);
      log.info("音频数据长度: {}", floats.length);
      List<WhisperSegment> segments = LocalTinyWhisper.INSTANCE.fullTranscribeWithTime(floats, floats.length, params);
      log.info("转录段数量: {}", segments.size());
      return segments;
    } catch (UnsupportedAudioFileException | IOException e) {
      log.error("音频处理或转录失败", e);
      return null;
    }
  }

  /**
   * 根据音频数据进行转录
   *
   * @param data   音频数据
   * @param params 转录参数
   * @return 转录段列表
   */
  public List<WhisperSegment> index(byte[] data, WhisperFullParams params) {
    float[] floats = WhisperAudioUtils.toFloat(data);
    return LocalTinyWhisper.INSTANCE.fullTranscribeWithTime(floats, params);
  }

  /**
   * 生成 SRT 字幕文件
   *
   * @param url    音频文件的 URL
   * @param params 转录参数
   * @return SRT 字符串
   * @throws IOException 如果处理失败
   */
  public StringBuffer outputSrt(URL url, WhisperFullParams params) throws IOException {
    List<WhisperSegment> segments = this.index(url, params);
    return textService.generateSrt(segments);
  }

  /**
   * 生成 VTT 字幕文件
   *
   * @param url    音频文件的 URL
   * @param params 转录参数
   * @return VTT 字符串
   * @throws IOException 如果处理失败
   */
  public StringBuffer outputVtt(URL url, WhisperFullParams params) throws IOException {
    List<WhisperSegment> segments = this.index(url, params);
    return textService.generateVtt(segments);
  }

  /**
   * 生成 LRC 歌词文件
   *
   * @param url    音频文件的 URL
   * @param params 转录参数
   * @return LRC 字符串
   * @throws IOException 如果处理失败
   */
  public StringBuffer outputLrc(URL url, WhisperFullParams params) throws IOException {
    List<WhisperSegment> segments = this.index(url, params);
    return textService.generateLrc(segments);
  }

  /**
   * 根据音频数据和指定类型进行转录并返回指定格式的结果
   *
   * @param data        音频数据
   * @param inputType   输入音频类型
   * @param outputType  输出文本类型
   * @return 转录结果
   * @throws IOException                    如果处理失败
   * @throws UnsupportedAudioFileException  如果音频文件不支持
   */
  public Object index(byte[] data, String inputType, String outputType)
      throws IOException, UnsupportedAudioFileException {
    return index(data, inputType, outputType, null);
  }

  /**
   * 根据音频数据和指定类型进行转录并返回指定格式的结果
   *
   * @param data        音频数据
   * @param inputType   输入音频类型
   * @param outputType  输出文本类型
   * @param params      转录参数
   * @return 转录结果
   * @throws IOException                    如果处理失败
   * @throws UnsupportedAudioFileException  如果音频文件不支持
   */
  public Object index(byte[] data, String inputType, String outputType, WhisperFullParams params)
      throws IOException, UnsupportedAudioFileException {
    log.info("输入类型: {}, 输出类型: {}", inputType, outputType);
    AudioType audioType = AudioType.fromString(inputType);
    TextType textType = TextType.fromString(outputType);

    // 如果输入音频是 MP3 格式,进行格式转换
    if (audioType == AudioType.MP3) {
      log.info("进行格式转换: MP3 转 WAV");
      data = Aop.get(Mp3Util.class).convertToWav(data, 16000, 1);
    }

    List<WhisperSegment> segments = index(data, params);
    if (segments == null) {
      return null;
    }

    switch (textType) {
      case SRT:
        return textService.generateSrt(segments).toString();
      case VTT:
        return textService.generateVtt(segments).toString();
      case LRC:
        return textService.generateLrc(segments).toString();
      default:
        return segments;
    }
  }
}

控制器示例

@RequestPath("/test/tiny")
public HttpResponse testTiny(HttpRequest request, WhisperFullParams params) {
  URL resource = ResourceUtil.getResource("audios/jfk.wav");
  if (resource != null) {
    List<WhisperSegment> list = whisperCppTinyService.index(resource, params);
    return Resps.json(request, Resp.ok(list));
  }
  return null;
}

访问地址即可测试:

http://localhost/whisper/test/tiny

示例响应:

{
  "data": [
    {
      "start": "0",
      "end": "1090",
      "sentence": "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country."
    }
  ],
  "ok": true,
  "code": null,
  "msg": null
}

总结

通过使用 ThreadLocal,我们为每个线程提供了独立的 WhisperJniService 实例,确保了 WhisperJNI 的线程安全性。同时,这种设计避免了锁机制带来的性能开销,提升了系统在高并发环境下的处理能力和响应速度。这对于需要处理大量并发音频转录请求的 Web 服务来说,是一种高效且可靠的解决方案。