浏览代码

对话流式实现

Zhangbw 2 月之前
父节点
当前提交
09cf164142

+ 2 - 0
ruoyi-admin/src/main/resources/application.yml

@@ -116,6 +116,8 @@ security:
     - /*/api-docs
     - /*/api-docs/**
     - /warm-flow-ui/config
+    - /talk/config/**
+    - /talk/message/stream
 
 # 多租户配置
 tenant:

+ 26 - 0
ruoyi-modules/yp-talk/src/main/java/org/dromara/talk/controller/api/ChatController.java

@@ -10,6 +10,7 @@ import org.dromara.talk.domain.vo.TalkAgentVo;
 import org.dromara.talk.service.IChatService;
 import org.dromara.talk.service.ITalkAgentService;
 import org.springframework.web.bind.annotation.*;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
 import java.util.List;
 import java.util.Map;
@@ -124,4 +125,29 @@ public class ChatController {
         log.info("客服 {} 挂断电话,状态已改为正常", id);
         return Map.of("success", success);
     }
+
+    /**
+     * 流式处理用户消息
+     * 使用SSE实时推送文本和音频数据
+     */
+    @PostMapping("/message/stream")
+    public SseEmitter handleMessageStream(@RequestBody Map<String, Object> request) {
+        String userMessage = (String) request.get("message");
+        Long agentId = request.get("agentId") != null ?
+            Long.valueOf(request.get("agentId").toString()) : null;
+        String agentGender = (String) request.get("agentGender");
+        @SuppressWarnings("unchecked")
+        List<Map<String, String>> ttsVcnList = (List<Map<String, String>>) request.get("ttsVcnList");
+        String conversationId = (String) request.get("conversationId");
+        Boolean isGreeting = request.get("isGreeting") != null ?
+            Boolean.valueOf(request.get("isGreeting").toString()) : false;
+        Integer requestId = request.get("requestId") != null ?
+            Integer.valueOf(request.get("requestId").toString()) : null;
+
+        log.info("收到流式消息请求: {}, 客服ID: {}, 对话ID: {}, 请求ID: {}", userMessage, agentId, conversationId, requestId);
+
+        SseEmitter emitter = new SseEmitter(60000L);
+        chatService.processMessageStream(userMessage, agentId, agentGender, ttsVcnList, conversationId, isGreeting, requestId, emitter);
+        return emitter;
+    }
 }

+ 16 - 0
ruoyi-modules/yp-talk/src/main/java/org/dromara/talk/service/IChatService.java

@@ -1,5 +1,7 @@
 package org.dromara.talk.service;
 
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
 import java.util.List;
 import java.util.Map;
 
@@ -21,4 +23,18 @@ public interface IChatService {
      * @return 响应数据(包含回复文本和音频)
      */
     Map<String, Object> processMessage(String userMessage, Long agentId, String agentGender, List<Map<String, String>> ttsVcnList, String conversationId, Boolean isGreeting, Integer requestId);
+
+    /**
+     * 流式处理用户消息
+     *
+     * @param userMessage 用户消息
+     * @param agentId 客服ID
+     * @param agentGender 客服性别
+     * @param ttsVcnList 发言人字典列表
+     * @param conversationId 对话ID
+     * @param isGreeting 是否为欢迎语
+     * @param requestId 请求ID(用于判断是否为最新请求)
+     * @param emitter SSE发射器
+     */
+    void processMessageStream(String userMessage, Long agentId, String agentGender, List<Map<String, String>> ttsVcnList, String conversationId, Boolean isGreeting, Integer requestId, SseEmitter emitter);
 }

+ 21 - 0
ruoyi-modules/yp-talk/src/main/java/org/dromara/talk/service/IDifyService.java

@@ -1,9 +1,30 @@
 package org.dromara.talk.service;
 
+import org.dromara.talk.domain.vo.TalkAgentVo;
+
+import java.util.List;
 import java.util.Map;
+import java.util.function.Consumer;
 
 public interface IDifyService {
 
    Map<String, String> callWorkflow(String userMessage, Map<String, Object> inputs, Long userId, String conversationId);
 
+   /**
+    * 流式调用Dify工作流(按句子分段)
+    */
+   void callWorkflowStream(String userMessage, String agentGender, List<Map<String, String>> ttsVcnList,
+                          TalkAgentVo agentConfig, Long userId, String conversationId,
+                          Consumer<String> onTextChunk,
+                          SentenceCallback onSentence);
+
+   @FunctionalInterface
+   interface TriConsumer<T, U, V> {
+       void accept(T t, U u, V v);
+   }
+
+   @FunctionalInterface
+   interface SentenceCallback {
+       void onSentence(String sentence, String conversationId, boolean isComplete);
+   }
 }

+ 7 - 0
ruoyi-modules/yp-talk/src/main/java/org/dromara/talk/service/ITtsService.java

