465 lines
15 KiB
Rust
465 lines
15 KiB
Rust
use super::*;
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
use fancy_regex::Regex;
|
|
use futures_util::{stream, StreamExt};
|
|
use http::header::CONTENT_TYPE;
|
|
use reqwest::Url;
|
|
use scraper::{Html, Selector};
|
|
use serde::Deserialize;
|
|
use serde_json::Value;
|
|
use std::sync::LazyLock;
|
|
use std::{
|
|
collections::{HashMap, HashSet},
|
|
sync::Arc,
|
|
time::Duration,
|
|
};
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::sync::Semaphore;
|
|
|
|
pub const URL_LOADER: &str = "url";
|
|
pub const RECURSIVE_URL_LOADER: &str = "recursive_url";
|
|
|
|
pub const MEDIA_URL_EXTENSION: &str = "media_url";
|
|
pub const DEFAULT_EXTENSION: &str = "txt";
|
|
|
|
const MAX_CRAWLS: usize = 5;
|
|
const BREAK_ON_ERROR: bool = false;
|
|
const USER_AGENT: &str = "curl/8.6.0";
|
|
|
|
static CLIENT: LazyLock<Result<reqwest::Client>> = LazyLock::new(|| {
|
|
let builder = reqwest::ClientBuilder::new().timeout(Duration::from_secs(16));
|
|
let client = builder.build()?;
|
|
Ok(client)
|
|
});
|
|
|
|
static PRESET: LazyLock<Vec<(Regex, CrawlOptions)>> = LazyLock::new(|| {
|
|
vec![
|
|
(
|
|
Regex::new(r"github.com/([^/]+)/([^/]+)/tree/([^/]+)").unwrap(),
|
|
CrawlOptions {
|
|
exclude: vec!["changelog".into(), "changes".into(), "license".into()],
|
|
..Default::default()
|
|
},
|
|
),
|
|
(
|
|
Regex::new(r"github.com/([^/]+)/([^/]+)/wiki").unwrap(),
|
|
CrawlOptions {
|
|
exclude: vec!["_history".into()],
|
|
extract: Some("#wiki-body".into()),
|
|
..Default::default()
|
|
},
|
|
),
|
|
]
|
|
});
|
|
|
|
static EXTENSION_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\.[^.]+$").unwrap());
|
|
static GITHUB_REPO_RE: LazyLock<Regex> =
|
|
LazyLock::new(|| Regex::new(r"^https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)").unwrap());
|
|
|
|
pub async fn fetch(url: &str) -> Result<String> {
|
|
let client = match *CLIENT {
|
|
Ok(ref client) => client,
|
|
Err(ref err) => bail!("{err}"),
|
|
};
|
|
let res = client.get(url).send().await?;
|
|
let output = res.text().await?;
|
|
Ok(output)
|
|
}
|
|
|
|
pub async fn fetch_with_loaders(
|
|
loaders: &HashMap<String, String>,
|
|
path: &str,
|
|
allow_media: bool,
|
|
) -> Result<(String, String)> {
|
|
if let Some(loader_command) = loaders.get(URL_LOADER) {
|
|
let contents = run_loader_command(path, URL_LOADER, loader_command)?;
|
|
return Ok((contents, DEFAULT_EXTENSION.into()));
|
|
}
|
|
let client = match *CLIENT {
|
|
Ok(ref client) => client,
|
|
Err(ref err) => bail!("{err}"),
|
|
};
|
|
let mut res = client.get(path).send().await?;
|
|
if !res.status().is_success() {
|
|
bail!("Invalid status: {}", res.status());
|
|
}
|
|
let content_type = res
|
|
.headers()
|
|
.get(CONTENT_TYPE)
|
|
.and_then(|v| v.to_str().ok())
|
|
.map(|v| match v.split_once(';') {
|
|
Some((mime, _)) => mime.trim(),
|
|
None => v,
|
|
})
|
|
.map(|v| v.to_string())
|
|
.unwrap_or_else(|| {
|
|
format!(
|
|
"_/{}",
|
|
get_patch_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into())
|
|
)
|
|
});
|
|
let mut is_media = false;
|
|
let extension = match content_type.as_str() {
|
|
"application/pdf" => "pdf".into(),
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" => "docx".into(),
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => "xlsx".into(),
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation" => {
|
|
"pptx".into()
|
|
}
|
|
"application/vnd.oasis.opendocument.text" => "odt".into(),
|
|
"application/vnd.oasis.opendocument.spreadsheet" => "ods".into(),
|
|
"application/vnd.oasis.opendocument.presentation" => "odp".into(),
|
|
"application/rtf" => "rtf".into(),
|
|
"text/javascript" => "js".into(),
|
|
"text/html" => "html".into(),
|
|
_ => content_type
|
|
.rsplit_once('/')
|
|
.map(|(first, last)| {
|
|
if ["image", "video", "audio"].contains(&first) {
|
|
is_media = true;
|
|
MEDIA_URL_EXTENSION.into()
|
|
} else {
|
|
last.to_lowercase()
|
|
}
|
|
})
|
|
.unwrap_or_else(|| DEFAULT_EXTENSION.into()),
|
|
};
|
|
let result = if is_media {
|
|
if !allow_media {
|
|
bail!("Unexpected media type")
|
|
}
|
|
let image_bytes = res.bytes().await?;
|
|
let image_base64 = base64_encode(&image_bytes);
|
|
let contents = format!("data:{content_type};base64,{image_base64}");
|
|
(contents, extension)
|
|
} else {
|
|
match loaders.get(&extension) {
|
|
Some(loader_command) => {
|
|
let save_path = temp_file("-download-", &format!(".{extension}"))
|
|
.display()
|
|
.to_string();
|
|
let mut save_file = tokio::fs::File::create(&save_path).await?;
|
|
let mut size = 0;
|
|
while let Some(chunk) = res.chunk().await? {
|
|
size += chunk.len();
|
|
save_file.write_all(&chunk).await?;
|
|
}
|
|
let contents = if size == 0 {
|
|
println!("{}", warning_text(&format!("No content at '{path}'")));
|
|
String::new()
|
|
} else {
|
|
run_loader_command(&save_path, &extension, loader_command)?
|
|
};
|
|
(contents, DEFAULT_EXTENSION.into())
|
|
}
|
|
None => {
|
|
let contents = res.text().await?;
|
|
if extension == "html" {
|
|
(html_to_md(&contents), "md".into())
|
|
} else {
|
|
(contents, extension)
|
|
}
|
|
}
|
|
}
|
|
};
|
|
Ok(result)
|
|
}
|
|
|
|
pub async fn fetch_models(api_base: &str, api_key: Option<&str>) -> Result<Vec<String>> {
|
|
let client = match *CLIENT {
|
|
Ok(ref client) => client,
|
|
Err(ref err) => bail!("{err}"),
|
|
};
|
|
let mut builder = client.get(format!("{}/models", api_base.trim_end_matches('/')));
|
|
if let Some(api_key) = api_key {
|
|
builder = builder.bearer_auth(api_key);
|
|
}
|
|
let res_body: Value = builder.send().await?.json().await?;
|
|
let mut result: Vec<String> = res_body
|
|
.get("data")
|
|
.and_then(|v| v.as_array())
|
|
.map(|v| {
|
|
v.iter()
|
|
.filter_map(|v| v.get("id").and_then(|v| v.as_str().map(|v| v.to_string())))
|
|
.collect()
|
|
})
|
|
.unwrap_or_default();
|
|
if result.is_empty() {
|
|
bail!("No valid models")
|
|
}
|
|
result.sort_unstable();
|
|
Ok(result)
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct CrawlOptions {
|
|
extract: Option<String>,
|
|
exclude: Vec<String>,
|
|
no_log: bool,
|
|
}
|
|
|
|
impl CrawlOptions {
|
|
pub fn preset(start_url: &str) -> CrawlOptions {
|
|
for (re, options) in PRESET.iter() {
|
|
if let Ok(true) = re.is_match(start_url) {
|
|
return options.clone();
|
|
}
|
|
}
|
|
CrawlOptions::default()
|
|
}
|
|
}
|
|
|
|
pub async fn crawl_website(start_url: &str, options: CrawlOptions) -> Result<Vec<Page>> {
|
|
let start_url = Url::parse(start_url)?;
|
|
let mut paths = vec![start_url.path().to_string()];
|
|
let normalized_start_url = normalize_start_url(&start_url);
|
|
if !options.no_log {
|
|
println!(
|
|
"Start crawling url={start_url} exclude={} extract={}",
|
|
options.exclude.join(","),
|
|
options.extract.as_deref().unwrap_or_default()
|
|
);
|
|
}
|
|
|
|
if let Ok(true) = GITHUB_REPO_RE.is_match(start_url.as_str()) {
|
|
paths = crawl_gh_tree(&start_url, &options.exclude)
|
|
.await
|
|
.with_context(|| "Failed to craw github repo".to_string())?;
|
|
}
|
|
|
|
let semaphore = Arc::new(Semaphore::new(MAX_CRAWLS));
|
|
let mut result_pages = Vec::new();
|
|
|
|
let mut index = 0;
|
|
while index < paths.len() {
|
|
let batch = paths[index..std::cmp::min(index + MAX_CRAWLS, paths.len())].to_vec();
|
|
|
|
let tasks: Vec<_> = batch
|
|
.iter()
|
|
.map(|path| {
|
|
let options = options.clone();
|
|
let permit = semaphore.clone().acquire_owned(); // acquire a permit for concurrency control
|
|
let normalized_start_url = normalized_start_url.clone();
|
|
let path = path.clone();
|
|
|
|
async move {
|
|
let _permit = permit.await?;
|
|
let url = normalized_start_url
|
|
.join(&path)
|
|
.map_err(|_| anyhow!("Invalid crawl page at {}", path))?;
|
|
let mut page = crawl_page(&normalized_start_url, &path, options)
|
|
.await
|
|
.with_context(|| format!("Failed to crawl {}", url.as_str()))?;
|
|
page.0 = url.as_str().to_string();
|
|
Ok(page)
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let results = stream::iter(tasks)
|
|
.buffer_unordered(MAX_CRAWLS)
|
|
.collect::<Vec<_>>()
|
|
.await;
|
|
|
|
let mut new_paths = Vec::new();
|
|
|
|
for res in results {
|
|
match res {
|
|
Ok((path, text, links)) => {
|
|
if !options.no_log {
|
|
println!("Crawled {path}");
|
|
}
|
|
if !text.is_empty() {
|
|
result_pages.push(Page { path, text });
|
|
}
|
|
for link in links {
|
|
if !paths.iter().any(|p| match_link(p, &link)) {
|
|
new_paths.push(link);
|
|
}
|
|
}
|
|
}
|
|
Err(err) => {
|
|
if BREAK_ON_ERROR {
|
|
return Err(err);
|
|
} else if !options.no_log {
|
|
println!("{}", error_text(&pretty_error(&err)));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
paths.extend(new_paths);
|
|
|
|
index += batch.len();
|
|
}
|
|
|
|
Ok(result_pages)
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct Page {
|
|
pub path: String,
|
|
pub text: String,
|
|
}
|
|
|
|
async fn crawl_gh_tree(start_url: &Url, exclude: &[String]) -> Result<Vec<String>> {
|
|
let path_segs: Vec<&str> = start_url.path().split('/').collect();
|
|
if path_segs.len() < 4 {
|
|
bail!("Invalid gh tree {}", start_url.as_str());
|
|
}
|
|
let client = match *CLIENT {
|
|
Ok(ref client) => client,
|
|
Err(ref err) => bail!("{err}"),
|
|
};
|
|
let owner = path_segs[1];
|
|
let repo = path_segs[2];
|
|
let branch = path_segs[4];
|
|
let root_path = path_segs[5..].join("/");
|
|
|
|
let url = format!("https://api.github.com/repos/{owner}/{repo}/git/ref/heads/{branch}");
|
|
|
|
let res_body: Value = client
|
|
.get(&url)
|
|
.header("User-Agent", USER_AGENT)
|
|
.header("Accept", "application/vnd.github+json")
|
|
.header("X-GitHub-Api-Version", "2022-11-28")
|
|
.send()
|
|
.await?
|
|
.json()
|
|
.await?;
|
|
|
|
let sha = res_body["object"]["sha"]
|
|
.as_str()
|
|
.ok_or_else(|| anyhow!("Not found branch or tag"))?;
|
|
|
|
let url = format!("https://api.github.com/repos/{owner}/{repo}/git/trees/{sha}?recursive=true");
|
|
|
|
let res_body: Value = client
|
|
.get(&url)
|
|
.header("User-Agent", USER_AGENT)
|
|
.header("Accept", "application/vnd.github+json")
|
|
.header("X-GitHub-Api-Version", "2022-11-28")
|
|
.send()
|
|
.await?
|
|
.json()
|
|
.await?;
|
|
let tree = res_body["tree"]
|
|
.as_array()
|
|
.ok_or_else(|| anyhow!("Invalid github repo tree"))?;
|
|
let paths = tree
|
|
.iter()
|
|
.flat_map(|v| {
|
|
let typ = v["type"].as_str()?;
|
|
let path = v["path"].as_str()?;
|
|
if typ == "blob"
|
|
&& (path.ends_with(".md") || path.ends_with(".MD"))
|
|
&& path.starts_with(&root_path)
|
|
&& !should_exclude_link(path, exclude)
|
|
{
|
|
Some(format!(
|
|
"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
|
))
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
Ok(paths)
|
|
}
|
|
|
|
async fn crawl_page(
|
|
start_url: &Url,
|
|
path: &str,
|
|
options: CrawlOptions,
|
|
) -> Result<(String, String, Vec<String>)> {
|
|
let client = match *CLIENT {
|
|
Ok(ref client) => client,
|
|
Err(ref err) => bail!("{err}"),
|
|
};
|
|
let location = start_url.join(path)?;
|
|
let response = client
|
|
.get(location.as_str())
|
|
.header("User-Agent", USER_AGENT)
|
|
.send()
|
|
.await?;
|
|
let body = response.text().await?;
|
|
|
|
if let Ok(true) = GITHUB_REPO_RE.is_match(start_url.as_str()) {
|
|
return Ok((path.to_string(), body, vec![]));
|
|
}
|
|
|
|
let mut links = HashSet::new();
|
|
let document = Html::parse_document(&body);
|
|
let selector = Selector::parse("a").map_err(|err| anyhow!("Invalid link selector, {}", err))?;
|
|
|
|
for element in document.select(&selector) {
|
|
if let Some(href) = element.value().attr("href") {
|
|
let href = Url::parse(href).ok().or_else(|| location.join(href).ok());
|
|
match href {
|
|
None => continue,
|
|
Some(href) => {
|
|
if href.as_str().starts_with(location.as_str())
|
|
&& !should_exclude_link(href.path(), &options.exclude)
|
|
{
|
|
links.insert(href.path().to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let text = if let Some(selector) = &options.extract {
|
|
let selector = Selector::parse(selector)
|
|
.map_err(|err| anyhow!("Invalid extract selector, {}", err))?;
|
|
document
|
|
.select(&selector)
|
|
.map(|v| html_to_md(&v.html()))
|
|
.collect::<Vec<String>>()
|
|
.join("\n\n")
|
|
} else {
|
|
html_to_md(&body)
|
|
};
|
|
|
|
Ok((path.to_string(), text, links.into_iter().collect()))
|
|
}
|
|
|
|
fn should_exclude_link(link: &str, exclude: &[String]) -> bool {
|
|
if link.contains("#") {
|
|
return true;
|
|
}
|
|
let parts: Vec<&str> = link.trim_end_matches('/').split('/').collect();
|
|
let name = parts.last().unwrap_or(&"").to_lowercase();
|
|
|
|
for exclude_name in exclude {
|
|
let cond = match EXTENSION_RE.is_match(exclude_name) {
|
|
Ok(true) => exclude_name.to_lowercase() == name.to_lowercase(),
|
|
_ => exclude_name.to_lowercase() == EXTENSION_RE.replace(&name, "").to_lowercase(),
|
|
};
|
|
if cond {
|
|
return true;
|
|
}
|
|
}
|
|
false
|
|
}
|
|
|
|
fn normalize_start_url(start_url: &Url) -> Url {
|
|
let mut start_url = start_url.clone();
|
|
start_url.set_query(None);
|
|
start_url.set_fragment(None);
|
|
let new_path = match start_url.path().rfind('/') {
|
|
Some(last_slash_index) => start_url.path()[..last_slash_index + 1].to_string(),
|
|
None => start_url.path().to_string(),
|
|
};
|
|
start_url.set_path(&new_path);
|
|
start_url
|
|
}
|
|
|
|
fn match_link(path: &str, link: &str) -> bool {
|
|
path == link
|
|
|| path
|
|
== link
|
|
.trim_end_matches("/index.html")
|
|
.trim_end_matches("/index.htm")
|
|
}
|