blob: 03cd24c0eb5230891513ca84ceeb4081ab865c9a [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use jni::objects::{JByteArray, JClass};
use jni::sys::{jboolean, jbyteArray, jint, jlong, JNI_TRUE};
use jni::JNIEnv;
use lazy_static::lazy_static;
use rand::Rng;
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha20Rng;
use spin::Mutex;
use ukey2_connections::{
D2DConnectionContextV1, D2DHandshakeContext, DecodeError, DeserializeError, HandleMessageError,
HandshakeError, HandshakeImplementation, InitiatorD2DHandshakeContext,
ServerD2DHandshakeContext,
};
cfg_if::cfg_if! {
if #[cfg(feature = "rustcrypto")] {
use crypto_provider_rustcrypto::RustCrypto as CryptoProvider;
} else {
use crypto_provider_openssl::Openssl as CryptoProvider;
}
}
// Handle management
type D2DBox = Box<dyn D2DHandshakeContext>;
type ConnectionBox = Box<D2DConnectionContextV1>;
lazy_static! {
static ref HANDLE_MAPPING: Mutex<HashMap<u64, D2DBox>> = Mutex::new(HashMap::new());
static ref CONNECTION_HANDLE_MAPPING: Mutex<HashMap<u64, ConnectionBox>> =
Mutex::new(HashMap::new());
static ref RNG: Mutex<ChaCha20Rng> = Mutex::new(ChaCha20Rng::from_entropy());
}
fn generate_handle() -> u64 {
RNG.lock().gen()
}
pub(crate) fn insert_handshake_handle(item: D2DBox) -> u64 {
let handle = generate_handle();
HANDLE_MAPPING.lock().insert(handle, item);
handle
}
pub(crate) fn insert_conn_handle(item: ConnectionBox) -> u64 {
let handle = generate_handle();
CONNECTION_HANDLE_MAPPING.lock().insert(handle, item);
handle
}
#[derive(Debug)]
enum JniError {
BadHandle,
DecodeError(DecodeError),
HandleMessageError(HandleMessageError),
HandshakeError(HandshakeError),
}
// D2DHandshakeContext
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_is_1handshake_1complete(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
) -> jboolean {
let mut is_complete = false;
if let Some(ctx) = HANDLE_MAPPING.lock().get(&(context_handle as u64)) {
is_complete = ctx.is_handshake_complete();
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
}
is_complete as jboolean
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_create_1context(
_: JNIEnv,
_: JClass,
is_client: jboolean,
) -> jlong {
if is_client == JNI_TRUE {
let client_obj = Box::new(InitiatorD2DHandshakeContext::<CryptoProvider>::new(
HandshakeImplementation::PublicKeyInProtobuf,
));
insert_handshake_handle(client_obj) as jlong
} else {
let server_obj = Box::new(ServerD2DHandshakeContext::<CryptoProvider>::new(
HandshakeImplementation::PublicKeyInProtobuf,
));
insert_handshake_handle(server_obj) as jlong
}
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1next_1handshake_1message(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
) -> jbyteArray {
let empty_arr = env.new_byte_array(0).unwrap();
let next_message = if let Some(ctx) = HANDLE_MAPPING.lock().get(&(context_handle as u64)) {
ctx.get_next_handshake_message()
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
None
};
// TODO error handling
if let Some(message) = next_message {
env.byte_array_from_slice(message.as_slice()).unwrap()
} else {
empty_arr
}
.into_raw()
}
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
/// Safety: We know the message pointer is safe as it is coming directly from the JVM.
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_parse_1handshake_1message(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
message: jbyteArray,
) {
let rust_buffer = env
.convert_byte_array(unsafe { JByteArray::from_raw(message) })
.unwrap();
let result = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
ctx.handle_handshake_message(rust_buffer.as_slice())
.map_err(JniError::HandleMessageError)
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
Err(JniError::BadHandle)
};
if let Err(e) = result {
if !env.exception_check().unwrap() {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/HandshakeException",
match e {
JniError::BadHandle => "Bad handle",
JniError::DecodeError(_) => "Unable to decode message",
JniError::HandleMessageError(_) => "Unable to handle message",
JniError::HandshakeError(_) => "Handshake incomplete",
},
)
.expect("failed to find error class");
}
}
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1verification_1string(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
length: jint,
) -> jbyteArray {
let empty_array = env.new_byte_array(0).unwrap();
let result = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
ctx.to_completed_handshake()
.map_err(|_| JniError::HandshakeError(HandshakeError::HandshakeNotComplete))
.map(|h| {
h.auth_string::<CryptoProvider>()
.derive_vec(length as usize)
.unwrap()
})
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
Err(JniError::BadHandle)
};
if let Err(e) = result {
if !env.exception_check().unwrap() {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/HandshakeException",
match e {
JniError::BadHandle => "Bad handle",
JniError::DecodeError(_) => "Unable to decode message",
JniError::HandleMessageError(_) => "Unable to handle message",
JniError::HandshakeError(_) => "Handshake incomplete",
},
)
.expect("failed to find error class");
}
empty_array
} else {
let ret_vec = result.unwrap();
env.byte_array_from_slice(&ret_vec).unwrap()
}
.into_raw()
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_to_1connection_1context(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
) -> jlong {
let conn_context = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
ctx.to_connection_context()
.map_err(JniError::HandshakeError)
} else {
Err(JniError::BadHandle)
};
if let Err(error) = conn_context {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/HandshakeException",
match error {
JniError::BadHandle => "Bad context handle",
JniError::HandshakeError(_) => "Handshake not complete",
JniError::DecodeError(_) | JniError::HandleMessageError(_) => "Unknown exception",
},
)
.expect("failed to find error class");
return -1;
} else {
HANDLE_MAPPING.lock().remove(&(context_handle as u64));
}
insert_conn_handle(Box::new(conn_context.unwrap())) as jlong
}
// D2DConnectionContextV1
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
/// Safety: We know the payload and associated_data pointers are safe as they are coming directly
/// from the JVM.
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_encode_1message_1to_1peer(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
payload: jbyteArray,
associated_data: jbyteArray,
) -> jbyteArray {
// We create the empty array here so we don't run into issues requesting a new byte array from
// the JNI env while an exception is being thrown.
let empty_array = env.new_byte_array(0).unwrap();
let result = if let Some(ctx) = CONNECTION_HANDLE_MAPPING
.lock()
.get_mut(&(context_handle as u64))
{
Ok(ctx.encode_message_to_peer::<CryptoProvider, _>(
env.convert_byte_array(unsafe { JByteArray::from_raw(payload) })
.unwrap()
.as_slice(),
if associated_data.is_null() {
None
} else {
Some(
env.convert_byte_array(unsafe { JByteArray::from_raw(associated_data) })
.unwrap(),
)
},
))
} else {
Err(JniError::BadHandle)
};
if let Ok(ret_vec) = result {
env.byte_array_from_slice(ret_vec.as_slice())
.expect("unable to create jByteArray")
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
empty_array
}
.into_raw()
}
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
/// Safety: We know the message and associated_data pointers are safe as they are coming directly
/// from the JVM.
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_decode_1message_1from_1peer(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
message: jbyteArray,
associated_data: jbyteArray,
) -> jbyteArray {
let empty_array = env.new_byte_array(0).unwrap();
let result = if let Some(ctx) = CONNECTION_HANDLE_MAPPING
.lock()
.get_mut(&(context_handle as u64))
{
ctx.decode_message_from_peer::<CryptoProvider, _>(
env.convert_byte_array(unsafe { JByteArray::from_raw(message) })
.unwrap()
.as_slice(),
if associated_data.is_null() {
None
} else {
Some(
env.convert_byte_array(unsafe { JByteArray::from_raw(associated_data) })
.unwrap(),
)
},
)
.map_err(JniError::DecodeError)
} else {
Err(JniError::BadHandle)
};
if let Ok(message) = result {
env.byte_array_from_slice(message.as_slice())
.expect("unable to create jByteArray")
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/CryptoException",
match result.unwrap_err() {
JniError::BadHandle => "Bad context handle",
JniError::DecodeError(e) => match e {
DecodeError::BadData => "Bad data",
DecodeError::BadSequenceNumber => "Bad sequence number",
},
// None of these should ever occur in this case.
JniError::HandleMessageError(_) | JniError::HandshakeError(_) => "Unknown error",
},
)
.expect("failed to find exception class");
empty_array
}
.into_raw()
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1encoding(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
) -> jint {
if let Some(ctx) = CONNECTION_HANDLE_MAPPING
.lock()
.get(&(context_handle as u64))
{
ctx.get_sequence_number_for_encoding() as jint
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
-1 as jint
}
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1decoding(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
) -> jint {
if let Some(ctx) = CONNECTION_HANDLE_MAPPING
.lock()
.get(&(context_handle as u64))
{
ctx.get_sequence_number_for_decoding() as jint
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
-1 as jint
}
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_save_1session(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
) -> jbyteArray {
let empty_array = env.new_byte_array(0).unwrap();
if let Some(ctx) = CONNECTION_HANDLE_MAPPING
.lock()
.get(&(context_handle as u64))
{
env.byte_array_from_slice(ctx.save_session().as_slice())
.expect("unable to save session")
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
empty_array
}
.into_raw()
}
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
/// Safety: We know the session_info pointer is safe because it is coming directly from the JVM.
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_from_1saved_1session(
mut env: JNIEnv,
_: JClass,
session_info: jbyteArray,
) -> jlong {
let session_info_rust = env
.convert_byte_array(unsafe { JByteArray::from_raw(session_info) })
.expect("bad session_info data");
let ctx = D2DConnectionContextV1::from_saved_session(session_info_rust.as_slice());
if ctx.is_err() {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/SessionRestoreException",
match ctx.err().unwrap() {
DeserializeError::BadDataLength => "DeserializeError: bad session_info length",
DeserializeError::BadProtocolVersion => "DeserializeError: bad protocol version",
DeserializeError::BadData => "DeserializeError: bad data",
},
)
.expect("failed to find exception class");
return -1;
}
let final_ctx = ctx.ok().unwrap();
let conn_context_final = Box::new(final_ctx);
insert_conn_handle(conn_context_final) as jlong
}
#[no_mangle]
pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1session_1unique(
mut env: JNIEnv,
_: JClass,
context_handle: jlong,
) -> jbyteArray {
let empty_array = env.new_byte_array(0).unwrap();
if let Some(ctx) = CONNECTION_HANDLE_MAPPING
.lock()
.get(&(context_handle as u64))
{
env.byte_array_from_slice(ctx.get_session_unique::<CryptoProvider>().as_slice())
.expect("unable to get unique session id")
} else {
env.throw_new(
"com/google/security/cryptauth/lib/securegcm/BadHandleException",
"",
)
.expect("failed to find error class");
empty_array
}
.into_raw()
}