|
|
@@ -1,19 +1,25 @@
|
|
|
package org.dromara.talk.service.impl;
|
|
|
|
|
|
import cn.dev33.satoken.stp.StpUtil;
|
|
|
+import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.dromara.common.satoken.utils.LoginHelper;
|
|
|
import org.dromara.talk.domain.vo.TalkAgentVo;
|
|
|
+import org.dromara.talk.domain.vo.TalkSessionVo;
|
|
|
import org.dromara.talk.service.IChatService;
|
|
|
import org.dromara.talk.service.IDifyService;
|
|
|
import org.dromara.talk.service.ITalkAgentService;
|
|
|
+import org.dromara.talk.service.ITalkSessionService;
|
|
|
import org.dromara.talk.service.ITtsService;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
|
|
|
-import java.util.HashMap;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
+import java.io.ByteArrayOutputStream;
|
|
|
+import java.util.*;
|
|
|
import java.util.concurrent.CompletableFuture;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
+import java.util.concurrent.CountDownLatch;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
@Slf4j
|
|
|
@@ -24,160 +30,34 @@ public class ChatServiceImpl implements IChatService {
|
|
|
private final ITtsService ttsService;
|
|
|
private final ITalkAgentService talkAgentService;
|
|
|
private final IDifyService difyService;
|
|
|
+ private final ITalkSessionService talkSessionService;
|
|
|
|
|
|
// 存储每个用户的最新请求ID
|
|
|
- private final Map<Long, Integer> latestRequestIdMap = new java.util.concurrent.ConcurrentHashMap<>();
|
|
|
+ private final Map<Long, Integer> latestRequestIdMap = new ConcurrentHashMap<>();
|
|
|
|
|
|
- @Override
|
|
|
- public Map<String, Object> processMessage(String userMessage, Long agentId, String agentGender, List<Map<String, String>> ttsVcnList, String conversationId, Boolean isGreeting, Integer requestId) {
|
|
|
- log.info("处理用户消息: {}, 客服ID: {}, 客服性别: {}, 对话ID: {}, 是否欢迎语: {}, 请求ID: {}", userMessage, agentId, agentGender, conversationId, isGreeting, requestId);
|
|
|
-
|
|
|
- // 获取当前登录用户ID
|
|
|
- Long userId = null;
|
|
|
- try {
|
|
|
- userId = StpUtil.getLoginIdAsLong();
|
|
|
- } catch (Exception e) {
|
|
|
- log.warn("获取登录用户ID失败,使用默认值", e);
|
|
|
- userId = 0L;
|
|
|
- }
|
|
|
-
|
|
|
- // 更新最新请求ID
|
|
|
- if (requestId != null && userId != null) {
|
|
|
- latestRequestIdMap.put(userId, requestId);
|
|
|
- log.info("更新用户 {} 的最新请求ID为: {}", userId, requestId);
|
|
|
- }
|
|
|
-
|
|
|
- // 获取客服配置
|
|
|
- TalkAgentVo agentConfig = null;
|
|
|
- if (agentId != null) {
|
|
|
- agentConfig = talkAgentService.queryById(agentId);
|
|
|
- }
|
|
|
-
|
|
|
- String reply;
|
|
|
- String newConversationId = conversationId;
|
|
|
-
|
|
|
- // 如果是欢迎语,直接使用消息文本,不调用 Dify 工作流
|
|
|
- if (Boolean.TRUE.equals(isGreeting)) {
|
|
|
- log.info("处理欢迎语,跳过 Dify 工作流");
|
|
|
- reply = userMessage;
|
|
|
- } else {
|
|
|
- // 调用 Dify 生成回复
|
|
|
- Map<String, String> aiResult = generateReply(userMessage, agentGender, ttsVcnList, agentConfig, userId, conversationId);
|
|
|
- reply = aiResult.get("replyText");
|
|
|
- String selectedVcn = aiResult.get("ttsVcn");
|
|
|
- newConversationId = aiResult.get("conversationId");
|
|
|
-
|
|
|
- // 如果AI选择了发言人,更新客服配置
|
|
|
- if (selectedVcn != null && agentConfig != null) {
|
|
|
- agentConfig.setTtsVcn(selectedVcn);
|
|
|
- }
|
|
|
- }
|
|
|
+ // 存储每个会话的对话内容
|
|
|
+ private final Map<String, List<Map<String, String>>> conversationMap = new ConcurrentHashMap<>();
|
|
|
|
|
|
- // 检查是否需要合成音频(只为最新请求合成音频)
|
|
|
- boolean needAudio = true;
|
|
|
- if (requestId != null && userId != null) {
|
|
|
- Integer latestRequestId = latestRequestIdMap.get(userId);
|
|
|
- if (latestRequestId != null && !latestRequestId.equals(requestId)) {
|
|
|
- needAudio = false;
|
|
|
- log.info("请求ID {} 不是最新请求(最新为 {}),跳过音频合成", requestId, latestRequestId);
|
|
|
+ @Override
|
|
|
+ public void processMessageStream(String userMessage, Long agentId, String agentGender,
|
|
|
+ List<Map<String, String>> ttsVcnList, String conversationId,
|
|
|
+ Boolean isGreeting, Integer requestId, String customerPhone, SseEmitter emitter) {
|
|
|
+ // 如果没有传递customerPhone,尝试从数据库中查询
|
|
|
+ if (customerPhone == null && conversationId != null) {
|
|
|
+ TalkSessionVo session = talkSessionService.queryBySessionId(conversationId);
|
|
|
+ if (session != null) {
|
|
|
+ customerPhone = session.getCustomerPhone();
|
|
|
+ log.info("从数据库查询到客户手机号: {}", customerPhone);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 合成语音(只有最新请求才合成)
|
|
|
- String audioBase64 = null;
|
|
|
- if (needAudio) {
|
|
|
- audioBase64 = synthesizeAudio(reply, agentConfig);
|
|
|
- } else {
|
|
|
- log.info("跳过音频合成");
|
|
|
- }
|
|
|
-
|
|
|
- // 构建响应
|
|
|
- Map<String, Object> response = new HashMap<>();
|
|
|
- response.put("reply", reply);
|
|
|
- response.put("audio", audioBase64);
|
|
|
- response.put("timestamp", System.currentTimeMillis());
|
|
|
- if (newConversationId != null) {
|
|
|
- response.put("conversationId", newConversationId);
|
|
|
- }
|
|
|
-
|
|
|
- log.info("消息处理完成: reply长度={}, audio长度={}",
|
|
|
- reply != null ? reply.length() : 0,
|
|
|
- audioBase64 != null ? audioBase64.length() : 0);
|
|
|
-
|
|
|
- return response;
|
|
|
- }
|
|
|
-
|
|
|
- private Map<String, String> generateReply(String userMessage, String agentGender,
|
|
|
- List<Map<String, String>> ttsVcnList,
|
|
|
- TalkAgentVo agentConfig,
|
|
|
- Long userId,
|
|
|
- String conversationId) {
|
|
|
- // 组装发送给 Dify 的数据
|
|
|
- Map<String, Object> inputs = new HashMap<>();
|
|
|
- inputs.put("agentGender", agentGender);
|
|
|
- inputs.put("ttsVcnList", ttsVcnList);
|
|
|
- inputs.put("currentVcn", agentConfig != null ? agentConfig.getTtsVcn() : null);
|
|
|
-
|
|
|
- log.info("调用 Dify 工作流 - 用户ID: {}, 对话ID: {}", userId, conversationId);
|
|
|
-
|
|
|
- // 调用 Dify 工作流
|
|
|
- Map<String, String> aiResponse = difyService.callWorkflow(userMessage, inputs, userId, conversationId);
|
|
|
-
|
|
|
- log.info("Dify 工作流响应: {}", aiResponse);
|
|
|
-
|
|
|
- return aiResponse;
|
|
|
- }
|
|
|
-
|
|
|
- private String synthesizeAudio(String text, TalkAgentVo agentConfig) {
|
|
|
- CompletableFuture<String> audioFuture = new CompletableFuture<>();
|
|
|
-
|
|
|
- ttsService.synthesize(text, agentConfig, new ITtsService.AudioCallback() {
|
|
|
- private final java.io.ByteArrayOutputStream audioBytes = new java.io.ByteArrayOutputStream();
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onAudio(String base64Audio, int status) {
|
|
|
- try {
|
|
|
- //解码base64音频片段
|
|
|
- byte[] decoded = java.util.Base64.getDecoder().decode(base64Audio);
|
|
|
- //追加到字节流
|
|
|
- audioBytes.write(decoded);
|
|
|
- } catch (Exception e) {
|
|
|
- log.error("解码音频数据失败", e);
|
|
|
- }
|
|
|
- if (status == 2) {
|
|
|
- // 将完整音频重新编码为 base64
|
|
|
- String finalBase64 = java.util.Base64.getEncoder().encodeToString(audioBytes.toByteArray());
|
|
|
- //完成异步任务
|
|
|
- audioFuture.complete(finalBase64);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onError(int code, String message) {
|
|
|
- log.error("TTS合成失败: {}", message);
|
|
|
- audioFuture.complete(null);
|
|
|
- }
|
|
|
- });
|
|
|
-
|
|
|
- // 等待音频合成完成(最多10秒)
|
|
|
- try {
|
|
|
- return audioFuture.get(10, TimeUnit.SECONDS);
|
|
|
- } catch (Exception e) {
|
|
|
- log.error("等待音频合成超时", e);
|
|
|
- return null;
|
|
|
- }
|
|
|
- }
|
|
|
+ String finalCustomerPhone = customerPhone;
|
|
|
|
|
|
- @Override
|
|
|
- public void processMessageStream(String userMessage, Long agentId, String agentGender,
|
|
|
- List<Map<String, String>> ttsVcnList, String conversationId,
|
|
|
- Boolean isGreeting, Integer requestId, org.springframework.web.servlet.mvc.method.annotation.SseEmitter emitter) {
|
|
|
// 在主线程中获取用户ID,避免在异步线程中访问ThreadLocal
|
|
|
- Long userId = 0L;
|
|
|
- try {
|
|
|
- userId = StpUtil.getLoginIdAsLong();
|
|
|
- } catch (Exception e) {
|
|
|
- log.warn("获取登录用户ID失败,使用默认值", e);
|
|
|
+ Long userId = LoginHelper.getUserId();
|
|
|
+ if (userId == null) {
|
|
|
+ userId = 0L;
|
|
|
+ log.warn("获取登录用户ID失败,使用默认值");
|
|
|
}
|
|
|
|
|
|
// 更新最新请求ID
|
|
|
@@ -201,20 +81,20 @@ public class ChatServiceImpl implements IChatService {
|
|
|
TalkAgentVo finalAgentConfig = agentConfig;
|
|
|
|
|
|
// 合成欢迎语语音
|
|
|
- java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1);
|
|
|
- java.io.ByteArrayOutputStream mergedAudioBytes = new java.io.ByteArrayOutputStream();
|
|
|
+ CountDownLatch latch = new CountDownLatch(1);
|
|
|
+ ByteArrayOutputStream mergedAudioBytes = new ByteArrayOutputStream();
|
|
|
|
|
|
ttsService.synthesizeStream(userMessage, finalAgentConfig, (audioChunk, status) -> {
|
|
|
try {
|
|
|
- byte[] audioBytes = java.util.Base64.getDecoder().decode(audioChunk);
|
|
|
+ byte[] audioBytes = Base64.getDecoder().decode(audioChunk);
|
|
|
mergedAudioBytes.write(audioBytes);
|
|
|
|
|
|
if (status == 2) {
|
|
|
- String mergedAudioBase64 = java.util.Base64.getEncoder().encodeToString(mergedAudioBytes.toByteArray());
|
|
|
+ String mergedAudioBase64 = Base64.getEncoder().encodeToString(mergedAudioBytes.toByteArray());
|
|
|
Map<String, String> audioEvent = new HashMap<>();
|
|
|
audioEvent.put("name", "audio");
|
|
|
audioEvent.put("data", mergedAudioBase64);
|
|
|
- emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
|
|
|
+ emitter.send(SseEmitter.event()
|
|
|
.data(audioEvent));
|
|
|
latch.countDown();
|
|
|
}
|
|
|
@@ -236,7 +116,7 @@ public class ChatServiceImpl implements IChatService {
|
|
|
Map<String, String> doneEvent = new HashMap<>();
|
|
|
doneEvent.put("name", "done");
|
|
|
doneEvent.put("data", "");
|
|
|
- emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
|
|
|
+ emitter.send(SseEmitter.event()
|
|
|
.data(doneEvent));
|
|
|
emitter.complete();
|
|
|
return;
|
|
|
@@ -245,14 +125,48 @@ public class ChatServiceImpl implements IChatService {
|
|
|
TalkAgentVo finalAgentConfig = agentConfig;
|
|
|
Integer finalRequestId = requestId;
|
|
|
|
|
|
- difyService.callWorkflowStream(userMessage, agentGender, ttsVcnList, agentConfig, finalUserId, conversationId,
|
|
|
+ // 初始化对话记录列表
|
|
|
+ String finalConversationId = conversationId;
|
|
|
+ if (finalConversationId != null) {
|
|
|
+ conversationMap.putIfAbsent(finalConversationId, new ArrayList<>());
|
|
|
+
|
|
|
+ // 添加用户消息(客户)
|
|
|
+ Map<String, String> userMsg = new HashMap<>();
|
|
|
+ userMsg.put("role", "user");
|
|
|
+ userMsg.put("content", userMessage);
|
|
|
+ userMsg.put("timestamp", String.valueOf(System.currentTimeMillis()));
|
|
|
+ conversationMap.get(finalConversationId).add(userMsg);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 用于累积AI回复的完整文本
|
|
|
+ StringBuilder aiReplyBuilder = new StringBuilder();
|
|
|
+
|
|
|
+ // 检查 conversationId 是否是临时 sessionId
|
|
|
+ // 如果是临时 sessionId(数据库中还未被 Dify 更新),则传递 null 给 Dify
|
|
|
+ String difyConversationId = conversationId;
|
|
|
+ if (conversationId != null) {
|
|
|
+ TalkSessionVo session = talkSessionService.queryBySessionId(conversationId);
|
|
|
+ if (session != null && conversationId.equals(session.getSessionId())) {
|
|
|
+ // 检查这个 sessionId 是否是刚创建的(没有对话内容)
|
|
|
+ // 如果没有对话内容,说明这是第一次对话,不应该传递给 Dify
|
|
|
+ if (session.getConversationJson() == null || session.getConversationJson().isEmpty()) {
|
|
|
+ difyConversationId = null;
|
|
|
+ log.info("检测到临时 sessionId,第一次对话不传递 conversationId 给 Dify");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ difyService.callWorkflowStream(userMessage, agentGender, ttsVcnList, agentConfig, finalUserId, difyConversationId,
|
|
|
(textChunk) -> {
|
|
|
try {
|
|
|
+ // 累积AI回复文本
|
|
|
+ aiReplyBuilder.append(textChunk);
|
|
|
+
|
|
|
// 构建 JSON 格式的事件数据
|
|
|
Map<String, String> eventData = new HashMap<>();
|
|
|
eventData.put("name", "text");
|
|
|
eventData.put("data", textChunk);
|
|
|
- emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
|
|
|
+ emitter.send(SseEmitter.event()
|
|
|
.data(eventData));
|
|
|
} catch (Exception e) {
|
|
|
log.error("发送文本失败", e);
|
|
|
@@ -264,10 +178,30 @@ public class ChatServiceImpl implements IChatService {
|
|
|
|
|
|
// 发送conversationId
|
|
|
if (newConversationId != null) {
|
|
|
+ // 确保conversationMap中有对应的列表
|
|
|
+ conversationMap.putIfAbsent(newConversationId, new ArrayList<>());
|
|
|
+
|
|
|
+ // 更新conversationId
|
|
|
+ if (finalConversationId == null) {
|
|
|
+ // 第一次对话,已经在上面创建了列表
|
|
|
+ } else if (!finalConversationId.equals(newConversationId)) {
|
|
|
+ // 如果传入的conversationId与Dify返回的不同,说明传入的是临时sessionId
|
|
|
+ // 需要更新数据库中的sessionId
|
|
|
+ talkSessionService.updateSessionId(finalConversationId, newConversationId);
|
|
|
+ log.info("更新临时sessionId {} 为Dify的conversationId {}", finalConversationId, newConversationId);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 添加用户消息(客户)- 每次对话都要添加
|
|
|
+ Map<String, String> userMsg = new HashMap<>();
|
|
|
+ userMsg.put("role", "user");
|
|
|
+ userMsg.put("content", userMessage);
|
|
|
+ userMsg.put("timestamp", String.valueOf(System.currentTimeMillis()));
|
|
|
+ conversationMap.get(newConversationId).add(userMsg);
|
|
|
+
|
|
|
Map<String, String> conversationEvent = new HashMap<>();
|
|
|
conversationEvent.put("name", "conversationId");
|
|
|
conversationEvent.put("data", newConversationId);
|
|
|
- emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
|
|
|
+ emitter.send(SseEmitter.event()
|
|
|
.data(conversationEvent));
|
|
|
}
|
|
|
|
|
|
@@ -285,27 +219,27 @@ public class ChatServiceImpl implements IChatService {
|
|
|
log.info("合成句子音频,长度: {}, 内容: {}", sentence.length(), sentence);
|
|
|
|
|
|
// 使用 CountDownLatch 等待音频合成完成
|
|
|
- java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1);
|
|
|
+ CountDownLatch latch = new CountDownLatch(1);
|
|
|
|
|
|
// 用于累积同一句子的所有音频片段(字节数组)
|
|
|
- java.io.ByteArrayOutputStream mergedAudioBytes = new java.io.ByteArrayOutputStream();
|
|
|
+ ByteArrayOutputStream mergedAudioBytes = new ByteArrayOutputStream();
|
|
|
|
|
|
// 对每个句子进行 TTS 合成
|
|
|
ttsService.synthesizeStream(sentence, finalAgentConfig, (audioChunk, status) -> {
|
|
|
try {
|
|
|
// 解码base64音频片段并累积到字节流
|
|
|
- byte[] audioBytes = java.util.Base64.getDecoder().decode(audioChunk);
|
|
|
+ byte[] audioBytes = Base64.getDecoder().decode(audioChunk);
|
|
|
mergedAudioBytes.write(audioBytes);
|
|
|
|
|
|
// 当音频合成完成时(status=2),发送合并后的音频
|
|
|
if (status == 2) {
|
|
|
// 将合并后的字节数组重新编码为base64
|
|
|
- String mergedAudioBase64 = java.util.Base64.getEncoder().encodeToString(mergedAudioBytes.toByteArray());
|
|
|
+ String mergedAudioBase64 = Base64.getEncoder().encodeToString(mergedAudioBytes.toByteArray());
|
|
|
|
|
|
Map<String, String> audioEvent = new HashMap<>();
|
|
|
audioEvent.put("name", "audio");
|
|
|
audioEvent.put("data", mergedAudioBase64);
|
|
|
- emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
|
|
|
+ emitter.send(SseEmitter.event()
|
|
|
.data(audioEvent));
|
|
|
|
|
|
log.info("句子音频合成完成,合并后长度: {}, 释放锁", mergedAudioBase64.length());
|
|
|
@@ -320,7 +254,7 @@ public class ChatServiceImpl implements IChatService {
|
|
|
// 等待当前句子的音频合成完成
|
|
|
try {
|
|
|
log.info("等待句子音频合成完成...");
|
|
|
- boolean completed = latch.await(30, java.util.concurrent.TimeUnit.SECONDS);
|
|
|
+ boolean completed = latch.await(30, TimeUnit.SECONDS);
|
|
|
if (completed) {
|
|
|
log.info("句子音频合成等待完成");
|
|
|
} else {
|
|
|
@@ -335,10 +269,38 @@ public class ChatServiceImpl implements IChatService {
|
|
|
// 如果是最后一个句子,发送完成事件
|
|
|
if (isComplete) {
|
|
|
log.info("收到完成标志,准备发送done事件");
|
|
|
+
|
|
|
+ // 保存对话内容到数据库
|
|
|
+ if (newConversationId != null && finalAgentConfig != null) {
|
|
|
+ try {
|
|
|
+ // 获取当前会话的对话内容
|
|
|
+ List<Map<String, String>> messages = conversationMap.getOrDefault(newConversationId, new ArrayList<>());
|
|
|
+
|
|
|
+ // 添加AI回复消息(客服)
|
|
|
+ if (aiReplyBuilder.length() > 0) {
|
|
|
+ Map<String, String> assistantMsg = new HashMap<>();
|
|
|
+ assistantMsg.put("role", "assistant");
|
|
|
+ assistantMsg.put("content", aiReplyBuilder.toString());
|
|
|
+ assistantMsg.put("timestamp", String.valueOf(System.currentTimeMillis()));
|
|
|
+ messages.add(assistantMsg);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 将对话内容转换为JSON字符串
|
|
|
+ String conversationJson = new ObjectMapper().writeValueAsString(messages);
|
|
|
+
|
|
|
+ // 保存到数据库
|
|
|
+ talkSessionService.saveOrUpdateConversation(newConversationId, finalAgentConfig.getId(), conversationJson, finalCustomerPhone, finalUserId);
|
|
|
+
|
|
|
+ log.info("对话内容已保存到数据库,会话ID: {}, 消息数量: {}", newConversationId, messages.size());
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("保存对话内容失败", e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
Map<String, String> doneEvent = new HashMap<>();
|
|
|
doneEvent.put("name", "done");
|
|
|
doneEvent.put("data", "");
|
|
|
- emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
|
|
|
+ emitter.send(SseEmitter.event()
|
|
|
.data(doneEvent));
|
|
|
log.info("done事件已发送,关闭SSE连接");
|
|
|
emitter.complete();
|