跳到主要内容 Spring AI 自定义 Tool 调用返回值:实现 TodoList 提醒注入 | 极客日志
Java AI java
Spring AI 自定义 Tool 调用返回值:实现 TodoList 提醒注入 介绍如何在 Spring AI 中自定义 Tool 调用返回值。通过拦截 MiniMaxChatModel 层执行逻辑,利用 StreamAdvisor 接管工具调用,实现了连续三次未触发 todoUpdate 时自动注入提醒的功能。同时增加了 JSON 格式错误的重试机制,提升了大模型调用 Tool 的容错能力。
灰度发布 发布于 2026/3/21 更新于 2026/4/18 3 浏览Spring AI 中自定义 Tool 调用返回值——实现 TodoList 提醒注入
最近发现了一个极简 Claude Code 的文档,其中有一个实用技巧:如何在适当时机提醒 AI 更新 TodoList? 文档中的做法是:当连续三次工具调用都没有触发 todo 更新操作时,在 Function Call 返回值的第一个位置插入一条提醒:
<reminder > Update your todos.</reminder >
对应的 Python 实现如下:
if rounds_since_todo >= 3 and messages:
last = messages[-1 ]
if last["role" ] == "user" and isinstance (last.get("content" ), list ):
last["content" ].insert(0 , {"type" : "text" , "text" : "<reminder>Update your todos.</reminder>" })
那么在 Spring AI 中能否实现同样的效果?经过一番研究,答案是可以。本文记录实现过程。
依赖版本
对应的 SpringAI 版本和 SpringBoot 依赖:
<properties >
<java.version > 21</java.version >
<project.build.sourceEncoding > UTF-8</project.build.sourceEncoding >
<spring-boot.version > 4.0.1</spring-boot.version >
<spring-ai.version > 2.0.0-M2</spring-ai.version >
</properties >
org.springframework.ai
spring-ai-starter-model-minimax
org.springframework.boot
spring-boot-dependencies
${spring-boot.version}
pom
import
org.springframework.ai
spring-ai-bom
${spring-ai.version}
pom
import
微信扫一扫,关注极客日志 微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
相关免费在线工具 Keycode 信息 查找任何按下的键的javascript键代码、代码、位置和修饰符。 在线工具,Keycode 信息在线工具,online
Escape 与 Native 编解码 JavaScript 字符串转义/反转义;Java 风格 \uXXXX(Native2Ascii)编码与解码。 在线工具,Escape 与 Native 编解码在线工具,online
JavaScript / HTML 格式化 使用 Prettier 在浏览器内格式化 JavaScript 或 HTML 片段。 在线工具,JavaScript / HTML 格式化在线工具,online
JavaScript 压缩与混淆 Terser 压缩、变量名混淆,或 javascript-obfuscator 高强度混淆(体积会增大)。 在线工具,JavaScript 压缩与混淆在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
<dependencies >
<dependency >
<groupId >
</groupId >
<artifactId >
</artifactId >
</dependency >
</dependencies >
<dependencyManagement >
<dependencies >
<dependency >
<groupId >
</groupId >
<artifactId >
</artifactId >
<version >
</version >
<type >
</type >
<scope >
</scope >
</dependency >
<dependency >
<groupId >
</groupId >
<artifactId >
</artifactId >
<version >
</version >
<type >
</type >
<scope >
</scope >
</dependency >
</dependencies >
</dependencyManagement >
定义 TodoList Tool 首先需要定义供 LLM 调用的 Tool。以下是完整实现,包含读取和写入两个操作,并通过 Caffeine 本地缓存按会话隔离存储:
@Component
public class TodolistTools extends BaseTools {
private static final int MAX_TODOS = 20 ;
private static final Set<String> VALID_STATUSES = Set.of("pending" , "in_progress" , "completed" );
private static final Map<String, String> STATUS_MARKERS = Map.of(
"pending" , "[ ]" ,
"in_progress" , "[>]" ,
"completed" , "[x]"
);
private record TodoItem (String id, String text, String status) {}
private static final Cache<String, List<TodoItem>> TODOLIST_CACHE = Caffeine.newBuilder()
.maximumSize(1000 )
.expireAfterWrite(Duration.ofMinutes(30 ))
.build();
@Tool(description = "Update task list. Track progress on multi-step tasks. Pass the full list of items each time (replaces previous list). Each item must have id, text, status. Status: pending, in_progress, completed. Only one item can be in_progress.")
public String todoUpdate (@ToolParam(description = "Full list of todo items. Each item: id (string), text (string), status (pending|in_progress|completed).") List<Map<String, Object>> items, ToolContext toolContext) {
try {
String conversationId = ConversationUtils.getToolsContext(toolContext).appId();
if (items == null || items.isEmpty()) {
TODOLIST_CACHE.invalidate(conversationId);
return "No todos." ;
}
if (items.size() > MAX_TODOS) {
return "Error: Max " + MAX_TODOS + " todos allowed" ;
}
List<TodoItem> validated = validateAndConvert(items);
TODOLIST_CACHE.put(conversationId, validated);
return render(validated);
} catch (IllegalArgumentException e) {
return "Error: " + e.getMessage();
}
}
private List<TodoItem> validateAndConvert (List<Map<String, Object>> items) {
int inProgressCount = 0 ;
List<TodoItem> result = new ArrayList <>(items.size());
for (int i = 0 ; i < items.size(); i++) {
Map<String, Object> item = items.get(i);
String id = String.valueOf(item.getOrDefault("id" , String.valueOf(i + 1 ))).trim();
String text = String.valueOf(item.getOrDefault("text" , "" )).trim();
String status = String.valueOf(item.getOrDefault("status" , "pending" )).toLowerCase();
if (StringUtils.isBlank(text)) {
throw new IllegalArgumentException ("Item " + id + ": text required" );
}
if (!VALID_STATUSES.contains(status)) {
throw new IllegalArgumentException ("Item " + id + ": invalid status '" + status + "'" );
}
if ("in_progress" .equals(status)) {
inProgressCount++;
}
result.add(new TodoItem (id, text, status));
}
if (inProgressCount > 1 ) {
throw new IllegalArgumentException ("Only one task can be in_progress at a time" );
}
return result;
}
@Tool(description = "Read the current todo list for this conversation. Use this to check progress and see what tasks remain.")
public String todoRead (ToolContext toolContext) {
String conversationId = ConversationUtils.getToolsContext(toolContext).appId();
List<TodoItem> items = TODOLIST_CACHE.getIfPresent(conversationId);
return items == null || items.isEmpty() ? "No todos." : render(items);
}
private String render (List<TodoItem> items) {
if (items == null || items.isEmpty()) {
return "No todos." ;
}
StringBuilder sb = new StringBuilder ("\n\n" );
for (TodoItem item : items) {
String marker = STATUS_MARKERS.getOrDefault(item.status(), "[ ]" );
sb.append(marker).append(" #" ).append(item.id()).append(": " ).append(item.text()).append("\n\n" );
}
long done = items.stream().filter(t -> "completed" .equals(t.status())).count();
sb.append("(" ).append(done).append("/" ).append(items.size()).append(" completed)" );
return sb.append("\n\n" ).toString();
}
@Override
String getToolName () {
return "Todo List Tool" ;
}
@Override
String getToolDes () {
return "Read and write task todo lists to track progress" ;
}
}
问题分析:为什么不能在普通 Advisor 中拦截工具调用? 通过阅读源码 org.springframework.ai.minimax.MiniMaxChatModel#stream 可以发现,框架内部会在 ChatModel 层直接执行 Tool 调用 ,而不是将其透传给 Advisor 链。核心执行逻辑如下:
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (this .toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) {
return Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this .toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
return Flux.just(ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build());
}).subscribeOn(Schedulers.boundedElastic());
}
return Flux.just(response);
}).doOnError(observation::error).doFinally(signalType -> observation.stop()).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
这意味着,如果我们在外层 Advisor 中尝试拦截 tool_call,此时工具已经执行完毕,并且无法识别到工具调用。所以我准备使用我自己写的 MiniMaxChatModel 覆盖掉这个源码的逻辑,之后再 Advisor 接管这个 Tool 执行。
验证思路:能否通过 Advisor 接管工具执行? 我们需要在自己的项目目录创建一个 org.springframework.ai.minimax.MiniMaxChatModel 具体文件内容可以获取完整的代码。这样写好之后就可以让工具调用信号透传到 Advisor 层,判断是否有 Tool 调用。验证用的 Advisor 如下:
@Slf4j
public class FindToolAdvisor implements StreamAdvisor {
@Override
public Flux<ChatClientResponse> adviseStream (ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
return Flux.deferContextual(contextView -> {
log.info("Advising stream" );
return streamAdvisorChain.nextStream(chatClientRequest).doOnNext(streamResponse -> {
boolean hasToolCalls = streamResponse.chatResponse().hasToolCalls();
log.info("Found tool calls: {}" , hasToolCalls);
});
});
}
@Override
public String getName () {
return "FindToolAdvisor" ;
}
@Override
public int getOrder () {
return 0 ;
}
}
@Component
public class StreamApplication implements CommandLineRunner {
@Resource
private ChatModel chatModel;
@Override
public void run (String... args) throws Exception {
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultTools(FileSystemTools.builder().build())
.defaultAdvisors(new FindToolAdvisor ())
.build();
ChatClient.StreamResponseSpec stream = chatClient.prompt("""
帮我写一个简单的 HTML 页面,路径是 E:\TEMPLATE\spring-skills 不超过 300 行代码
""" ).stream();
stream.content().subscribe(System.out::println);
}
}
spring:
ai:
minimax:
api-key: sk-cp-xxxxx
chat:
options:
model: MiniMax-M2.5
测试结果证明工具调用信号可以被成功拦截,方案可行。
改造项目:实现 ExecuteToolAdvisor 参考 Spring AI 社区中一个尚未合并的 PR,我们实现了 ExecuteToolAdvisor,主要做了两件事:
工具调用 JSON 格式容错 :捕获 JSON 解析异常,最多重试 3 次再抛出,提升大模型调用 Tool 时格式不规范的容错能力。
TodoList 提醒注入 :连续 3 次工具调用均未触发 todoUpdate 时,在 ToolResponseMessage 的第一个位置注入提醒,引导 AI 及时更新任务列表。
⚠️ 注意 order 顺序 :由于该 Advisor 接管了工具执行,它的 order 值应尽量大(即靠后执行)。若 order 较小,可能导致后续 Advisor 的 doFinally 在每次工具调用时都被触发(比如后面的 buildAdvisor、versionAdvisor 只需要执行一次),而非在整个对话结束时触发一次。本实现中使用 Integer.MAX_VALUE - 100。
@Slf4j
@Component
public class ExecuteToolAdvisor implements StreamAdvisor {
private static final String TODO_REMINDER = "<reminder>Update your todos.</reminder>" ;
private static final String JSON_ERROR_MESSAGE = "Tool call JSON parse failed. Fix and retry.\nRules: strict RFC8259 JSON, no trailing commas, no comments, no unescaped control chars in strings (escape newlines as \\n, tabs as \\t), all keys double-quoted." ;
private static final int MAX_TOOL_RETRY = 3 ;
private static final int ORDER = Integer.MAX_VALUE - 100 ;
private static final String TODO_METHOD = "todoUpdate" ;
private static final int REMINDER_THRESHOLD = 3 ;
private final Cache<String, Integer> roundsSinceTodo = Caffeine.newBuilder()
.maximumSize(10_00 )
.expireAfterWrite(Duration.ofMinutes(30 ))
.build();
@Resource
private ToolCallingManager toolCallingManager;
@Override
public Flux<ChatClientResponse> adviseStream (ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
Assert.notNull(streamAdvisorChain, "streamAdvisorChain must not be null" );
Assert.notNull(chatClientRequest, "chatClientRequest must not be null" );
if (chatClientRequest.prompt().getOptions() == null || !(chatClientRequest.prompt().getOptions() instanceof ToolCallingChatOptions)) {
throw new IllegalArgumentException ("ExecuteToolAdvisor requires ToolCallingChatOptions to be set in the ChatClientRequest options." );
}
var optionsCopy = (ToolCallingChatOptions) chatClientRequest.prompt().getOptions().copy();
optionsCopy.setInternalToolExecutionEnabled(false );
return internalStream(streamAdvisorChain, chatClientRequest, optionsCopy, chatClientRequest.prompt().getInstructions(), 0 );
}
private Flux<ChatClientResponse> internalStream (StreamAdvisorChain streamAdvisorChain, ChatClientRequest originalRequest, ToolCallingChatOptions optionsCopy, List<Message> instructions, int jsonRetryCount) {
return Flux.deferContextual(contextView -> {
var processedRequest = ChatClientRequest.builder().prompt(new Prompt (instructions, optionsCopy)).context(originalRequest.context()).build();
StreamAdvisorChain chainCopy = streamAdvisorChain.copy(this );
Flux<ChatClientResponse> responseFlux = chainCopy.nextStream(processedRequest);
AtomicReference<ChatClientResponse> aggregatedResponseRef = new AtomicReference <>();
AtomicReference<List<ChatClientResponse>> chunksRef = new AtomicReference <>(new ArrayList <>());
return new ChatClientMessageAggregator ().aggregateChatClientResponse(responseFlux, aggregatedResponseRef::set)
.doOnNext(chunk -> chunksRef.get().add(chunk))
.ignoreElements()
.cast(ChatClientResponse.class)
.concatWith(Flux.defer(() -> processAggregatedResponse(aggregatedResponseRef.get(), chunksRef.get(), processedRequest, streamAdvisorChain, originalRequest, optionsCopy, jsonRetryCount)));
});
}
private Flux<ChatClientResponse> processAggregatedResponse (ChatClientResponse aggregatedResponse, List<ChatClientResponse> chunks, ChatClientRequest finalRequest, StreamAdvisorChain streamAdvisorChain, ChatClientRequest originalRequest, ToolCallingChatOptions optionsCopy, int retryCount) {
if (aggregatedResponse == null ) {
return Flux.fromIterable(chunks);
}
ChatResponse chatResponse = aggregatedResponse.chatResponse();
boolean isToolCall = chatResponse != null && chatResponse.hasToolCalls();
if (isToolCall) {
Assert.notNull(chatResponse, "chatResponse must not be null when hasToolCalls is true" );
ChatClientResponse finalAggregatedResponse = aggregatedResponse;
Flux<ChatClientResponse> toolCallFlux = Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = toolCallingManager.executeToolCalls(finalRequest.prompt(), chatResponse);
} catch (Exception e) {
if (retryCount < MAX_TOOL_RETRY) {
List<Message> retryInstructions = buildRetryInstructions(finalRequest, chatResponse, e);
if (retryInstructions != null ) {
return internalStream(streamAdvisorChain, originalRequest, optionsCopy, retryInstructions, retryCount + 1 );
}
}
throw e;
} finally {
ToolCallReactiveContextHolder.clearContext();
}
List<Message> historyWithReminder = injectReminderIntoConversationHistory(toolExecutionResult.conversationHistory(), getAppId(finalRequest));
if (toolExecutionResult.returnDirect()) {
return Flux.just(buildReturnDirectResponse(finalAggregatedResponse, chatResponse, toolExecutionResult, historyWithReminder));
}
return internalStream(streamAdvisorChain, originalRequest, optionsCopy, historyWithReminder, 0 );
});
return toolCallFlux.subscribeOn(Schedulers.boundedElastic());
}
return Flux.fromIterable(chunks);
}
private String getAppId (ChatClientRequest finalRequest) {
if (finalRequest.prompt().getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
return toolCallingChatOptions.getToolContext().get(CONVERSATION_ID).toString();
}
throw new BusinessException (ErrorCode.SYSTEM_ERROR);
}
private static List<Message> buildRetryInstructions (ChatClientRequest finalRequest, ChatResponse chatResponse, Throwable error) {
AssistantMessage assistantMessage = extractAssistantMessage(chatResponse);
if (assistantMessage == null || assistantMessage.getToolCalls() == null || assistantMessage.getToolCalls().isEmpty()) {
return null ;
}
List<Message> instructions = new ArrayList <>(finalRequest.prompt().getInstructions());
instructions.add(assistantMessage);
String errorMessage = buildJsonErrorMessage(error);
List<ToolResponseMessage.ToolResponse> responses = assistantMessage.getToolCalls().stream()
.map(toolCall -> new ToolResponseMessage .ToolResponse(toolCall.id(), toolCall.name(), errorMessage))
.toList();
instructions.add(ToolResponseMessage.builder().responses(responses).build());
return instructions;
}
private static AssistantMessage extractAssistantMessage (ChatResponse chatResponse) {
if (chatResponse == null ) {
return null ;
}
Generation result = chatResponse.getResult();
if (result != null && result.getOutput() != null ) {
return result.getOutput();
}
List<Generation> results = chatResponse.getResults();
if (results != null && !results.isEmpty() && results.get(0 ).getOutput() != null ) {
return results.get(0 ).getOutput();
}
return null ;
}
private static String buildJsonErrorMessage (Throwable error) {
String detail = ExceptionUtils.getRootCauseMessage(error);
if (detail.isBlank()) {
return JSON_ERROR_MESSAGE;
}
return JSON_ERROR_MESSAGE + "\nError: " + detail;
}
private List<Message> injectReminderIntoConversationHistory (List<Message> conversationHistory, String appId) {
if (conversationHistory == null || conversationHistory.isEmpty()) {
return conversationHistory;
}
if (!(conversationHistory.getLast() instanceof ToolResponseMessage toolMsg)) {
return conversationHistory;
}
List<ToolResponseMessage.ToolResponse> responses = toolMsg.getResponses();
if (responses.isEmpty()) {
return conversationHistory;
}
ToolResponseMessage.ToolResponse firstResponse = responses.getFirst();
if (!updateRoundsAndCheckReminder(appId, firstResponse.name())) {
return conversationHistory;
}
List<ToolResponseMessage.ToolResponse> newResponses = new ArrayList <>(responses);
ToolResponseMessage.ToolResponse actualRes = newResponses.removeFirst();
newResponses.add(new ToolResponseMessage .ToolResponse(firstResponse.id(), "text" , TODO_REMINDER));
newResponses.add(actualRes);
List<Message> result = new ArrayList <>(conversationHistory.subList(0 , conversationHistory.size() - 1 ));
result.add(ToolResponseMessage.builder().responses(newResponses).build());
return result;
}
private static ChatClientResponse buildReturnDirectResponse (ChatClientResponse aggregatedResponse, ChatResponse chatResponse, ToolExecutionResult originalResult, List<Message> historyWithReminder) {
ToolExecutionResult resultWithReminder = ToolExecutionResult.builder()
.conversationHistory(historyWithReminder)
.returnDirect(originalResult.returnDirect())
.build();
ChatResponse newChatResponse = ChatResponse.builder()
.from(chatResponse)
.generations(ToolExecutionResult.buildGenerations(resultWithReminder))
.build();
return aggregatedResponse.mutate().chatResponse(newChatResponse).build();
}
private boolean updateRoundsAndCheckReminder (String appId, String methodName) {
if (TODO_METHOD.equals(methodName)) {
roundsSinceTodo.put(appId, 0 );
return false ;
}
int count = roundsSinceTodo.asMap().merge(appId, 1 , Integer::sum);
return count >= REMINDER_THRESHOLD;
}
@Override
public String getName () {
return "ExecuteToolAdvisor" ;
}
@Override
public int getOrder () {
return ORDER;
}
}
因为这个 Advisor 也使用到了 StreamAdvisorChain 接口的 copy 所以我们需要覆盖源码的这个 StreamAdvisorChain 并且实现对应的接口,下面的代码包路径是 org.springframework.ai.chat.client.advisor.api 具体的代码:
public interface StreamAdvisorChain extends AdvisorChain {
Flux<ChatClientResponse> nextStream (ChatClientRequest chatClientRequest) ;
List<StreamAdvisor> getStreamAdvisors () ;
StreamAdvisorChain copy (StreamAdvisor after) ;
}
下面的包位置是 org.springframework.ai.chat.client.advisor 具体的实现代码:
public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention ();
private static final ChatClientMessageAggregator CHAT_CLIENT_MESSAGE_AGGREGATOR = new ChatClientMessageAggregator ();
private final List<CallAdvisor> originalCallAdvisors;
private final List<StreamAdvisor> originalStreamAdvisors;
private final Deque<CallAdvisor> callAdvisors;
private final Deque<StreamAdvisor> streamAdvisors;
private final ObservationRegistry observationRegistry;
private final AdvisorObservationConvention observationConvention;
DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, Deque<CallAdvisor> callAdvisors, Deque<StreamAdvisor> streamAdvisors, @Nullable AdvisorObservationConvention observationConvention) {
Assert.notNull(observationRegistry, "the observationRegistry must be non-null" );
Assert.notNull(callAdvisors, "the callAdvisors must be non-null" );
Assert.notNull(streamAdvisors, "the streamAdvisors must be non-null" );
this .observationRegistry = observationRegistry;
this .callAdvisors = callAdvisors;
this .streamAdvisors = streamAdvisors;
this .originalCallAdvisors = List.copyOf(callAdvisors);
this .originalStreamAdvisors = List.copyOf(streamAdvisors);
this .observationConvention = observationConvention != null ? observationConvention : DEFAULT_OBSERVATION_CONVENTION;
}
public static Builder builder (ObservationRegistry observationRegistry) {
return new Builder (observationRegistry);
}
@Override
public ChatClientResponse nextCall (ChatClientRequest chatClientRequest) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null" );
if (this .callAdvisors.isEmpty()) {
throw new IllegalStateException ("No CallAdvisors available to execute" );
}
var advisor = this .callAdvisors.pop();
var observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.chatClientRequest(chatClientRequest)
.order(advisor.getOrder())
.build();
return AdvisorObservationDocumentation.AI_ADVISOR.observation(this .observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this .observationRegistry)
.observe(() -> {
var chatClientResponse = advisor.adviseCall(chatClientRequest, this );
observationContext.setChatClientResponse(chatClientResponse);
return chatClientResponse;
});
}
@Override
public Flux<ChatClientResponse> nextStream (ChatClientRequest chatClientRequest) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null" );
return Flux.deferContextual(contextView -> {
if (this .streamAdvisors.isEmpty()) {
return Flux.error(new IllegalStateException ("No StreamAdvisors available to execute" ));
}
var advisor = this .streamAdvisors.pop();
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.chatClientRequest(chatClientRequest)
.order(advisor.getOrder())
.build();
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(this .observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this .observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null )).start();
Flux<ChatClientResponse> chatClientResponse = Flux.defer(() -> advisor.adviseStream(chatClientRequest, this ).doOnError(observation::error).doFinally(s -> observation.stop()).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)));
return CHAT_CLIENT_MESSAGE_AGGREGATOR.aggregateChatClientResponse(chatClientResponse, observationContext::setChatClientResponse);
});
}
@Override
public CallAdvisorChain copy (CallAdvisor after) {
return this .copyAdvisorsAfter(this .getCallAdvisors(), after);
}
@Override
public StreamAdvisorChain copy (StreamAdvisor after) {
return this .copyAdvisorsAfter(this .getStreamAdvisors(), after);
}
private DefaultAroundAdvisorChain copyAdvisorsAfter (List<? extends Advisor> advisors, Advisor after) {
Assert.notNull(after, "The after advisor must be non-null" );
Assert.notNull(advisors, "The advisors must be non-null" );
int afterAdvisorIndex = advisors.indexOf(after);
if (afterAdvisorIndex < 0 ) {
throw new IllegalArgumentException ("The specified advisor is not part of the chain: " + after.getName());
}
var remainingStreamAdvisors = advisors.subList(afterAdvisorIndex + 1 , advisors.size());
return DefaultAroundAdvisorChain.builder(this .getObservationRegistry()).pushAll(remainingStreamAdvisors).build();
}
@Override
public List<CallAdvisor> getCallAdvisors () {
return this .originalCallAdvisors;
}
@Override
public List<StreamAdvisor> getStreamAdvisors () {
return this .originalStreamAdvisors;
}
@Override
public ObservationRegistry getObservationRegistry () {
return this .observationRegistry;
}
public static final class Builder {
private final ObservationRegistry observationRegistry;
private final Deque<CallAdvisor> callAdvisors;
private final Deque<StreamAdvisor> streamAdvisors;
private @Nullable AdvisorObservationConvention observationConvention;
public Builder (ObservationRegistry observationRegistry) {
this .observationRegistry = observationRegistry;
this .callAdvisors = new ConcurrentLinkedDeque <>();
this .streamAdvisors = new ConcurrentLinkedDeque <>();
}
public Builder observationConvention (@Nullable AdvisorObservationConvention observationConvention) {
this .observationConvention = observationConvention;
return this ;
}
public Builder push (Advisor advisor) {
Assert.notNull(advisor, "the advisor must be non-null" );
return this .pushAll(List.of(advisor));
}
public Builder pushAll (List<? extends Advisor> advisors) {
Assert.notNull(advisors, "the advisors must be non-null" );
Assert.noNullElements(advisors, "the advisors must not contain null elements" );
if (!CollectionUtils.isEmpty(advisors)) {
List<CallAdvisor> callAroundAdvisorList = advisors.stream().filter(a -> a instanceof CallAdvisor).map(a -> (CallAdvisor) a).toList();
if (!CollectionUtils.isEmpty(callAroundAdvisorList)) {
callAroundAdvisorList.forEach(this .callAdvisors::push);
}
List<StreamAdvisor> streamAroundAdvisorList = advisors.stream().filter(a -> a instanceof StreamAdvisor).map(a -> (StreamAdvisor) a).toList();
if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) {
streamAroundAdvisorList.forEach(this .streamAdvisors::push);
}
this .reOrder();
}
return this ;
}
private void reOrder () {
ArrayList<CallAdvisor> callAdvisors = new ArrayList <>(this .callAdvisors);
OrderComparator.sort(callAdvisors);
this .callAdvisors.clear();
callAdvisors.forEach(this .callAdvisors::addLast);
ArrayList<StreamAdvisor> streamAdvisors = new ArrayList <>(this .streamAdvisors);
OrderComparator.sort(streamAdvisors);
this .streamAdvisors.clear();
streamAdvisors.forEach(this .streamAdvisors::addLast);
}
public DefaultAroundAdvisorChain build () {
return new DefaultAroundAdvisorChain (this .observationRegistry, this .callAdvisors, this .streamAdvisors, this .observationConvention);
}
}
}
效果验证&前端展示 连续三次未更新 TodoList 时触发提醒注入: