片段向量
我们已经完成了数据库设计、用户登录、知识库管理以及文件拆分功能。接下来,我们将实现片段向量化功能。
概述
片段向量化需要调用大型语言模型。为了减少向量模型的调用次数,我们将向量模型的返回结果缓存到 max_kb_embedding_cache
表中。以下是相关接口及代码的实现细节。
接口文档
批量处理文档向量
请求方法: POST
请求路径: /api/dataset/{id}/document/_bach
请求示例:
[
{
"name": "ICS111_31391_Miller_Syllabus_F24.pdf",
"id": "file_id",
"paragraphs": [
{
"title": "",
"content": "ICS 111- Introduction to Computer Science I, 31391\nFall"
}
]
}
]
响应示例:
{
"code": 200,
"message": "success",
"data": [
{
"create_time": "2024-11-05T14:24:12.315",
"update_time": "2024-11-05T14:24:12.315",
"id": "xxx",
"name": "ICS111_31391_Miller_Syllabus_F24.pdf",
"char_length": 20755,
"status": "3",
"is_active": true,
"type": "0",
"meta": {},
"dataset_id": "xxx",
"hit_handling_method": "optimization",
"directly_return_similarity": 0.9,
"paragraph_count": 6
}
]
}
代码实现
控制器层
package com.litongjava.maxkb.controller;
import java.util.List;
import com.litongjava.annotation.Post;
import com.litongjava.annotation.RequestPath;
import com.litongjava.jfinal.aop.Aop;
import com.litongjava.maxkb.model.DocumentBatchVo;
import com.litongjava.maxkb.service.DatasetDocumentVectorService;
import com.litongjava.model.result.ResultVo;
import com.litongjava.tio.boot.http.TioRequestContext;
import com.litongjava.tio.http.common.HttpRequest;
import com.litongjava.tio.utils.json.JsonUtils;
@RequestPath("/api/dataset")
public class AapiDatasetController {
@Post("/{id}/document/_bach")
public ResultVo batch(Long id, HttpRequest request) {
String bodyString = request.getBodyString();
Long userId = TioRequestContext.getUserIdLong();
List<DocumentBatchVo> documentList = JsonUtils.parseArray(bodyString, DocumentBatchVo.class);
DatasetDocumentVectorService vectorService = Aop.get(DatasetDocumentVectorService.class);
return vectorService.batch(userId, id, documentList);
}
}
服务层
DatasetDocumentVectorService
package com.litongjava.maxkb.service;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Future;
import org.postgresql.util.PGobject;
import com.jfinal.kit.Kv;
import com.litongjava.db.TableInput;
import com.litongjava.db.TableResult;
import com.litongjava.db.activerecord.Db;
import com.litongjava.db.activerecord.Row;
import com.litongjava.jfinal.aop.Aop;
import com.litongjava.maxkb.constant.TableNames;
import com.litongjava.maxkb.utils.ExecutorServiceUtils;
import com.litongjava.maxkb.vo.DocumentBatchVo;
import com.litongjava.maxkb.vo.Paragraph;
import com.litongjava.model.result.ResultVo;
import com.litongjava.table.services.ApiTable;
import com.litongjava.tio.utils.crypto.Md5Utils;
import com.litongjava.tio.utils.hutool.FilenameUtils;
import com.litongjava.tio.utils.snowflake.SnowflakeIdUtils;
public class DatasetDocumentVectorService {
public ResultVo batch(Long userId, Long dataset_id, List<DocumentBatchVo> list) {
TableInput tableInput = new TableInput();
tableInput.set("id", dataset_id);
if (!userId.equals(1L)) {
tableInput.set("user_id", userId);
}
TableResult<Row> result = ApiTable.get(TableNames.max_kb_dataset, tableInput);
Row dataset = result.getData();
if (dataset == null) {
return ResultVo.fail("Dataset not found.");
}
Long embedding_mode_id = dataset.getLong("embedding_mode_id");
String sqlModelName = String.format("SELECT model_name FROM %s WHERE id = ?", TableNames.max_kb_model);
String modelName = Db.queryStr(sqlModelName, embedding_mode_id);
String sqlDocumentId = String.format("SELECT id FROM %s WHERE user_id = ? AND file_id = ?", TableNames.max_kb_document);
List<Kv> kvs = new ArrayList<>();
CompletionService<Row> completionService = new ExecutorCompletionService<>(ExecutorServiceUtils.getExecutorService());
for (DocumentBatchVo documentBatchVo : list) {
Long fileId = documentBatchVo.getId();
String filename = documentBatchVo.getName();
Long documentId = Db.queryLong(sqlDocumentId, userId, fileId);
List<Paragraph> paragraphs = documentBatchVo.getParagraphs();
int char_length = 0;
for (Paragraph p : paragraphs) {
if (p.getContent() != null) {
char_length += p.getContent().length();
}
}
String type = FilenameUtils.getSuffix(filename);
if (documentId == null) {
documentId = SnowflakeIdUtils.id();
Row row = Row.by("id", documentId).set("file_id", fileId).set("user_id", userId).set("name", filename).set("char_length", char_length).set("status", "1").set("is_active", true)
.set("type", type).set("dataset_id", dataset_id).set("paragraph_count", paragraphs.size()).set("hit_handling_method", "optimization").set("directly_return_similarity", 0.9);
Db.save(TableNames.max_kb_document, row);
Kv kv = row.toKv();
kvs.add(kv);
} else {
Row existingRecord = Db.findById(TableNames.max_kb_document, documentId);
if (existingRecord != null) {
Kv kv = existingRecord.toKv();
kvs.add(kv);
} else {
// Handle the case where documentId is provided but the row does not exist
return ResultVo.fail("Document not found for ID: " + documentId);
}
}
MaxKbEmbeddingService maxKbEmbeddingService = Aop.get(MaxKbEmbeddingService.class);
List<Future<Row>> futures = new ArrayList<>();
final long documentIdFinal = documentId;
for (Paragraph p : paragraphs) {
futures.add(completionService.submit(() -> {
String title = p.getTitle();
String content = p.getContent();
PGobject vector = maxKbEmbeddingService.getVector(content, modelName);
return Row.by("id", SnowflakeIdUtils.id())
//
.set("source_id", fileId)
//
.set("source_type", type)
//
.set("title", title)
//
.set("content", content)
//
.set("md5", Md5Utils.getMD5(content))
//
.set("status", "1")
//
.set("hit_num", 0)
//
.set("is_active", true).set("dataset_id", dataset_id).set("document_id", documentIdFinal)
//
.set("embedding", vector);
}));
}
List<Row> batchRecord = new ArrayList<>(paragraphs.size());
for (int i = 0; i < paragraphs.size(); i++) {
try {
Future<Row> future = completionService.take();
Row row = future.get();
if (row != null) {
batchRecord.add(row);
}
} catch (Exception e) {
e.printStackTrace();
}
}
boolean transactionSuccess = Db.tx(() -> {
Db.delete(TableNames.max_kb_paragraph, Row.by("document_id", documentIdFinal));
Db.batchSave(TableNames.max_kb_paragraph, batchRecord, 2000);
return true;
});
if (!transactionSuccess) {
return ResultVo.fail("Transaction failed while saving paragraphs for document ID: " + documentIdFinal);
}
}
return ResultVo.ok(kvs);
}
}
MaxKbEmbeddingService
package com.litongjava.maxkb.service;
import java.util.Arrays;
import org.postgresql.util.PGobject;
import com.litongjava.db.activerecord.Db;
import com.litongjava.db.activerecord.Row;
import com.litongjava.db.utils.PgVectorUtils;
import com.litongjava.maxkb.constant.TableNames;
import com.litongjava.openai.client.OpenAiClient;
import com.litongjava.tio.utils.crypto.Md5Utils;
import com.litongjava.tio.utils.snowflake.SnowflakeIdUtils;
public class MaxKbEmbeddingService {
private final Object vectorLock = new Object();
private final Object writeLock = new Object();
public PGobject getVector(String text, String model) {
String v = null;
String md5 = Md5Utils.getMD5(text);
String sql = String.format("select v from %s where md5=? and m=?", TableNames.max_kb_embedding_cache);
PGobject pGobject = Db.queryFirst(sql, md5, model);
if (pGobject == null) {
float[] embeddingArray = null;
try {
embeddingArray = OpenAiClient.embeddingArray(text, model);
} catch (Exception e) {
try {
embeddingArray = OpenAiClient.embeddingArray(text, model);
} catch (Exception e1) {
embeddingArray = OpenAiClient.embeddingArray(text, model);
}
}
String string = Arrays.toString(embeddingArray);
long id = SnowflakeIdUtils.id();
v = (String) string;
pGobject = PgVectorUtils.getPgVector(v);
Row saveRecord = new Row().set("t", text).set("v", pGobject).set("id", id).set("md5", md5)
//
.set("m", model);
synchronized (writeLock) {
Db.save(TableNames.max_kb_embedding_cache, saveRecord);
}
}
return pGobject;
}
public Long getVectorId(String text, String model) {
String md5 = Md5Utils.getMD5(text);
String sql = String.format("select id from %s where md5=? and m=?", TableNames.max_kb_embedding_cache);
Long id = Db.queryLong(sql, md5, model);
if (id == null) {
float[] embeddingArray = null;
synchronized (vectorLock) {
embeddingArray = OpenAiClient.embeddingArray(text, model);
}
String vString = Arrays.toString(embeddingArray);
id = SnowflakeIdUtils.id();
PGobject pGobject = PgVectorUtils.getPgVector(vString);
Row saveRecord = new Row().set("t", text).set("v", pGobject).set("id", id).set("md5", md5)
//
.set("m", model);
synchronized (writeLock) {
Db.save(TableNames.max_kb_embedding_cache, saveRecord);
}
}
return id;
}
}
说明
缓存机制: 为了优化性能,
MaxKbEmbeddingService
在生成向量时会先查询max_kb_embedding_cache
表。如果缓存中存在对应的向量,则直接使用缓存;否则,调用 OpenAI 接口生成向量,并将结果缓存到数据库中。线程安全: 在
MaxKbEmbeddingService
中,vectorLock
和writeLock
用于确保向量生成和缓存写入过程的线程安全,避免并发冲突。事务管理: 在
DatasetDocumentVectorService
中,批量保存段落记录时使用事务管理,确保数据的一致性。唯一标识: 使用
SnowflakeIdUtils
生成唯一的 ID,确保记录的唯一性。MD5 校验: 使用 MD5 对段落内容进行校验,以便快速查询和缓存。
总结
通过以上实现,我们完成了片段向量化的功能,利用缓存机制有效减少了对向量模型的调用次数,提升了系统的整体性能和响应速度。后续可以根据需求进一步优化缓存策略和并发处理机制。