@@ -2,6 +2,8 @@ package org.dromara.talk.service;
 
 import org.dromara.talk.domain.vo.TalkAgentVo;
 
+import java.util.function.BiConsumer;
+
 /**
  * TTS语音合成服务接口
  */
@@ -16,6 +18,11 @@ public interface ITtsService {
      */
     void synthesize(String text, TalkAgentVo agentConfig, AudioCallback callback);
 
+    /**
+     * 流式合成语音
+     */
+    void synthesizeStream(String text, TalkAgentVo agentConfig, BiConsumer<String, Integer> callback);
+
     /**
      * 音频回调接口
      */

+ 163 - 0
ruoyi-modules/yp-talk/src/main/java/org/dromara/talk/service/impl/ChatServiceImpl.java

@@ -167,4 +167,167 @@ public class ChatServiceImpl implements IChatService {
             return null;
         }
     }
+
+    @Override
+    public void processMessageStream(String userMessage, Long agentId, String agentGender,
+                                     List<Map<String, String>> ttsVcnList, String conversationId,
+                                     Boolean isGreeting, Integer requestId, org.springframework.web.servlet.mvc.method.annotation.SseEmitter emitter) {
+        // 在主线程中获取用户ID,避免在异步线程中访问ThreadLocal
+        Long userId = 0L;
+        try {
+            userId = StpUtil.getLoginIdAsLong();
+        } catch (Exception e) {
+            log.warn("获取登录用户ID失败,使用默认值", e);
+        }
+
+        // 更新最新请求ID
+        if (requestId != null) {
+            latestRequestIdMap.put(userId, requestId);
+            log.info("流式请求 - 更新用户 {} 的最新请求ID为: {}", userId, requestId);
+        }
+
+        Long finalUserId = userId;
+        CompletableFuture.runAsync(() -> {
+            try {
+
+                // 获取客服配置
+                TalkAgentVo agentConfig = null;
+                if (agentId != null) {
+                    agentConfig = talkAgentService.queryById(agentId);
+                }
+
+                // 如果是欢迎语,直接发送
+                if (Boolean.TRUE.equals(isGreeting)) {
+                    Map<String, String> textEvent = new HashMap<>();
+                    textEvent.put("name", "text");
+                    textEvent.put("data", userMessage);
+                    emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
+                        .data(textEvent));
+
+                    Map<String, String> doneEvent = new HashMap<>();
+                    doneEvent.put("name", "done");
+                    doneEvent.put("data", "");
+                    emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
+                        .data(doneEvent));
+                    emitter.complete();
+                    return;
+                }
+
+                TalkAgentVo finalAgentConfig = agentConfig;
+                Integer finalRequestId = requestId;
+
+                difyService.callWorkflowStream(userMessage, agentGender, ttsVcnList, agentConfig, finalUserId, conversationId,
+                    (textChunk) -> {
+                        try {
+                            // 构建 JSON 格式的事件数据
+                            Map<String, String> eventData = new HashMap<>();
+                            eventData.put("name", "text");
+                            eventData.put("data", textChunk);
+                            emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
+                                .data(eventData));
+                        } catch (Exception e) {
+                            log.error("发送文本失败", e);
+                        }
+                    },
+                    (sentence, newConversationId, isComplete) -> {
+                        try {
+                            log.info("句子回调 - 句子: {}, isComplete: {}", sentence != null ? sentence : "(null)", isComplete);
+
+                            // 发送conversationId
+                            if (newConversationId != null) {
+                                Map<String, String> conversationEvent = new HashMap<>();
+                                conversationEvent.put("name", "conversationId");
+                                conversationEvent.put("data", newConversationId);
+                                emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
+                                    .data(conversationEvent));
+                            }
+
+                            // 检查是否需要合成音频(只为最新请求合成音频)
+                            boolean needAudio = true;
+                            if (finalRequestId != null) {
+                                Integer latestRequestId = latestRequestIdMap.get(finalUserId);
+                                if (latestRequestId != null && !latestRequestId.equals(finalRequestId)) {
+                                    needAudio = false;
+                                    log.info("流式请求ID {} 不是最新请求(最新为 {}),跳过音频合成", finalRequestId, latestRequestId);
+                                }
+                            }
+
+                            if (needAudio && sentence != null && !sentence.trim().isEmpty()) {
+                                log.info("合成句子音频,长度: {}, 内容: {}", sentence.length(), sentence);
+
+                                // 使用 CountDownLatch 等待音频合成完成
+                                java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1);
+
+                                // 用于累积同一句子的所有音频片段(字节数组)
+                                java.io.ByteArrayOutputStream mergedAudioBytes = new java.io.ByteArrayOutputStream();
+
+                                // 对每个句子进行 TTS 合成
+                                ttsService.synthesizeStream(sentence, finalAgentConfig, (audioChunk, status) -> {
+                                    try {
+                                        // 解码base64音频片段并累积到字节流
+                                        byte[] audioBytes = java.util.Base64.getDecoder().decode(audioChunk);
+                                        mergedAudioBytes.write(audioBytes);
+
+                                        // 当音频合成完成时(status=2),发送合并后的音频
+                                        if (status == 2) {
+                                            // 将合并后的字节数组重新编码为base64
+                                            String mergedAudioBase64 = java.util.Base64.getEncoder().encodeToString(mergedAudioBytes.toByteArray());
+
+                                            Map<String, String> audioEvent = new HashMap<>();
+                                            audioEvent.put("name", "audio");
+                                            audioEvent.put("data", mergedAudioBase64);
+                                            emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
+                                                .data(audioEvent));
+
+                                            log.info("句子音频合成完成,合并后长度: {}, 释放锁", mergedAudioBase64.length());
+                                            latch.countDown();
+                                        }
+                                    } catch (Exception e) {
+                                        log.error("发送音频失败", e);
+                                        latch.countDown(); // 出错时也要释放锁
+                                    }
+                                });
+
+                                // 等待当前句子的音频合成完成
+                                try {
+                                    log.info("等待句子音频合成完成...");
+                                    boolean completed = latch.await(30, java.util.concurrent.TimeUnit.SECONDS);
+                                    if (completed) {
+                                        log.info("句子音频合成等待完成");
+                                    } else {
+                                        log.warn("句子音频合成等待超时");
+                                    }
+                                } catch (InterruptedException e) {
+                                    log.error("等待音频合成被中断", e);
+                                    Thread.currentThread().interrupt();
+                                }
+                            }
+
+                            // 如果是最后一个句子,发送完成事件
+                            if (isComplete) {
+                                log.info("收到完成标志,准备发送done事件");
+                                Map<String, String> doneEvent = new HashMap<>();
+                                doneEvent.put("name", "done");
+                                doneEvent.put("data", "");
+                                emitter.send(org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event()
+                                    .data(doneEvent));
+                                log.info("done事件已发送,关闭SSE连接");
+                                emitter.complete();
+                            }
+                        } catch (Exception e) {
+                            log.error("处理句子回调失败", e);
+                            emitter.completeWithError(e);
+                        }
+                    });
+
+            } catch (Exception e) {
+                log.error("流式处理失败", e);
+                try {
+                    emitter.completeWithError(e);
+                } catch (Exception ex) {
+                    log.error("发送错误失败", ex);
+                }
+            }
+        });
+    }
 }

