Spring AI Alibaba Graph 初探
介绍 Spring AI Alibaba Graph 框架的使用。涵盖核心概念如 StateGraph、NodeAction,通过快速入门示例展示依赖配置与状态图定义。包含英语学习助手案例,演示条件边与循环边的实现逻辑,以及状态存储与图的可视化打印方法。适合希望构建复杂 Agent 工作流的开发者参考。

介绍 Spring AI Alibaba Graph 框架的使用。涵盖核心概念如 StateGraph、NodeAction,通过快速入门示例展示依赖配置与状态图定义。包含英语学习助手案例,演示条件边与循环边的实现逻辑,以及状态存储与图的可视化打印方法。适合希望构建复杂 Agent 工作流的开发者参考。



实现如下工作流: 开始节点→node1→node2→结束节点 用 node2 的值替换 node1 的值
spring-boot: 3.4.0 spring-ai-alibaba: 1.0.0.4
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-zhipuai</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-graph-core</artifactId>
</dependency>
server:
port: 8889
spring:
application:
name: agent-graph
ai:
zhipuai:
api-key: ${ZHIPU_KEY}
chat:
options:
model: glm-4-flash

GraphConfig.java
@Configuration
@Slf4j
public class GraphConfig {
@Bean("quickStartGraph")
public CompiledGraph quickStartGraph() throws GraphStateException {
KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
@Override
public Map<String, KeyStrategy> apply() {
// ReplaceStrategy 为替换策略
return Map.of("input1", new ReplaceStrategy(), "input2", new ReplaceStrategy());
}
};
// 定义状态图
StateGraph stateGraph = new StateGraph("quickStartGraph", keyStrategyFactory);
// 添加节点
// AsyncNodeAction.node_async 为异步执行
stateGraph.addNode("node1", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
log.info("node1 state: {}", state);
return Map.of("input1", 1, "input2", 1);
}
}));
stateGraph.addNode("node2", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
log.info("node2 state: {}", state);
return Map.of("input1", 2, "input2", 2);
}
}));
// 定义边
stateGraph.addEdge(StateGraph.START, "node1");
stateGraph.addEdge("node1", "node2");
stateGraph.addEdge("node2", StateGraph.END);
// 编译状态图
return stateGraph.compile();
}
}
GraphController.java
@RestController
@RequestMapping("/graph")
@Slf4j
public class GraphController {
private final CompiledGraph compiledGraph;
public GraphController(CompiledGraph compiledGraph) {
this.compiledGraph = compiledGraph;
}
@GetMapping("/quickStartGraph")
public String quickStartGraph() {
Optional<OverAllState> overAllStateOptional = compiledGraph.call(Map.of());
log.info("overAllStateOptional: {}", overAllStateOptional);
return "OK";
}
}

发现 input1 和 input2 的值被成功替换为 2


状态图的抽象,需要配置状态 (通过 KeyStrategyFactory), 节点,边。 配置好后通过 compile 方法编译成 CompiledGraph 后才可以供调用。
CompiledGraph 是 StateGraph 编译后的结果,CompiledGraph 才能用了执行。 一般我们是把 StateGraph 定义好后调用其 compile 方法得到一个 CompiledGraph 放入 Spring 容器中然后在需要的时候从容器中注入然后再调用。

使用 Graph 开发一个英语学习小助手。 功能如下:输入一个单词,能基于这个单词造句,然后再对句子进行翻译,把造句的译文也返回。
我们可以定义一个工作流,工作流中主要有两个节点: SentenceConstructionNode 造句节点,拿输入的单词让 LLM 进行造句。 TranslationNode 翻译节点,能够把一个英文句子翻译成中文。最终把造句的结果和翻译的结果返回即可。
开始节点(输入一个单词)–>造句节点(根据给定的单词进行造句)–>翻译节点(对句子进行翻译)–>结束节点(输出造句和翻译的结果)

