Parcourir la source

0816 优化消息处理

Qing il y a 8 mois
Parent
commit
0537e7fb64

+ 38 - 5
consumer-service-demo/spring-ai-demo/src/main/java/com/sf/ai/controller/ChatWebController.java

@@ -1,15 +1,21 @@
 package com.sf.ai.controller;
 
 import com.google.gson.Gson;
-import com.google.gson.internal.LinkedTreeMap;
+import com.sf.ai.dto.ChatInfoDto;
 import com.sf.ai.dto.ChatReqDto;
 import lombok.RequiredArgsConstructor;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.ai.openai.OpenAiChatModel;
 import org.springframework.stereotype.Controller;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.ResponseBody;
 
+import java.util.ArrayList;
 import java.util.List;
 
 // 要返回的是页面
@@ -26,13 +32,40 @@ public class ChatWebController {
         return "chat";
     }
 
+    // 设定消息列表的最大长度
+    static int maxLen = 10;
+
     @PostMapping("/chat")
     @ResponseBody
     public String chat(ChatReqDto chatReqDto) {
         System.out.println("chatReqDto = " + chatReqDto);
-        List<LinkedTreeMap> list = new Gson().fromJson(chatReqDto.getPrompts(), List.class);
-        String content = (String) list.get(0).get("content");
-        String called = chatModel.call(content);
-        return called;
+        // 由前端传递历史消息
+        List<Message> historyMessages = new ArrayList<>();
+        // [{"role":"user","content":"讲一个笑话"},{"role":"user","content":"再讲一个笑话"}]
+        String prompts = chatReqDto.getPrompts();
+        ChatInfoDto[] chatInfoDtos = new Gson().fromJson(prompts, ChatInfoDto[].class);
+        for (ChatInfoDto chatInfoDto : chatInfoDtos) {
+            if(chatInfoDto.getRole().equals("user")) {
+                UserMessage userMessage = new UserMessage(chatInfoDto.getContent());
+                historyMessages.add(userMessage);
+            }else if(chatInfoDto.getRole().equals("assistant")) {
+                AssistantMessage assistantMessage = new AssistantMessage(chatInfoDto.getContent());
+                historyMessages.add(assistantMessage);
+            }
+        }
+
+        // 对消息列表的长度进行检查
+        if(historyMessages.size() > maxLen) {
+            // 截取最近的10条
+            historyMessages = historyMessages.subList(historyMessages.size() - maxLen,
+                    historyMessages.size());
+        }
+
+        Prompt prompt = new Prompt(historyMessages);
+        ChatResponse response = chatModel.call(prompt);
+        AssistantMessage assistantMessage = response.getResult().getOutput();
+
+        // 将本次问题的答案返回
+        return assistantMessage.getContent();
     }
 }

+ 11 - 0
consumer-service-demo/spring-ai-demo/src/main/java/com/sf/ai/dto/ChatInfoDto.java

@@ -0,0 +1,11 @@
+package com.sf.ai.dto;
+
+import lombok.Data;
+
+@Data
+public class ChatInfoDto {
+
+    // [{"role":"user","content":"讲一个笑话"}]
+    private String role;
+    private String content;
+}