Created
April 10, 2024 03:19
-
-
Save abel533/300e642cb4e2548830981ce824036586 to your computer and use it in GitHub Desktop.
使用 Spring AI 参考 https://github.com/mshumer/ai-journalist 实现的 AI 记者
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.mybatis.ai; | |
import org.jsoup.Jsoup; | |
import org.jsoup.nodes.Document; | |
import org.springframework.ai.chat.ChatClient; | |
import org.springframework.ai.chat.ChatResponse; | |
import org.springframework.ai.chat.messages.Message; | |
import org.springframework.ai.chat.messages.SystemMessage; | |
import org.springframework.ai.chat.messages.UserMessage; | |
import org.springframework.ai.chat.prompt.Prompt; | |
import org.springframework.ai.openai.OpenAiChatClient; | |
import org.springframework.ai.openai.OpenAiChatOptions; | |
import org.springframework.ai.openai.api.OpenAiApi; | |
import org.springframework.http.*; | |
import org.springframework.web.client.RestTemplate; | |
import java.util.*; | |
import java.util.stream.Collectors; | |
import java.util.stream.IntStream; | |
/** | |
* 记者类,用于撰写和编辑文章 | |
*/ | |
public class AIJournalist { | |
private ChatClient chatClient; // 聊天客户端 | |
private RestTemplate restTemplate = new RestTemplate(); // 用于发送HTTP请求的模板 | |
private HttpHeaders serpApiHeader; // SERP API的请求头 | |
/** | |
* 构造函数,初始化聊天客户端和SERP API的请求头 | |
* | |
* @param chatClient 聊天客户端 | |
* @param serpApiKey SERP API的密钥 | |
*/ | |
public AIJournalist(ChatClient chatClient, String serpApiKey) { | |
this.chatClient = chatClient; | |
this.serpApiHeader = new HttpHeaders(); | |
serpApiHeader.setContentType(MediaType.APPLICATION_JSON); | |
serpApiHeader.set("X-API-KEY", serpApiKey); | |
} | |
public static void main(String[] args) { | |
// 创建聊天客户端 | |
var openAiApi = new OpenAiApi("替换token"); | |
var chatClient = new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder() | |
.withModel("gpt-4-turbo").withTemperature(0.4F).build()); | |
// 启动 | |
AIJournalist journalist = new AIJournalist(chatClient, "替换token"); | |
journalist.start(); | |
} | |
/** | |
* 开始撰写和编辑文章的过程 | |
*/ | |
public void start() { | |
// User input | |
Scanner scanner = new Scanner(System.in); | |
System.out.print("输入要写的主题:"); | |
String topic = scanner.nextLine(); | |
System.out.print("初稿完成后,是否要进行自动编辑?这可能会提高性能,但有点不可靠。回答“是”或“否”:"); | |
String doEdit = scanner.nextLine(); | |
// Generate search terms | |
List<String> searchTerms = getSearchTerms(topic); | |
System.out.println("\n------------------------------------------------"); | |
System.out.println("\n搜索词 '" + topic + "':"); | |
System.out.println(String.join(", ", searchTerms)); | |
// Perform searches and select relevant URLs | |
List<String> relevantUrls = new ArrayList<>(); | |
for (String term : searchTerms) { | |
List<Map<String, Object>> searchResults = getSearchResults(term); | |
List<String> urls = selectRelevantUrls(searchResults); | |
relevantUrls.addAll(urls); | |
} | |
String urls = IntStream.range(0, relevantUrls.size()) | |
.mapToObj(i -> (i + 1) + ". " + relevantUrls.get(i)) | |
.collect(Collectors.joining("\n")); | |
System.out.println("\n------------------------------------------------"); | |
System.out.println("要阅读的相关 URL:\n" + urls); | |
// Get article text from relevant URLs | |
List<String> articleTexts = new ArrayList<>(); | |
for (String url : relevantUrls) { | |
try { | |
String text = getArticleText(url); | |
if (text.length() > 75) { | |
articleTexts.add(text); | |
} | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
System.out.println("\n------------------------------------------------"); | |
System.out.println("参考文章:" + articleTexts); | |
System.out.println("\n\n正在写文章..."); | |
// Write the article | |
String article = writeArticle(topic, articleTexts); | |
System.out.println("\n------------------------------------------------"); | |
System.out.println("\n生成的文章:"); | |
System.out.println(article); | |
if (doEdit.toLowerCase().contains("是")) { | |
// Edit the article | |
String editedArticle = editArticle(article); | |
System.out.println("\n------------------------------------------------"); | |
System.out.println("\n编辑文章:"); | |
System.out.println(editedArticle); | |
} | |
} | |
/** | |
* 从给定的URL获取文章文本 | |
* | |
* @param url 文章的URL | |
* @return 文章的文本 | |
*/ | |
public String getArticleText(String url) { | |
try { | |
Document doc = Jsoup.connect(url).get(); | |
return doc.body().text(); | |
} catch (Exception e) { | |
System.out.println("解析URL" + url + " 错误: " + e.getMessage()); | |
return ""; | |
} | |
} | |
/** | |
* 根据给定的主题和参考文章文本撰写文章 | |
* | |
* @param topic 主题 | |
* @param articleTexts 参考文章文本 | |
* @return 撰写的文章 | |
*/ | |
public String writeArticle(String topic, List<String> articleTexts) { | |
List<Message> messages = new ArrayList<>(); | |
messages.add(new SystemMessage("你是一位世界级的记者。根据以下的参考文章和主题,撰写一篇关于该主题的文章。")); | |
StringBuilder articlesText = new StringBuilder(); | |
for (int i = 0; i < articleTexts.size(); i++) { | |
String article = articleTexts.get(i); | |
articlesText.append(i + 1).append(". ").append(article).append("\n"); | |
} | |
messages.add(new UserMessage("参考文章:\n" + articlesText + "\n\n主题: " + topic + "\n\n请撰写一篇关于该主题的文章。")); | |
return call(messages); | |
} | |
private int error = 0; | |
/** | |
* 调用聊天客户端,发送消息并获取回复 | |
* | |
* @param messages 要发送的消息 | |
* @return 聊天客户端的回复 | |
*/ | |
public String call(List<Message> messages) { | |
try { | |
if (error > 6) { | |
throw new RuntimeException("失败" + error + "次,停止调用"); | |
} | |
ChatResponse response = chatClient.call(new Prompt(messages)); | |
error = 0; | |
return response.getResult().getOutput().getContent(); | |
} catch (Exception e) { | |
System.out.println("请求出错: " + e.getMessage()); | |
System.out.println("等待 " + error * 10000 + "s 重试"); | |
try { | |
Thread.sleep(error * 10000); | |
} catch (InterruptedException ex) { | |
} | |
error++; | |
return call(messages); | |
} | |
} | |
/** | |
* 编辑文章以提高其质量 | |
* | |
* @param article 要编辑的文章 | |
* @return 编辑后的文章 | |
*/ | |
public String editArticle(String article) { | |
List<Message> messages = new ArrayList<>(); | |
messages.add(new SystemMessage("你是一位世界级的编辑。根据以下的文章,进行编辑以提高其质量。")); | |
messages.add(new UserMessage("请编辑以下文章以提高其质量:\n" + article)); | |
return call(messages); | |
} | |
/** | |
* 根据给定的主题生成搜索词 | |
* | |
* @param topic 主题 | |
* @return 搜索词列表 | |
*/ | |
public List<String> getSearchTerms(String topic) { | |
List<Message> messages = new ArrayList<>(); | |
messages.add(new SystemMessage("你是一位世界级的记者。生成一个包含5个搜索词的列表,用于研究和撰写关于该主题的文章。")); | |
messages.add(new UserMessage("主题: " + topic + "\n\n请提供一个与'" + topic + "'相关的5个搜索词的列表,用于研究和撰写文章。以逗号分隔的Java可解析列表形式回复。")); | |
String responseText = call(messages); | |
return Arrays.asList(responseText.replace("[", "") | |
.replace("]", "").replace("\"", "").split(",")); | |
} | |
/** | |
* 从给定的搜索结果中选择相关的URL | |
* | |
* @param searchResults 搜索结果 | |
* @return 相关的URL列表 | |
*/ | |
public List<String> selectRelevantUrls(List<Map<String, Object>> searchResults) { | |
List<Message> messages = new ArrayList<>(); | |
messages.add(new SystemMessage("你是一位记者助手。从给定的搜索结果中,选择出看起来最相关和信息丰富的 URL,用于撰写关于该主题的文章。")); | |
StringBuilder searchResultsText = new StringBuilder(); | |
for (int i = 0; i < searchResults.size(); i++) { | |
searchResultsText.append(i + 1).append(". ").append(searchResults.get(i).get("link")).append("\n"); | |
} | |
messages.add(new UserMessage("搜索结果:\n" + searchResultsText + "\n\n请选择看起来最相关和信息丰富的 URL 的编号,用于撰写关于该主题的文章。以逗号分隔的 Java 可解析列表形式回复(如 [1,2,4])。")); | |
String responseText = call(messages); | |
String[] numbers = responseText.replace("[", "") | |
.replace("]", "").replace("\"", "").split(","); | |
List<String> relevantUrls = new ArrayList<>(); | |
for (String num : numbers) { | |
int index = Integer.parseInt(num.trim()) - 1; | |
relevantUrls.add((String) searchResults.get(index).get("link")); | |
} | |
return relevantUrls; | |
} | |
/** | |
* 根据给定的搜索词获取搜索结果 | |
* | |
* @param searchTerm 搜索词 | |
* @return 搜索结果 | |
*/ | |
@SuppressWarnings({"unchecked", "rawtypes"}) | |
public List<Map<String, Object>> getSearchResults(String searchTerm) { | |
// Create request body | |
String body = "{\"q\":\"" + searchTerm + "\",\"hl\":\"en\",\"num\":10}"; | |
// Create entity | |
HttpEntity<String> entity = new HttpEntity<>(body, serpApiHeader); | |
// Execute request | |
ResponseEntity<Map> response = restTemplate.exchange( | |
"https://google.serper.dev/search", | |
HttpMethod.POST, | |
entity, | |
Map.class); | |
return (List<Map<String, Object>>) response.getBody().get("organic"); | |
} | |
} |
大佬太🐂了,搜spring ai 看到大佬写的文章。拜服🫡
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
关联文章: https://mp.weixin.qq.com/s/IajaX7S0vOxhlD5gr9lLUg