public class SentenceConstructionNode implements NodeAction {
private final ChatClient chatClient;
public SentenceConstructionNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从 stage 中获取要造句的单词
String word = state.value("word", "");
// 定义提示词
PromptTemplate promptTemplate = new PromptTemplate(
"你是一个英语造句专家,能够基于给定的单词进行造句。" +
"要求只返回最终造好的句子,不要返回其他信息。给定的单词:{word}"
);
promptTemplate.add("word", word);
// 替换占位符
String prompt = promptTemplate.render();
// 渲染提示词
// 模型调用
String content = chatClient.prompt().user(prompt).call().content();
// 把句子存入 stage
return Map.of("sentence", content);
}
}
public class TranslationNode implements NodeAction {
private final ChatClient chatClient;
public TranslationNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从 stage 中获取要翻译的句子
String sentence = state.value("sentence", "");
// 定义提示词
PromptTemplate promptTemplate = new PromptTemplate(
"你是一个英语翻译专家,能够把英文翻译成中文。" +
"要求只返回翻译的中文结果,不要返回英文原句。要翻译的英文句子:{sentence}"
);
promptTemplate.add("sentence", sentence);
// 替换占位符
String prompt = promptTemplate.render();
// 渲染提示词
// 模型调用
String content = chatClient.prompt().user(prompt).call().content();
// 把翻译结果存入 stage
return Map.of("translation", content);
}
}
config/GraphConfig.java,在 quickStartGraph 下面增加如下内容

@Bean("simpleGraph")
public CompiledGraph simpleGraph(ChatClient.Builder clientBuilder) throws GraphStateException {
KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
keyStrategyHashMap.put("word", new ReplaceStrategy());
keyStrategyHashMap.put("sentence", new ReplaceStrategy());
keyStrategyHashMap.put("translation", new ReplaceStrategy());
return keyStrategyHashMap;
};
// 创建状态图
StateGraph stateGraph = new StateGraph("simpleGraph", keyStrategyFactory);
// 添加节点
stateGraph.addNode("SentenceConstructionNode", AsyncNodeAction.node_async(new SentenceConstructionNode(clientBuilder)));
stateGraph.addNode("TranslationNode", AsyncNodeAction.node_async(new TranslationNode(clientBuilder)));
// 定义边
stateGraph.addEdge(StateGraph.START, "SentenceConstructionNode");
stateGraph.addEdge("SentenceConstructionNode", "TranslationNode");
stateGraph.addEdge("TranslationNode", StateGraph.END);
// 编译状态图,放入容器
return stateGraph.compile();
}

@RestController
@RequestMapping("/graph")
@Slf4j
public class GraphController {
private final CompiledGraph compiledGraph;
private final CompiledGraph simpleGraph;
public GraphController(@Qualifier("quickStartGraph") CompiledGraph compiledGraph,
@Qualifier("simpleGraph") CompiledGraph simpleGraph) {
this.compiledGraph = compiledGraph;
this.simpleGraph = simpleGraph;
}
@GetMapping("/quickStartGraph")
public String quickStartGraph() {
Optional<OverAllState> overAllStateOptional = compiledGraph.call(Map.of());
log.info("overAllStateOptional: {}", overAllStateOptional);
return "OK";
}
@GetMapping("/simpleGraph")
public Map<String, Object> simpleGraph(@RequestParam("word") String word) {
Optional<OverAllState> overAllStateOptional = simpleGraph.call(Map.of("word", word));
Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
return data;
}
}




