blob: da3bd57f76c59765df1d836c65b219c5804f522b [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use core::ops::{Deref, DerefMut};
use lock_adapter::std::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use lock_adapter::RwLock as _;
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicU32, Ordering};
use crate::guard::{
ObjectReadGuardImpl, ObjectReadGuardMapping, ObjectReadWriteGuardImpl,
ObjectReadWriteGuardMapping,
};
use crate::{Handle, HandleNotPresentError};
// Bunch o' type aliases to make talking about them much easier in the shard code.
type ShardMapType<T> = HashMap<Handle, T>;
type ShardReadWriteLock<T> = RwLock<ShardMapType<T>>;
type ShardReadGuard<'a, T> = RwLockReadGuard<'a, ShardMapType<T>>;
type ShardReadWriteGuard<'a, T> = RwLockWriteGuard<'a, ShardMapType<T>>;
/// Internal error enum for failed allocations into a given shard.
pub(crate) enum ShardAllocationError<T, F: FnOnce() -> T> {
/// Error for when the entry for the handle is occupied,
/// in which case we spit out the object-provider to try again
/// with a new handle-id.
EntryOccupied(F),
/// Error for when we would exceed the maximum number of allocations.
ExceedsAllocationLimit,
}
/// An individual handle-map shard, which is ultimately
/// just a hash-map behind a lock.
pub(crate) struct HandleMapShard<T: Send + Sync> {
data: RwLock<ShardMapType<T>>,
}
impl<T: Send + Sync> Default for HandleMapShard<T> {
fn default() -> Self {
Self { data: RwLock::new(HashMap::new()) }
}
}
impl<T: Send + Sync> HandleMapShard<T> {
pub fn get(&self, handle: Handle) -> Result<ObjectReadGuardImpl<T>, HandleNotPresentError> {
let map_read_guard = ShardReadWriteLock::<T>::read(&self.data);
let read_only_map_ref = map_read_guard.deref();
if read_only_map_ref.contains_key(&handle) {
let object_read_guard = ShardReadGuard::<T>::map(
map_read_guard,
ObjectReadGuardMapping { handle, _marker: PhantomData },
);
Ok(ObjectReadGuardImpl { guard: object_read_guard })
} else {
// Auto-drop the read guard, and return an error
Err(HandleNotPresentError)
}
}
/// Gets a read-write guard on the entire shard map if an entry for the given
/// handle exists, but if not, yield [`HandleNotPresentError`].
fn get_read_write_guard_if_entry_exists(
&self,
handle: Handle,
) -> Result<ShardReadWriteGuard<T>, HandleNotPresentError> {
let contains_key = {
let map_ref = self.data.read();
map_ref.contains_key(&handle)
};
if contains_key {
// If we know that the entry exists, and we're currently
// holding a read-lock, we know that we're safe to request
// an upgrade to a write lock, since only one write or
// upgradable read lock can be outstanding at any one time.
let write_guard = self.data.write();
Ok(write_guard)
} else {
// Auto-drop the read guard, we don't need to allow a write.
Err(HandleNotPresentError)
}
}
pub fn get_mut(
&self,
handle: Handle,
) -> Result<ObjectReadWriteGuardImpl<T>, HandleNotPresentError> {
let map_read_write_guard = self.get_read_write_guard_if_entry_exists(handle)?;
// Expose only the pointed-to object with a mapped read-write guard
let object_read_write_guard = ShardReadWriteGuard::<T>::map(
map_read_write_guard,
ObjectReadWriteGuardMapping { handle, _marker: PhantomData },
);
Ok(ObjectReadWriteGuardImpl { guard: object_read_write_guard })
}
pub fn deallocate(
&self,
handle: Handle,
outstanding_allocations_counter: &AtomicU32,
) -> Result<T, HandleNotPresentError> {
let mut map_read_write_guard = self.get_read_write_guard_if_entry_exists(handle)?;
// We don't need to worry about double-decrements, since the above call
// got us an upgradable read guard for our read, which means it's the only
// outstanding upgradeable guard on the shard. See `spin` documentation.
// Remove the pointed-to object from the map, and return it,
// releasing the lock when the guard goes out of scope.
let removed_object = map_read_write_guard.deref_mut().remove(&handle).unwrap();
// Decrement the allocations counter. Release ordering because we want
// to ensure that clearing the map entry never gets re-ordered to after when
// this counter gets decremented.
let _ = outstanding_allocations_counter.fetch_sub(1, Ordering::Release);
Ok(removed_object)
}
pub fn try_allocate<F>(
&self,
handle: Handle,
object_provider: F,
outstanding_allocations_counter: &AtomicU32,
max_active_handles: u32,
) -> Result<(), ShardAllocationError<T, F>>
where
F: FnOnce() -> T,
{
let mut read_write_guard = self.data.write();
match read_write_guard.entry(handle) {
Occupied(_) => {
// We've already allocated for that handle-id, so yield
// the object provider back to the caller.
Err(ShardAllocationError::EntryOccupied(object_provider))
}
Vacant(vacant_entry) => {
// An entry is open, but we haven't yet checked the allocations count.
// Try to increment the total allocations count atomically.
// Use acquire ordering on a successful bump, because we don't want
// to invoke the allocation closure before we have a guaranteed slot.
// On the other hand, upon failure, we don't care about ordering
// of surrounding operations, and so we use a relaxed ordering there.
let allocation_count_bump_result = outstanding_allocations_counter.fetch_update(
Ordering::Acquire,
Ordering::Relaxed,
|old_total_allocations| {
if old_total_allocations >= max_active_handles {
None
} else {
Some(old_total_allocations + 1)
}
},
);
match allocation_count_bump_result {
Ok(_) => {
// We're good to actually allocate
let object = object_provider();
vacant_entry.insert(object);
Ok(())
}
Err(_) => {
// The allocation would cause us to exceed the allowed allocations,
// so release all locks and error.
Err(ShardAllocationError::ExceedsAllocationLimit)
}
}
}
}
}
/// Gets the actual number of elements stored in this shard.
/// Only suitable for single-threaded sections of tests.
#[cfg(test)]
pub fn len(&self) -> usize {
let guard = ShardReadWriteLock::<T>::read(&self.data);
guard.deref().len()
}
}