blob: 50875f8f0de44f3458e95923a7119075e3738820 [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
borrow::Borrow,
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
hash::{BuildHasher, Hash},
};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum ZipItem<L, R> {
Left(L),
Right(R),
Both(L, R),
}
pub fn zip_hash_map<'m, K, K1, K2, V1, V2, S>(
map1: &'m HashMap<K1, V1, S>,
map2: &'m HashMap<K2, V2, S>,
) -> impl Iterator<Item = (&'m K, ZipItem<&'m V1, &'m V2>)>
where
K1: Borrow<K> + Hash + Eq,
K2: Borrow<K> + Hash + Eq,
K: Eq + Hash + 'm,
S: BuildHasher + Default,
{
let keys: HashSet<&K, S> = map1
.keys()
.map(|k| k.borrow())
.chain(map2.keys().map(|k| k.borrow()))
.collect();
keys.into_iter()
.map(|key| match (map1.get(key), map2.get(key)) {
(Some(value_a), Some(value_b)) => (key, ZipItem::Both(value_a, value_b)),
(Some(v), None) => (key, ZipItem::Left(v)),
(None, Some(v)) => (key, ZipItem::Right(v)),
(None, None) => unreachable!("Iterator was built from a.keys.chain(b.keys)"),
})
}
pub fn zip_btree_map<'m, K, K1, K2, V1, V2>(
map1: &'m BTreeMap<K1, V1>,
map2: &'m BTreeMap<K2, V2>,
) -> impl Iterator<Item = (&'m K, ZipItem<&'m V1, &'m V2>)>
where
K1: Borrow<K> + Ord,
K2: Borrow<K> + Ord,
K: Eq + Ord + 'm,
{
let keys: BTreeSet<&K> = map1
.keys()
.map(|k| k.borrow())
.chain(map2.keys().map(|k| k.borrow()))
.collect();
let mut map1_iter = map1.iter().peekable();
let mut map2_iter = map2.iter().peekable();
keys.into_iter().map(move |key| {
match (
map1_iter.next_if(|(k, _)| (**k).borrow() == key),
map2_iter.next_if(|(k, _)| (**k).borrow() == key),
) {
(Some((_, value_a)), Some((_, value_b))) => (key, ZipItem::Both(value_a, value_b)),
(Some((_, v)), None) => (key, ZipItem::Left(v)),
(None, Some((_, v))) => (key, ZipItem::Right(v)),
(None, None) => unreachable!("Iterator was built from a.keys.chain(b.keys)"),
}
})
}
/// Merge two iterators according to the order, like how it's done in merge sort.
pub fn merge_descending_iters<'t, T: Ord + 't>(
iter1: impl IntoIterator<Item = &'t T> + 't,
iter2: impl IntoIterator<Item = &'t T> + 't,
) -> impl Iterator<Item = &'t T> {
// Reverse the comparators so we can make use of the fact that Some(X) > None.
let mut iter1 = iter1.into_iter().peekable();
let mut iter2 = iter2.into_iter().peekable();
std::iter::from_fn(move || {
if iter1.peek() > iter2.peek() {
iter1.next()
} else {
iter2.next()
}
})
}
pub enum Either<L, R> {
Left(L),
Right(R),
}
impl<L, R> Either<L, R> {
pub fn into_iter<I>(self) -> impl Iterator<Item = I>
where
L: Iterator<Item = I>,
R: Iterator<Item = I>,
{
let (l_iter, r_iter) = match self {
Either::Left(l) => (Some(l), None),
Either::Right(r) => (None, Some(r)),
};
l_iter
.into_iter()
.flatten()
.chain(r_iter.into_iter().flatten())
}
}
#[cfg(test)]
mod tests {
use std::collections::{BTreeMap, BTreeSet, HashMap};
use crate::utils::{zip_btree_map, ZipItem};
use super::{merge_descending_iters, zip_hash_map};
#[test]
fn test_zip_hash_map() {
let english = HashMap::<_, _>::from_iter([(0, "zero"), (1, "one"), (2, "two")]);
let spanish = HashMap::<_, _>::from_iter([(1, "uno"), (2, "dos"), (3, "tres")]);
let zipped: HashMap<_, _> = zip_hash_map(&english, &spanish).collect();
assert_eq!(
zipped,
HashMap::from_iter([
(&0, ZipItem::Left(&"zero")),
(&1, ZipItem::Both(&"one", &"uno")),
(&2, ZipItem::Both(&"two", &"dos")),
(&3, ZipItem::Right(&"tres")),
])
)
}
#[test]
fn test_zip_btree_map() {
let english = BTreeMap::from_iter([(0, "zero"), (1, "one"), (2, "two")]);
let spanish = BTreeMap::from_iter([(1, "uno"), (2, "dos"), (3, "tres")]);
let zipped: BTreeMap<_, _> = zip_btree_map(&english, &spanish).collect();
assert_eq!(
zipped,
BTreeMap::from_iter([
(&0, ZipItem::Left(&"zero")),
(&1, ZipItem::Both(&"one", &"uno")),
(&2, ZipItem::Both(&"two", &"dos")),
(&3, ZipItem::Right(&"tres")),
])
)
}
/// The implementations for [`zip_hash_map`] and [`zip_btree_map`] should be the same, except
/// for iteration order.
#[derive_fuzztest::proptest]
fn zip_results_equal(map1: Vec<(u8, u8)>, map2: Vec<(u8, u8)>) {
let hashmap1 = HashMap::<_, _>::from_iter(map1.clone());
let hashmap2 = HashMap::<_, _>::from_iter(map2.clone());
let btreemap1 = BTreeMap::from_iter(map1.clone());
let btreemap2 = BTreeMap::from_iter(map2.clone());
let zipped_hashmap = zip_hash_map(&hashmap1, &hashmap2);
let zipped_btreemap = zip_btree_map(&btreemap1, &btreemap2);
assert_eq!(
zipped_hashmap.collect::<BTreeSet<_>>(),
zipped_btreemap.collect::<BTreeSet<_>>()
);
}
#[test]
fn test_merge_iters() {
let iter1 = [5, 4, 3, 2, 1];
let iter2 = [8, 6, 4, 2];
let merged = merge_descending_iters(&iter1, &iter2).collect::<Vec<_>>();
assert_eq!([&8, &6, &5, &4, &4, &3, &2, &2, &1].as_slice(), &merged);
}
}