blob: bce397691540170a5ab64ae4edc1189364d8d65f [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Code generation for the `call_method!` series of proc-macros. This module is used by creating a
//! [`MethodCall`] instance and calling [`MethodCall::generate`]. The generated code will vary
//! for each macro based on the contained [`Receiver`] value. If there is a detected issue with the
//! [`MethodCall`] provided, then `generate` will return a [`CodegenError`] with information on
//! what is wrong. This error can be converted into a [`syn::Error`] for reporting any failures to
//! the user.
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned, TokenStreamExt};
use syn::{parse_quote, spanned::Spanned, Expr, Ident, ItemStatic, LitStr, Type};
use crate::type_parser::{JavaType, MethodSig, NonArray, Primitive, ReturnType};
/// The errors that can be encountered during codegen. Used in [`CodegenError`].
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ErrorKind {
InvalidArgsLength { expected: usize, found: usize },
ConstructorRetValShouldBeVoid,
InvalidTypeSignature,
}
/// An error encountered during codegen with span information. Can be converted to [`syn::Error`].
/// using [`From`].
#[derive(Clone, Debug)]
pub struct CodegenError(pub Span, pub ErrorKind);
impl From<CodegenError> for syn::Error {
fn from(CodegenError(span, kind): CodegenError) -> syn::Error {
use ErrorKind::*;
match kind {
InvalidArgsLength { expected, found } => {
syn::Error::new(span, format!("The number of args does not match the type signature provided: expected={expected}, found={found}"))
}
ConstructorRetValShouldBeVoid => {
syn::Error::new(span, "Return type should be `void` (`V`) for constructor methods")
}
InvalidTypeSignature => syn::Error::new(span, "Failed to parse type signature"),
}
}
}
/// Codegen can fail with [`CodegenError`].
pub type CodegenResult<T> = Result<T, CodegenError>;
/// Describes a method that will be generated. Create one with [`MethodCall::new`] and generate
/// code with [`MethodCall::generate`]. This should be given AST nodes from the macro input so that
/// errors are associated properly to the input span.
pub struct MethodCall {
env: Expr,
method: MethodInfo,
receiver: Receiver,
arg_exprs: Vec<Expr>,
}
impl MethodCall {
/// Create a new MethodCall instance
pub fn new(env: Expr, method: MethodInfo, receiver: Receiver, arg_exprs: Vec<Expr>) -> Self {
Self {
env,
method,
receiver,
arg_exprs,
}
}
/// Generate code to call the described method.
pub fn generate(&self) -> CodegenResult<TokenStream> {
// Needs to be threaded manually to other methods since self-referential structs can't
// exist.
let sig = self.method.sig()?;
let args = self.generate_args(&sig)?;
let method_call = self
.receiver
.generate_call(&self.env, &self.method, &sig, &args)?;
// Wrap the generated code in a closure so that we can access the outer scope but any
// variables we define aren't accessible by the outer scope. There is small hygiene issue
// where the arg exprs have our `env` variable in scope. If this becomes an issue we can
// refactor these exprs to be passed as closure parameters instead.
Ok(quote! {
(|| {
#method_call
})()
})
}
/// Generate the `&[jni::jvalue]` arguments slice that will be passed to the `jni` method call.
/// This validates the argument count and types.
fn generate_args(&self, sig: &MethodSig<'_>) -> CodegenResult<Expr> {
// Safety: must check that arg count matches the signature
if self.arg_exprs.len() != sig.args.len() {
return Err(CodegenError(
self.method.sig.span(),
ErrorKind::InvalidArgsLength {
expected: sig.args.len(),
found: self.arg_exprs.len(),
},
));
}
// Create each `jvalue` expression
let type_expr_pairs = core::iter::zip(sig.args.iter().copied(), self.arg_exprs.iter());
let jvalues = type_expr_pairs.map(|(ty, expr)| generate_jvalue(ty, expr));
// Put the `jvalue` expressions in a slice.
Ok(parse_quote! {
&[#(#jvalues),*]
})
}
}
/// The receiver of the method call and the type of the method.
pub enum Receiver {
/// A constructor.
Constructor,
/// A static method.
Static,
/// An instance method. The `Expr` here is the `this` object.
Instance(Expr),
}
impl Receiver {
/// Generate the code that performs the JNI call.
fn generate_call(
&self,
env: &Expr,
method_info: &MethodInfo,
sig: &MethodSig<'_>,
args: &Expr,
) -> CodegenResult<TokenStream> {
// Constructors are void methods. Validate this fact.
if matches!(*self, Receiver::Constructor) && !sig.ret.is_void() {
return Err(CodegenError(
method_info.sig.span(),
ErrorKind::ConstructorRetValShouldBeVoid,
));
}
// The static item containing the `pourover::[Static]MethodDesc`.
let method_desc = self.generate_method_desc(method_info);
// The `jni::signature::ReturnType` that the `jni` crate uses to perform the correct native
// call.
let return_type = return_type_from_sig(sig.ret);
// A conversion expression to convert from `jni::object::JValueOwned` to the actual return
// type. We have this information from the parsed method signature whereas the `jni` crate
// only knows this at runtime.
let conversion = return_value_conversion_from_sig(sig.ret);
// This preamble is used to evaluate all the client-provided expressions outside of the
// `unsafe` block. This is the same for all receiver kinds.
let mut method_call = quote! {
#method_desc
let env: &mut ::jni::JNIEnv = #env;
let method_id = ::jni::descriptors::Desc::lookup(&METHOD_DESC, env)?;
let args: &[::jni::sys::jvalue] = #args;
};
// Generate the unsafe JNI call.
//
// Safety: `args` contains the arguments to this method. The type signature of this
// method is `#sig`.
//
// `args` must adhere to the following:
// - `args.len()` must match the number of arguments given in the type signature.
// - The union value of each arg in `args` must match the type specified in the type
// signature.
//
// These conditions are upheld by this proc macro and a compile error will be caused if
// they are broken. No user-provided code is executed within the `unsafe` block.
method_call.append_all(match self {
Self::Constructor => quote! {
unsafe {
env.new_object_unchecked(
METHOD_DESC.cls(),
method_id,
args,
)
}
},
Self::Static => quote! {
unsafe {
env.call_static_method_unchecked(
METHOD_DESC.cls(),
method_id,
#return_type,
args,
)
}#conversion
},
Self::Instance(this) => quote! {
let this_obj: &JObject = #this;
unsafe {
env.call_method_unchecked(
this_obj,
method_id,
#return_type,
args,
)
}#conversion
},
});
Ok(method_call)
}
fn generate_method_desc(&self, MethodInfo { cls, name, sig, .. }: &MethodInfo) -> ItemStatic {
match self {
Self::Constructor => parse_quote! {
static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).constructor(#sig);
},
Self::Static => parse_quote! {
static METHOD_DESC: ::pourover::desc::StaticMethodDesc = (#cls).static_method(#name, #sig);
},
Self::Instance(_) => parse_quote! {
static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).method(#name, #sig);
},
}
}
}
/// Information about the method being called
pub struct MethodInfo {
cls: Expr,
name: Expr,
sig: LitStr,
/// Derived from `sig.value()`. This string must be stored in the struct so that we can return
/// a `MethodSig` instance that references it from `MethodInfo::sig()`.
sig_str: String,
}
impl MethodInfo {
pub fn new(cls: Expr, name: Expr, sig: LitStr) -> Self {
let sig_str = sig.value();
Self {
cls,
name,
sig,
sig_str,
}
}
/// Parse the type signature from `sig`. Will return a [`CodegenError`] if the signature cannot
/// be parsed.
fn sig(&self) -> CodegenResult<MethodSig<'_>> {
MethodSig::try_from_str(&self.sig_str)
.ok_or_else(|| CodegenError(self.sig.span(), ErrorKind::InvalidTypeSignature))
}
}
/// Generate a `jni::sys::jvalue` instance given a Java type and a Rust value.
///
/// Safety: The generated `jvalue` must match the given type `ty`.
fn generate_jvalue(ty: JavaType<'_>, expr: &Expr) -> TokenStream {
// The `jvalue` field to inhabit
let union_field: Ident;
// The expected input type
let type_name: Type;
// Whether we need to call `JObject::as_raw()` on the input type
let needs_as_raw: bool;
// Fill the above values based the type signature.
match ty {
JavaType::Array { depth, ty } => {
union_field = parse_quote![l];
if let NonArray::Primitive(p) = ty {
if depth.get() == 1 {
let prim_type = prim_to_sys_type(p);
type_name = parse_quote![::jni::objects::JPrimitiveArray<'_, #prim_type>]
} else {
type_name = parse_quote![&::jni::objects::JObjectArray<'_>];
}
} else {
type_name = parse_quote![&::jni::objects::JObjectArray<'_>];
}
needs_as_raw = true;
}
JavaType::NonArray(NonArray::Object { cls }) => {
union_field = parse_quote![l];
type_name = match cls {
"java/lang/String" => parse_quote![&::jni::objects::JString<'_>],
"java/util/List" => parse_quote![&::jni::objects::JList<'_>],
"java/util/Map" => parse_quote![&::jni::objects::JMap<'_>],
_ => parse_quote![&::jni::objects::JObject<'_>],
};
needs_as_raw = true;
}
JavaType::NonArray(NonArray::Primitive(p)) => {
union_field = prim_to_union_field(p);
type_name = prim_to_sys_type(p);
needs_as_raw = false;
}
}
// The as_raw() tokens if required.
let as_raw = if needs_as_raw {
quote! { .as_raw() }
} else {
quote![]
};
// Create the `jvalue` expression. This uses `identity` to produce nice type error messages.
quote_spanned! { expr.span() =>
::jni::sys::jvalue {
#union_field: ::core::convert::identity::<#type_name>(#expr) #as_raw
}
}
}
/// Get a `::jni::signature::ReturnType` expression from a [`crate::type_parser::ReturnType`]. This
/// value is passed to the `jni` crate so that it knows which JNI method to call.
fn return_type_from_sig(ret: ReturnType<'_>) -> Expr {
let prim_type = |prim| parse_quote![::jni::signature::ReturnType::Primitive(::jni::signature::Primitive::#prim)];
use crate::type_parser::{JavaType::*, NonArray::*, Primitive::*};
match ret {
ReturnType::Void => prim_type(quote![Void]),
ReturnType::Returns(NonArray(Primitive(Boolean))) => prim_type(quote![Boolean]),
ReturnType::Returns(NonArray(Primitive(Byte))) => prim_type(quote![Byte]),
ReturnType::Returns(NonArray(Primitive(Char))) => prim_type(quote![Char]),
ReturnType::Returns(NonArray(Primitive(Double))) => prim_type(quote![Double]),
ReturnType::Returns(NonArray(Primitive(Float))) => prim_type(quote![Float]),
ReturnType::Returns(NonArray(Primitive(Int))) => prim_type(quote![Int]),
ReturnType::Returns(NonArray(Primitive(Long))) => prim_type(quote![Long]),
ReturnType::Returns(NonArray(Primitive(Short))) => prim_type(quote![Short]),
ReturnType::Returns(NonArray(Object { .. })) => {
parse_quote![::jni::signature::ReturnType::Object]
}
ReturnType::Returns(Array { .. }) => parse_quote![::jni::signature::ReturnType::Array],
}
}
/// A postfix call on a `jni::objects::JValueOwned` instance to convert it to the type specified by
/// `ret`. Since we have this information from the type signature we can perform this conversion
/// in the macro.
fn return_value_conversion_from_sig(ret: ReturnType<'_>) -> TokenStream {
use crate::type_parser::{JavaType::*, NonArray::*};
match ret {
ReturnType::Void => quote! { .and_then(::jni::objects::JValueOwned::v) },
ReturnType::Returns(NonArray(Primitive(p))) => {
let prim = prim_to_union_field(p);
quote! { .and_then(::jni::objects::JValueOwned::#prim) }
}
ReturnType::Returns(NonArray(Object { cls })) => {
let mut conversion = quote! { .and_then(::jni::objects::JValueOwned::l) };
match cls {
"java/lang/String" => {
conversion.append_all(quote! { .map(::jni::objects::JString::from) });
}
"java/util/List" => {
conversion.append_all(quote! { .map(::jni::objects::JList::from) });
}
"java/util/Map" => {
conversion.append_all(quote! { .map(::jni::objects::JMap::from) });
}
_ => {
// Already a JObject, so we are good here
}
}
conversion
}
ReturnType::Returns(Array {
depth,
ty: Primitive(p),
}) if depth.get() == 1 => {
let sys_type = prim_to_sys_type(p);
quote! {
.and_then(::jni::objects::JValueOwned::l)
.map(::jni::objects::JPrimitiveArray::<#sys_type>::from)
}
}
ReturnType::Returns(Array { .. }) => quote! {
.and_then(::jni::objects::JValueOwned::l)
.map(::jni::objects::JObjectArray::from)
},
}
}
/// From a [`Primitive`], this gets the `jni::sys::jvalue` union field name for that type. This is
/// also the `jni::objects::JValueGen` getter name.
fn prim_to_union_field(p: Primitive) -> Ident {
quote::format_ident!("{}", p.as_char().to_ascii_lowercase())
}
/// From a [`Primitive`], this gets the matching `jvalue::sys` type.
fn prim_to_sys_type(p: Primitive) -> Type {
match p {
Primitive::Boolean => parse_quote![::jni::sys::jboolean],
Primitive::Byte => parse_quote![::jni::sys::jbyte],
Primitive::Char => parse_quote![::jni::sys::jchar],
Primitive::Double => parse_quote![::jni::sys::jdouble],
Primitive::Float => parse_quote![::jni::sys::jfloat],
Primitive::Int => parse_quote![::jni::sys::jint],
Primitive::Long => parse_quote![::jni::sys::jlong],
Primitive::Short => parse_quote![::jni::sys::jshort],
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use crate::test_util::contains_ident;
use quote::ToTokens;
use syn::parse_quote;
fn example_method_call() -> MethodCall {
MethodCall::new(
parse_quote![&mut env],
MethodInfo::new(
parse_quote![&FOO_CLS],
parse_quote!["example"],
parse_quote!["(II)I"],
),
Receiver::Instance(parse_quote![&foo]),
vec![parse_quote![123], parse_quote![2 + 3]],
)
}
#[test]
fn args_are_counted() {
let mut call = example_method_call();
call.arg_exprs.push(parse_quote![too_many]);
let CodegenError(_span, kind) = call.generate().unwrap_err();
assert_eq!(
ErrorKind::InvalidArgsLength {
expected: 2,
found: 3
},
kind
);
}
#[test]
fn constructor_return_type_is_void() {
let mut call = example_method_call();
call.receiver = Receiver::Constructor;
let CodegenError(_span, kind) = call.generate().unwrap_err();
assert_eq!(ErrorKind::ConstructorRetValShouldBeVoid, kind);
}
#[test]
fn invalid_type_sig_is_error() {
let mut call = example_method_call();
call.method.sig = parse_quote!["L"];
call.method.sig_str = call.method.sig.value();
let CodegenError(_span, kind) = call.generate().unwrap_err();
assert_eq!(ErrorKind::InvalidTypeSignature, kind);
}
#[test]
fn jni_types_are_used_for_stdlib_classes_input() {
let types = [
("Ljava/lang/String;", "JString"),
("Ljava/util/Map;", "JMap"),
("Ljava/util/List;", "JList"),
("[Ljava/lang/String;", "JObjectArray"),
("[[I", "JObjectArray"),
("[I", "JPrimitiveArray"),
("Lcom/example/MyObject;", "JObject"),
("Z", "jboolean"),
("C", "jchar"),
("B", "jbyte"),
("S", "jshort"),
("I", "jint"),
("J", "jlong"),
("F", "jfloat"),
("D", "jdouble"),
];
for (desc, jni_type) in types {
let jt = JavaType::try_from_str(desc).unwrap();
let expr = parse_quote![some_value];
let jvalue = generate_jvalue(jt, &expr);
assert!(
contains_ident(jvalue, jni_type),
"desc: {desc}, jni_type: {jni_type}"
);
}
}
#[test]
fn jni_types_are_used_for_stdlib_classes_output() {
let types = [
("Ljava/lang/String;", "JString"),
("Ljava/util/Map;", "JMap"),
("Ljava/util/List;", "JList"),
("[Ljava/lang/String;", "JObjectArray"),
("[[I", "JObjectArray"),
("[I", "JPrimitiveArray"),
];
for (desc, jni_type) in types {
let rt = ReturnType::try_from_str(desc).unwrap();
let conversion = return_value_conversion_from_sig(rt);
assert!(
contains_ident(conversion, jni_type),
"desc: {desc}, jni_type: {jni_type}"
);
}
}
#[test]
fn return_type_passed_to_jni_is_correct() {
let types = [
("Ljava/lang/String;", "Object"),
("Ljava/util/Map;", "Object"),
("Ljava/util/List;", "Object"),
("[Ljava/lang/String;", "Array"),
("[[I", "Array"),
("[I", "Array"),
("V", "Void"),
("Z", "Boolean"),
("C", "Char"),
("B", "Byte"),
("S", "Short"),
("I", "Int"),
("J", "Long"),
("F", "Float"),
("D", "Double"),
];
for (desc, return_type) in types {
let rt = ReturnType::try_from_str(desc).unwrap();
let expr = return_type_from_sig(rt).into_token_stream();
assert!(
contains_ident(expr, return_type),
"desc: {desc}, return_type: {return_type}"
);
}
}
#[test]
fn method_desc_is_correct() {
let mut call = example_method_call();
call.method.sig = parse_quote!["(II)V"];
call.method.sig_str = call.method.sig.value();
let tests = [
(Receiver::Constructor, "constructor"),
(Receiver::Static, "static_method"),
(Receiver::Instance(parse_quote![this_value]), "method"),
];
for (receiver, method_ident) in tests {
let desc = receiver.generate_method_desc(&call.method);
let rhs = desc.expr.into_token_stream();
assert!(contains_ident(rhs, method_ident), "method: {method_ident}");
}
}
#[test]
fn jni_call_is_correct() {
let mut call = example_method_call();
call.method.sig = parse_quote!["(II)V"];
call.method.sig_str = call.method.sig.value();
let sig = call.method.sig().unwrap();
let args = parse_quote![test_stub];
let tests = [
(Receiver::Constructor, "new_object_unchecked"),
(Receiver::Static, "call_static_method_unchecked"),
(
Receiver::Instance(parse_quote![this_value]),
"call_method_unchecked",
),
];
for (receiver, method_ident) in tests {
let call = receiver
.generate_call(&call.env, &call.method, &sig, &args)
.unwrap();
assert!(contains_ident(call, method_ident), "method: {method_ident}");
}
}
}