public class GenerateJokeNode implements NodeAction {
private final ChatClient chatClient;
public GenerateJokeNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从 stage 中获取笑话主题
String topic = state.value("topic", "");
// 定义提示词
PromptTemplate promptTemplate = new PromptTemplate(
"你需要写一个关于指定主题的短笑话。要求返回的结果中只能包含笑话的内容" +
"主题:{topic}"
);
promptTemplate.add("topic", topic);
// 替换占位符
String prompt = promptTemplate.render();
// 渲染提示词
// 模型调用
String content = chatClient.prompt().user(prompt).call().content();
// 把结果存入 stage
return Map.of("joke", content);
}
}
public class EvaluateJokesNode implements NodeAction {
private final ChatClient chatClient;
public EvaluateJokesNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从 stage 中获取待评估笑话
String joke = state.value("joke", "");
// 定义提示词
PromptTemplate promptTemplate = new PromptTemplate(
"你是一个笑话评分专家,能够对笑话进行评分,基于效果的搞笑程度给出 0 到 10 分的打分。" +
"0 到 3 分是不够优秀,4 到 10 分是优秀。要求结果只返回优秀或者不够优秀,不能输出其他内容。" +
"要评分的笑话:{joke}"
);
promptTemplate.add("joke", joke);
// 替换占位符
String prompt = promptTemplate.render();
// 渲染提示词
// 模型调用
String content = chatClient.prompt().user(prompt).call().content();
// 把结果存入 stage
return Map.of("result", content.trim());
}
}
public class EnhanceJokeQualityNode implements NodeAction {
private final ChatClient chatClient;
public EnhanceJokeQualityNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从 stage 中获取待评估笑话
String joke = state.value("joke", "");
// 定义提示词
PromptTemplate promptTemplate = new PromptTemplate(
"你是一个笑话优化专家,你能够优化笑话,让它更加搞笑" +
"要优化的话:{joke}"
);
promptTemplate.add("joke", joke);
// 替换占位符
String prompt = promptTemplate.render();
// 渲染提示词
// 模型调用
String content = chatClient.prompt().user(prompt).call().content();
// 把结果存入 stage
return Map.of("newJoke", content);
}
}
@Bean("conditionalGraph")
public CompiledGraph conditionalGraph(ChatClient.Builder clientBuilder) throws GraphStateException {
KeyStrategyFactory keyStrategyFactory = () -> Map.of("topic", new ReplaceStrategy());
// 定义状态图
StateGraph stateGraph = new StateGraph("conditionalGraph", keyStrategyFactory);
// 定义节点
stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(clientBuilder)));
stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new EvaluateJokesNode(clientBuilder)));
stateGraph.addNode("优化笑话", AsyncNodeAction.node_async(new EnhanceJokeQualityNode(clientBuilder)));
// 定义边
stateGraph.addEdge(StateGraph.START, "生成笑话");
stateGraph.addEdge("生成笑话", "评估笑话");
stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(
state -> state.value("result", "优秀"),
Map.of("优秀", StateGraph.END, "不够优秀", "优化笑话")
));
stateGraph.addEdge("优化笑话", StateGraph.END);
return stateGraph.compile();
}
private final CompiledGraph compiledGraph;
private final CompiledGraph simpleGraph;
private final CompiledGraph conditionalGraph;
public GraphController(@Qualifier("quickStartGraph") CompiledGraph compiledGraph,
@Qualifier("simpleGraph") CompiledGraph simpleGraph,
@Qualifier("conditionalGraph") CompiledGraph conditionalGraph) {
this.compiledGraph = compiledGraph;
this.simpleGraph = simpleGraph;
this.conditionalGraph = conditionalGraph;
}
@GetMapping("/conditionalGraph")
public Map<String, Object> conditionalGraph(@RequestParam("topic") String topic) {
Optional<OverAllState> overAllStateOptional = conditionalGraph.call(Map.of("topic", topic));
Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
return data;
}
GET 方式:http://localhost:8889/graph/conditionalGraph?topic=爱情 评估结果是优秀

直接输出结果

在断点处右击,选择'Evaluate Expression'

篡改评估结果为'不够优秀',回车后关闭

修改成功

就会走优化节点,生成新的笑话


