Skip to content

Instantly share code, notes, and snippets.

@abel533
Created April 10, 2024 03:19
Show Gist options
  • Save abel533/300e642cb4e2548830981ce824036586 to your computer and use it in GitHub Desktop.
Save abel533/300e642cb4e2548830981ce824036586 to your computer and use it in GitHub Desktop.
使用 Spring AI 参考 https://github.com/mshumer/ai-journalist 实现的 AI 记者
package io.mybatis.ai;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.http.*;
import org.springframework.web.client.RestTemplate;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* 记者类,用于撰写和编辑文章
*/
public class AIJournalist {
private ChatClient chatClient; // 聊天客户端
private RestTemplate restTemplate = new RestTemplate(); // 用于发送HTTP请求的模板
private HttpHeaders serpApiHeader; // SERP API的请求头
/**
* 构造函数,初始化聊天客户端和SERP API的请求头
*
* @param chatClient 聊天客户端
* @param serpApiKey SERP API的密钥
*/
public AIJournalist(ChatClient chatClient, String serpApiKey) {
this.chatClient = chatClient;
this.serpApiHeader = new HttpHeaders();
serpApiHeader.setContentType(MediaType.APPLICATION_JSON);
serpApiHeader.set("X-API-KEY", serpApiKey);
}
public static void main(String[] args) {
// 创建聊天客户端
var openAiApi = new OpenAiApi("替换token");
var chatClient = new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder()
.withModel("gpt-4-turbo").withTemperature(0.4F).build());
// 启动
AIJournalist journalist = new AIJournalist(chatClient, "替换token");
journalist.start();
}
/**
* 开始撰写和编辑文章的过程
*/
public void start() {
// User input
Scanner scanner = new Scanner(System.in);
System.out.print("输入要写的主题:");
String topic = scanner.nextLine();
System.out.print("初稿完成后,是否要进行自动编辑?这可能会提高性能,但有点不可靠。回答“是”或“否”:");
String doEdit = scanner.nextLine();
// Generate search terms
List<String> searchTerms = getSearchTerms(topic);
System.out.println("\n------------------------------------------------");
System.out.println("\n搜索词 '" + topic + "':");
System.out.println(String.join(", ", searchTerms));
// Perform searches and select relevant URLs
List<String> relevantUrls = new ArrayList<>();
for (String term : searchTerms) {
List<Map<String, Object>> searchResults = getSearchResults(term);
List<String> urls = selectRelevantUrls(searchResults);
relevantUrls.addAll(urls);
}
String urls = IntStream.range(0, relevantUrls.size())
.mapToObj(i -> (i + 1) + ". " + relevantUrls.get(i))
.collect(Collectors.joining("\n"));
System.out.println("\n------------------------------------------------");
System.out.println("要阅读的相关 URL:\n" + urls);
// Get article text from relevant URLs
List<String> articleTexts = new ArrayList<>();
for (String url : relevantUrls) {
try {
String text = getArticleText(url);
if (text.length() > 75) {
articleTexts.add(text);
}
} catch (Exception e) {
e.printStackTrace();
}
}
System.out.println("\n------------------------------------------------");
System.out.println("参考文章:" + articleTexts);
System.out.println("\n\n正在写文章...");
// Write the article
String article = writeArticle(topic, articleTexts);
System.out.println("\n------------------------------------------------");
System.out.println("\n生成的文章:");
System.out.println(article);
if (doEdit.toLowerCase().contains("是")) {
// Edit the article
String editedArticle = editArticle(article);
System.out.println("\n------------------------------------------------");
System.out.println("\n编辑文章:");
System.out.println(editedArticle);
}
}
/**
* 从给定的URL获取文章文本
*
* @param url 文章的URL
* @return 文章的文本
*/
public String getArticleText(String url) {
try {
Document doc = Jsoup.connect(url).get();
return doc.body().text();
} catch (Exception e) {
System.out.println("解析URL" + url + " 错误: " + e.getMessage());
return "";
}
}
/**
* 根据给定的主题和参考文章文本撰写文章
*
* @param topic 主题
* @param articleTexts 参考文章文本
* @return 撰写的文章
*/
public String writeArticle(String topic, List<String> articleTexts) {
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一位世界级的记者。根据以下的参考文章和主题,撰写一篇关于该主题的文章。"));
StringBuilder articlesText = new StringBuilder();
for (int i = 0; i < articleTexts.size(); i++) {
String article = articleTexts.get(i);
articlesText.append(i + 1).append(". ").append(article).append("\n");
}
messages.add(new UserMessage("参考文章:\n" + articlesText + "\n\n主题: " + topic + "\n\n请撰写一篇关于该主题的文章。"));
return call(messages);
}
private int error = 0;
/**
* 调用聊天客户端,发送消息并获取回复
*
* @param messages 要发送的消息
* @return 聊天客户端的回复
*/
public String call(List<Message> messages) {
try {
if (error > 6) {
throw new RuntimeException("失败" + error + "次,停止调用");
}
ChatResponse response = chatClient.call(new Prompt(messages));
error = 0;
return response.getResult().getOutput().getContent();
} catch (Exception e) {
System.out.println("请求出错: " + e.getMessage());
System.out.println("等待 " + error * 10000 + "s 重试");
try {
Thread.sleep(error * 10000);
} catch (InterruptedException ex) {
}
error++;
return call(messages);
}
}
/**
* 编辑文章以提高其质量
*
* @param article 要编辑的文章
* @return 编辑后的文章
*/
public String editArticle(String article) {
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一位世界级的编辑。根据以下的文章,进行编辑以提高其质量。"));
messages.add(new UserMessage("请编辑以下文章以提高其质量:\n" + article));
return call(messages);
}
/**
* 根据给定的主题生成搜索词
*
* @param topic 主题
* @return 搜索词列表
*/
public List<String> getSearchTerms(String topic) {
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一位世界级的记者。生成一个包含5个搜索词的列表,用于研究和撰写关于该主题的文章。"));
messages.add(new UserMessage("主题: " + topic + "\n\n请提供一个与'" + topic + "'相关的5个搜索词的列表,用于研究和撰写文章。以逗号分隔的Java可解析列表形式回复。"));
String responseText = call(messages);
return Arrays.asList(responseText.replace("[", "")
.replace("]", "").replace("\"", "").split(","));
}
/**
* 从给定的搜索结果中选择相关的URL
*
* @param searchResults 搜索结果
* @return 相关的URL列表
*/
public List<String> selectRelevantUrls(List<Map<String, Object>> searchResults) {
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一位记者助手。从给定的搜索结果中,选择出看起来最相关和信息丰富的 URL,用于撰写关于该主题的文章。"));
StringBuilder searchResultsText = new StringBuilder();
for (int i = 0; i < searchResults.size(); i++) {
searchResultsText.append(i + 1).append(". ").append(searchResults.get(i).get("link")).append("\n");
}
messages.add(new UserMessage("搜索结果:\n" + searchResultsText + "\n\n请选择看起来最相关和信息丰富的 URL 的编号,用于撰写关于该主题的文章。以逗号分隔的 Java 可解析列表形式回复(如 [1,2,4])。"));
String responseText = call(messages);
String[] numbers = responseText.replace("[", "")
.replace("]", "").replace("\"", "").split(",");
List<String> relevantUrls = new ArrayList<>();
for (String num : numbers) {
int index = Integer.parseInt(num.trim()) - 1;
relevantUrls.add((String) searchResults.get(index).get("link"));
}
return relevantUrls;
}
/**
* 根据给定的搜索词获取搜索结果
*
* @param searchTerm 搜索词
* @return 搜索结果
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public List<Map<String, Object>> getSearchResults(String searchTerm) {
// Create request body
String body = "{\"q\":\"" + searchTerm + "\",\"hl\":\"en\",\"num\":10}";
// Create entity
HttpEntity<String> entity = new HttpEntity<>(body, serpApiHeader);
// Execute request
ResponseEntity<Map> response = restTemplate.exchange(
"https://google.serper.dev/search",
HttpMethod.POST,
entity,
Map.class);
return (List<Map<String, Object>>) response.getBody().get("organic");
}
}
@abel533
Copy link
Author

abel533 commented Apr 10, 2024

@qinfengge
Copy link

大佬太🐂了,搜spring ai 看到大佬写的文章。拜服🫡

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment