Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use rayon when fetching pypi mapping #3299

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/pypi_mapping/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ pixi_config = { workspace = true }
pixi_consts = { workspace = true }
rattler_conda_types = { workspace = true }
rattler_digest = { workspace = true }
rayon = "1.10.0"
reqwest = { workspace = true, features = ["json"] }
reqwest-middleware = { workspace = true }
reqwest-retry = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing.workspace = true
url = { workspace = true }
uv-configuration = { workspace = true }
3 changes: 1 addition & 2 deletions crates/pypi_mapping/src/custom_pypi_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ pub async fn amend_pypi_purls(
client,
&packages_for_prefix_mapping,
reporter,
)
.await?;
)?;
let compressed_mapping =
prefix_pypi_name_mapping::conda_pypi_name_compressed_mapping(client).await?;

Expand Down
103 changes: 60 additions & 43 deletions crates/pypi_mapping/src/prefix_pypi_name_mapping.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
cell::RefCell, collections::{BTreeSet, HashMap}, sync::{Arc, LazyLock, Mutex}
};
use rayon::prelude::*;

use futures::{stream::FuturesUnordered, StreamExt};
use itertools::Itertools;
use miette::{IntoDiagnostic, WrapErr};
use rattler_conda_types::{PackageUrl, RepoDataRecord};
use rattler_digest::Sha256Hash;
use reqwest::StatusCode;
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use tokio::sync::Semaphore;
use tokio::runtime::Runtime;
use url::Url;
use uv_configuration::RAYON_INITIALIZE;

use super::{
build_pypi_purl_from_package_record, custom_pypi_mapping, is_conda_forge_record, PurlSource,
Reporter,
};

thread_local! {
static TOKIO_RT: RefCell<Option<Runtime>> = RefCell::new(None);
}

const STORAGE_URL: &str = "https://conda-mapping.prefix.dev";
const HASH_DIR: &str = "hash-v0";
const COMPRESSED_MAPPING: &str =
Expand Down Expand Up @@ -63,11 +67,15 @@ async fn try_fetch_single_mapping(
}

/// Downloads and caches the conda-forge conda-to-pypi name mapping.
pub async fn conda_pypi_name_mapping<'r>(
pub fn conda_pypi_name_mapping<'r>(
client: &ClientWithMiddleware,
conda_packages: impl IntoIterator<Item = &'r RepoDataRecord>,
reporter: Option<Arc<dyn Reporter>>,
) -> miette::Result<HashMap<Sha256Hash, Package>> {
// Force the initialization of the rayon thread pool to avoid implicit creation
// by the Installer.
LazyLock::force(&RAYON_INITIALIZE);

let filtered_packages = conda_packages
.into_iter()
// because we later skip adding purls for packages
Expand All @@ -85,63 +93,72 @@ pub async fn conda_pypi_name_mapping<'r>(
.collect_vec();

let total_records = filtered_packages.len();
let mut pending_futures = FuturesUnordered::new();
let concurrency_limit = Arc::new(Semaphore::new(100));
for (record, hash) in filtered_packages {
let result_map = Arc::new(Mutex::new(HashMap::with_capacity(total_records)));
let error = Arc::new(Mutex::new(None));

tracing::info!("Downloading conda-pypi mapping for {} packages", total_records);
filtered_packages.par_iter().for_each(|(record, hash)| {
// Check if we've already encountered an error
if error.lock().unwrap().is_some() {
return;
}

if let Some(reporter) = &reporter {
reporter.download_started(record, total_records);
}

let client = client.clone();
let reporter = reporter.clone();
let concurrency_limit = concurrency_limit.clone();

// Create a future that fetches the mapping for the record's hash concurrently
// with the rest of the requests.
pending_futures.push(async move {
// Acquire a permit to limit the number of concurrent requests
let _permit = concurrency_limit
.acquire_owned()
.await
.expect("semaphore error");

// Fetch the mapping by the hash of the record.
let result = try_fetch_single_mapping(&client, &hash).await;

// Report the result to the reporter
if let Some(reporter) = reporter {
match &result {
Ok(_) => reporter.download_finished(record, total_records),
Err(_) => reporter.download_failed(record, total_records),
}
}

match result {
Ok(Some(package)) => Ok(Some((hash, package))),
Ok(None) => Ok(None),
Err(e) => Err(e),
// Get or create the thread-local Tokio runtime
let result = TOKIO_RT.with(|rt| {
let mut rt_ref = rt.borrow_mut();
if rt_ref.is_none() {
*rt_ref = Some(Runtime::new().expect("Failed to create Tokio runtime"));
}

// Execute the async function within the Tokio runtime
rt_ref.as_ref().unwrap().block_on(try_fetch_single_mapping(&client, hash))
});
}

let mut result_map = HashMap::with_capacity(total_records);
while let Some(result) = pending_futures.next().await {
// Report the result to the reporter
if let Some(reporter) = &reporter {
match &result {
Ok(_) => reporter.download_finished(record, total_records),
Err(_) => reporter.download_failed(record, total_records),
}
}

match result {
Ok(Some((hash, package))) => {
Ok(Some(package)) => {
// Add the mapping to the result hashmap
result_map.insert(hash, package);
let mut map = result_map.lock().unwrap();
map.insert(*hash, package);
}
Ok(None) => {
// If no mapping was found, do nothing.
}
Err(e) => {
// If an error occurred, bail out,.
return Err(e);
// If an error occurred, store it
let mut err = error.lock().unwrap();
*err = Some(e);
}
}
});

// Check if any errors occurred
let err = error.lock().unwrap();
if let Some(e) = err.as_ref() {
// tracing::error!("Failed to download conda-pypi mapping: {:?}", e);
miette::bail!("Failed to download conda-pypi mapping: {:?}", e);
}

Ok(result_map)
// Convert Arc<Mutex<HashMap>> back to HashMap
let result = Arc::try_unwrap(result_map)
.expect("There should be no other references to result_map")
.into_inner()
.expect("Mutex should not be poisoned");

Ok(result)
}

/// Downloads and caches prefix.dev conda-pypi mapping.
Expand All @@ -162,7 +179,7 @@ pub async fn amend_pypi_purls(
) -> miette::Result<()> {
let conda_packages = conda_packages.into_iter().collect_vec();
let conda_mapping =
conda_pypi_name_mapping(client, conda_packages.iter().map(|p| *p as &_), reporter).await?;
conda_pypi_name_mapping(client, conda_packages.iter().map(|p| *p as &_), reporter)?;
let compressed_mapping = conda_pypi_name_compressed_mapping(client).await?;

for record in conda_packages {
Expand Down
Loading