@Slf4j
public class LoopEvaluateJokesNode implements NodeAction {
private final ChatClient chatClient;
private final Integer targetScore;
private final Integer maxLoopCount;
public LoopEvaluateJokesNode(ChatClient.Builder builder, Integer targetScore, Integer maxLoopCount) {
this.chatClient = builder.build();
this.targetScore = targetScore;
this.maxLoopCount = maxLoopCount;
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从 stage 中获取待评估笑话
String joke = state.value("joke", "");
// 循环次数
Integer loopCount = state.value("loopCount", 0);
// 定义提示词
PromptTemplate promptTemplate = new PromptTemplate(
"你是一个笑话评分专家,能够对笑话进行评分,基于效果的搞笑程度给出 0 到 10 分的打分。" +
"要求结果只返回最后的打分,打分必须是整数,不能输出其他内容。" +
"要评分的笑话:{joke}"
);
promptTemplate.add("joke", joke);
// 替换占位符
String prompt = promptTemplate.render();
// 渲染提示词
// 模型调用
String content = chatClient.prompt().user(prompt).call().content();
// content 转为整数
Integer score = Integer.parseInt(content.trim());
log.info("joke: {},score: {},循环次数:{}", joke, score, loopCount);
// 根据分数判断是否继续循环,循环最多执行 5 次
String result = "loop";
if (score >= targetScore || loopCount >= maxLoopCount) {
result = "break";
}
loopCount++;
// 把结果存入 stage
return Map.of("result", result, "loopCount", loopCount);
}
}
@Bean("loopGraph")
public CompiledGraph loopGraph(ChatClient.Builder clientBuilder) throws GraphStateException {
KeyStrategyFactory keyStrategyFactory = () -> Map.of("topic", new ReplaceStrategy());
// 定义状态图
StateGraph stateGraph = new StateGraph("loopGraph", keyStrategyFactory);
// 定义节点
stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(clientBuilder)));
stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new LoopEvaluateJokesNode(clientBuilder, 8, 5)));
// 定义边
stateGraph.addEdge(StateGraph.START, "生成笑话");
stateGraph.addEdge("生成笑话", "评估笑话");
stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(
state -> state.value("result", "loop"),
Map.of("loop", "生成笑话", "break", StateGraph.END)
));
return stateGraph.compile();
}
@GetMapping("/loopGraph")
public Map<String, Object> loopGraph(@RequestParam("topic") String topic) {
Optional<OverAllState> overAllStateOptional = loopGraph.call(Map.of("topic", topic));
Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
return data;
}
GET 方式:http://localhost:8889/graph/loopGraph?topic=爱情 当 score 为 8 时,退出循环,输出结果


我们可以把图中的状态数据进行存储。默契情况下 Graph 会把状态存储到内存中。
@Bean("saveGraph")
public CompiledGraph saveGraph(ChatClient.Builder clientBuilder) throws GraphStateException {
KeyStrategyFactory keyStrategyFactory = () -> Map.of();
// 定义状态图
StateGraph stateGraph = new StateGraph("saveGraph", keyStrategyFactory);
stateGraph.addNode("对话存储", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
String msg = state.value("msg", "");
ArrayList<Object> historyMsg = state.value("historyMsg", new ArrayList<>());
historyMsg.add(msg);
return Map.of("historyMsg", historyMsg);
}
}));
// 定义边
stateGraph.addEdge(StateGraph.START, "对话存储");
stateGraph.addEdge("对话存储", StateGraph.END);
return stateGraph.compile();
}
@GetMapping("/saveGraph")
// 通过 conversationId 来隔离不同请求者的数据
public Map<String, Object> saveGraph(@RequestParam("msg") String msg,
@RequestParam("conversationId") String conversationId) {
RunnableConfig runnableConfig = RunnableConfig.builder().threadId(conversationId).build();
Optional<OverAllState> overAllStateOptional = saveGraph.call(Map.of("msg", msg), runnableConfig);
Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
return data;
}

第二次调用,发现前面的值存储了下来

修改会话 ID,历史数据只有最新的一条数据

我们可以把定义好的状态图进行打印,更直观的看到当前图的情况
在图的下面添加如下代码:
// 添加 PlantUML 打印
GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "stateGraph");
log.info("\n===打印 UML Flow===");
log.info(representation.content());
log.info("====================\n");

启动服务,复制如下内容

打开网址:http://www.plantuml.com/plantuml/ 粘贴内容,就可以看到图的效果了


微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
查找任何按下的键的javascript键代码、代码、位置和修饰符。 在线工具,Keycode 信息在线工具,online
JavaScript 字符串转义/反转义;Java 风格 \uXXXX(Native2Ascii)编码与解码。 在线工具,Escape 与 Native 编解码在线工具,online
使用 Prettier 在浏览器内格式化 JavaScript 或 HTML 片段。 在线工具,JavaScript / HTML 格式化在线工具,online
Terser 压缩、变量名混淆,或 javascript-obfuscator 高强度混淆(体积会增大)。 在线工具,JavaScript 压缩与混淆在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online