+ 108 - 0
ruoyi-modules/yp-talk/src/main/java/org/dromara/talk/service/impl/DifyServiceImpl.java

@@ -142,4 +142,112 @@ public class DifyServiceImpl implements IDifyService {
             throw new RuntimeException("AI 服务调用失败", e);
         }
     }
+
+    @Override
+    public void callWorkflowStream(String userMessage, String agentGender,
+                                   java.util.List<java.util.Map<String, String>> ttsVcnList,
+                                   org.dromara.talk.domain.vo.TalkAgentVo agentConfig,
+                                   Long userId, String conversationId,
+                                   java.util.function.Consumer<String> onTextChunk,
+                                   IDifyService.SentenceCallback onSentence) {
+        try {
+            log.info("流式调用 Dify 工作流 - userId: {}, conversationId: {}", userId, conversationId);
+
+            // 构建请求体
+            Map<String, Object> inputs = new HashMap<>();
+            inputs.put("agentGender", agentGender);
+            inputs.put("ttsVcnList", ttsVcnList);
+            inputs.put("currentVcn", agentConfig != null ? agentConfig.getTtsVcn() : null);
+
+            JSONObject requestBody = new JSONObject();
+            requestBody.set("inputs", inputs);
+            requestBody.set("query", userMessage);
+            requestBody.set("response_mode", "streaming");
+            requestBody.set("user", "user-" + userId);
+
+            if (conversationId != null && !conversationId.isEmpty()) {
+                requestBody.set("conversation_id", conversationId);
+            }
+
+            Request request = new Request.Builder()
+                .url(difyConfig.getApiUrl() + "/chat-messages")
+                .post(RequestBody.create(
+                    requestBody.toString(),
+                    MediaType.parse("application/json")))
+                .addHeader("Authorization", "Bearer " + difyConfig.getApiKey())
+                .addHeader("Content-Type", "application/json")
+                .build();
+
+            Response response = httpClient.newCall(request).execute();
+
+            log.info("Dify API 响应状态码: {}", response.code());
+
+            if (!response.isSuccessful()) {
+                String errorBody = response.body() != null ? response.body().string() : "无响应体";
+                log.error("Dify API 调用失败,状态码: {}, 响应: {}", response.code(), errorBody);
+                throw new RuntimeException("Dify API 调用失败: " + response.code());
+            }
+
+            StringBuilder currentSentence = new StringBuilder();
+            String newConversationId = null;
+
+            try (ResponseBody responseBody = response.body()) {
+                if (responseBody == null) {
+                    throw new RuntimeException("响应体为空");
+                }
+
+                java.io.BufferedReader reader = new java.io.BufferedReader(
+                    new java.io.InputStreamReader(responseBody.byteStream(), java.nio.charset.StandardCharsets.UTF_8)
+                );
+
+                String line;
+                while ((line = reader.readLine()) != null) {
+                    if (line.startsWith("data: ")) {
+                        String jsonData = line.substring(6);
+
+                        try {
+                            JSONObject event = new JSONObject(jsonData);
+                            String eventType = event.getStr("event");
+
+                            if ("message".equals(eventType)) {
+                                String answer = event.getStr("answer");
+                                if (answer != null) {
+                                    currentSentence.append(answer);
+                                    onTextChunk.accept(answer);
+
+                                    // 检查是否包含句子结束符
+                                    String sentenceText = currentSentence.toString();
+                                    if (sentenceText.matches(".*[。!?.!?]\\s*$")) {
+                                        // 发现句子结束,触发回调
+                                        onSentence.onSentence(sentenceText.trim(), newConversationId, false);
+                                        currentSentence.setLength(0);
+                                    }
+                                }
+                            } else if ("message_end".equals(eventType)) {
+                                newConversationId = event.getStr("conversation_id");
+
+                                // 处理最后剩余的文本(如果有)
+                                if (currentSentence.length() > 0) {
+                                    onSentence.onSentence(currentSentence.toString().trim(), newConversationId, true);
+                                } else {
+                                    // 没有剩余文本,直接标记完成
+                                    onSentence.onSentence("", newConversationId, true);
+                                }
+                            } else if ("error".equals(eventType)) {
+                                String errorMessage = event.getStr("message");
+                                log.error("Dify API 返回错误: {}", errorMessage);
+                                throw new RuntimeException("Dify API 错误: " + errorMessage);
+                            }
+                        } catch (Exception e) {
+                            log.warn("解析 SSE 事件失败: {}", line, e);
+                        }
+                    }
+                }
+            }
+
+        } catch (Exception e) {
+            log.error("流式调用 Dify 工作流失败", e);
+            throw new RuntimeException("AI 服务调用失败", e);
+        }
+    }
 }

+ 50 - 0
ruoyi-modules/yp-talk/src/main/java/org/dromara/talk/service/impl/TtsServiceImpl.java

@@ -78,6 +78,56 @@ public class TtsServiceImpl implements ITtsService {
         }
     }
 
+    @Override
+    public void synthesizeStream(String text, TalkAgentVo agentConfig, java.util.function.BiConsumer<String, Integer> callback) {
+        try {
+            String wsUrl = getAuthUrl(HOST_URL, xunfeiConfig.getApiKey(), xunfeiConfig.getApiSecret()).replace("https://", "wss://");
+            URI uri = new URI(wsUrl);
+            String requestJson = buildRequest(text, agentConfig);
+
+            WebSocketClient client = new WebSocketClient(uri) {
+                @Override
+                public void onOpen(ServerHandshake handshake) {
+                    log.info("TTS WebSocket连接成功(流式)");
+                    send(requestJson);
+                }
+
+                @Override
+                public void onMessage(String message) {
+                    log.info("TTS流式收到响应: {}", message.length() > 200 ? message.substring(0, 200) + "..." : message);
+                    TtsResponse response = JSON.parseObject(message, TtsResponse.class);
+
+                    if (response.code != 0) {
+                        log.error("TTS返回错误: code={}, message={}", response.code, response.message);
+                        return;
+                    }
+
+                    if (response.data != null && response.data.audio != null) {
+                        log.info("TTS流式收到音频数据: status={}, 音频长度={}", response.data.status, response.data.audio.length());
+                        callback.accept(response.data.audio, response.data.status);
+                    } else {
+                        log.warn("TTS流式响应中没有音频数据: data={}", response.data);
+                    }
+                }
+
+                @Override
+                public void onClose(int code, String reason, boolean remote) {
+                    log.info("TTS WebSocket连接关闭(流式)");
+                }
+
+                @Override
+                public void onError(Exception e) {
+                    log.error("TTS WebSocket错误(流式)", e);
+                }
+            };
+
+            client.connect();
+
+        } catch (Exception e) {
+            log.error("TTS流式合成失败", e);
+        }
+    }
+
     private String buildRequest(String text, TalkAgentVo agentConfig) {
         // 使用客服配置的TTS参数,如果没有则使用默认值
         String vcn = agentConfig != null && agentConfig.getTtsVcn() != null ? agentConfig.getTtsVcn() : "x4_yezi";