|
@@ -1,15 +1,21 @@
|
|
package com.sf.ai.controller;
|
|
package com.sf.ai.controller;
|
|
|
|
|
|
import com.google.gson.Gson;
|
|
import com.google.gson.Gson;
|
|
-import com.google.gson.internal.LinkedTreeMap;
|
|
|
|
|
|
+import com.sf.ai.dto.ChatInfoDto;
|
|
import com.sf.ai.dto.ChatReqDto;
|
|
import com.sf.ai.dto.ChatReqDto;
|
|
import lombok.RequiredArgsConstructor;
|
|
import lombok.RequiredArgsConstructor;
|
|
|
|
+import org.springframework.ai.chat.messages.AssistantMessage;
|
|
|
|
+import org.springframework.ai.chat.messages.Message;
|
|
|
|
+import org.springframework.ai.chat.messages.UserMessage;
|
|
|
|
+import org.springframework.ai.chat.model.ChatResponse;
|
|
|
|
+import org.springframework.ai.chat.prompt.Prompt;
|
|
import org.springframework.ai.openai.OpenAiChatModel;
|
|
import org.springframework.ai.openai.OpenAiChatModel;
|
|
import org.springframework.stereotype.Controller;
|
|
import org.springframework.stereotype.Controller;
|
|
import org.springframework.web.bind.annotation.GetMapping;
|
|
import org.springframework.web.bind.annotation.GetMapping;
|
|
import org.springframework.web.bind.annotation.PostMapping;
|
|
import org.springframework.web.bind.annotation.PostMapping;
|
|
import org.springframework.web.bind.annotation.ResponseBody;
|
|
import org.springframework.web.bind.annotation.ResponseBody;
|
|
|
|
|
|
|
|
+import java.util.ArrayList;
|
|
import java.util.List;
|
|
import java.util.List;
|
|
|
|
|
|
// 要返回的是页面
|
|
// 要返回的是页面
|
|
@@ -26,13 +32,40 @@ public class ChatWebController {
|
|
return "chat";
|
|
return "chat";
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // 设定消息列表的最大长度
|
|
|
|
+ static int maxLen = 10;
|
|
|
|
+
|
|
@PostMapping("/chat")
|
|
@PostMapping("/chat")
|
|
@ResponseBody
|
|
@ResponseBody
|
|
public String chat(ChatReqDto chatReqDto) {
|
|
public String chat(ChatReqDto chatReqDto) {
|
|
System.out.println("chatReqDto = " + chatReqDto);
|
|
System.out.println("chatReqDto = " + chatReqDto);
|
|
- List<LinkedTreeMap> list = new Gson().fromJson(chatReqDto.getPrompts(), List.class);
|
|
|
|
- String content = (String) list.get(0).get("content");
|
|
|
|
- String called = chatModel.call(content);
|
|
|
|
- return called;
|
|
|
|
|
|
+ // 由前端传递历史消息
|
|
|
|
+ List<Message> historyMessages = new ArrayList<>();
|
|
|
|
+ // [{"role":"user","content":"讲一个笑话"},{"role":"user","content":"再讲一个笑话"}]
|
|
|
|
+ String prompts = chatReqDto.getPrompts();
|
|
|
|
+ ChatInfoDto[] chatInfoDtos = new Gson().fromJson(prompts, ChatInfoDto[].class);
|
|
|
|
+ for (ChatInfoDto chatInfoDto : chatInfoDtos) {
|
|
|
|
+ if(chatInfoDto.getRole().equals("user")) {
|
|
|
|
+ UserMessage userMessage = new UserMessage(chatInfoDto.getContent());
|
|
|
|
+ historyMessages.add(userMessage);
|
|
|
|
+ }else if(chatInfoDto.getRole().equals("assistant")) {
|
|
|
|
+ AssistantMessage assistantMessage = new AssistantMessage(chatInfoDto.getContent());
|
|
|
|
+ historyMessages.add(assistantMessage);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 对消息列表的长度进行检查
|
|
|
|
+ if(historyMessages.size() > maxLen) {
|
|
|
|
+ // 截取最近的10条
|
|
|
|
+ historyMessages = historyMessages.subList(historyMessages.size() - maxLen,
|
|
|
|
+ historyMessages.size());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ Prompt prompt = new Prompt(historyMessages);
|
|
|
|
+ ChatResponse response = chatModel.call(prompt);
|
|
|
|
+ AssistantMessage assistantMessage = response.getResult().getOutput();
|
|
|
|
+
|
|
|
|
+ // 将本次问题的答案返回
|
|
|
|
+ return assistantMessage.getContent();
|
|
}
|
|
}
|
|
}
|
|
}
|