blob: 275a7d0d0a2a39f2efeee1b906b25af06593bdc0 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Tools for checking for, or adding, headers (e.g. licenses, etc) in files.
use std::{
fs,
io::{self, BufRead as _, Write as _},
iter::FromIterator,
path, thread,
};
pub mod license;
/// A file header to check for or add to files.
#[derive(Clone)]
pub struct Header<C: HeaderChecker> {
checker: C,
header: String,
}
impl<C: HeaderChecker> Header<C> {
/// Construct a new `Header` with the `checker` used to determine if the header is already
/// present, and the plain `header` text to add (without any applicable comment syntax, etc).
pub fn new(checker: C, header: String) -> Self {
Self { checker, header }
}
/// Return true if the file has the desired header, false otherwise.
pub fn header_present(&self, input: &mut impl io::Read) -> io::Result<bool> {
self.checker.check(input)
}
/// Add the header, with appropriate formatting for the type of file indicated by `p`'s
/// extension, if the header is not already present.
/// Returns true if the header was added.
pub fn add_header_if_missing(&self, p: &path::Path) -> Result<bool, AddHeaderError> {
let err_mapper = |e| AddHeaderError::IoError(p.to_path_buf(), e);
let contents = fs::read_to_string(p).map_err(err_mapper)?;
if self.header_present(&mut contents.as_bytes()).map_err(err_mapper)? {
return Ok(false);
}
let mut effective_header = header_delimiters(p)
.ok_or_else(|| AddHeaderError::UnknownExtension(p.to_path_buf()))
.map(|d| wrap_header(&self.header, d))?;
let mut after_header = contents.as_str();
// check for a magic first line
if let Some((first_line, rest)) = contents.split_once('\n') {
if MAGIC_FIRST_LINES.iter().any(|l| first_line.contains(l)) {
let mut first_line = first_line.to_string();
first_line.push('\n');
effective_header.insert_str(0, &first_line);
after_header = rest;
}
}
// write the license
let mut f =
fs::OpenOptions::new().write(true).truncate(true).open(p).map_err(err_mapper)?;
f.write_all(effective_header.as_bytes()).map_err(err_mapper)?;
// newline to separate the header from previous contents
f.write_all("\n".as_bytes()).map_err(err_mapper)?;
f.write_all(after_header.as_bytes()).map_err(err_mapper)?;
Ok(true)
}
}
/// Errors that can occur when adding a header
#[derive(Debug, thiserror::Error)]
pub enum AddHeaderError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Unknown file extension: {0:?}")]
UnknownExtension(path::PathBuf),
}
/// Checks for headers in files, like licenses or author attribution.
pub trait HeaderChecker: Send + Clone {
/// Return true if the file has the desired header, false otherwise.
fn check(&self, file: &mut impl io::Read) -> io::Result<bool>;
}
/// Checks for a in the first several lines of each file.
#[derive(Clone)]
pub struct SingleLineChecker {
/// Pattern to do a substring match on in each of the first `max_lines` lines of the file
pattern: String,
/// Number of lines to search through
max_lines: usize,
}
impl SingleLineChecker {
/// Construct a `SingleLineChecker` that looks for `pattern` in the first `max_lines` of a file.
pub(crate) fn new(pattern: String, max_lines: usize) -> Self {
Self { pattern, max_lines }
}
}
impl HeaderChecker for SingleLineChecker {
fn check(&self, input: &mut impl io::Read) -> io::Result<bool> {
let mut reader = io::BufReader::new(input);
let mut lines_read = 0;
// reuse buffer to minimize allocation
let mut line = String::new();
// only read the first bit of the file
while lines_read < self.max_lines {
line.clear();
let bytes = reader.read_line(&mut line)?;
if bytes == 0 {
// EOF
return Ok(false);
}
lines_read += 1;
if line.contains(&self.pattern) {
return Ok(true);
}
}
Ok(false)
}
}
#[derive(Copy, Clone)]
enum CheckStatus {
MisMatchedHeader,
BinaryFile,
}
#[derive(Clone)]
struct FileResult {
path: path::PathBuf,
status: CheckStatus,
}
#[derive(Clone, Default)]
pub struct FileResults {
pub mismatched_files: Vec<path::PathBuf>,
pub binary_files: Vec<path::PathBuf>,
}
impl FileResults {
pub fn has_failure(&self) -> bool {
!self.mismatched_files.is_empty() || !self.binary_files.is_empty()
}
}
impl FromIterator<FileResult> for FileResults {
fn from_iter<I>(iter: I) -> FileResults
where
I: IntoIterator<Item = FileResult>,
{
let mut results = FileResults::default();
for result in iter {
match result.status {
CheckStatus::MisMatchedHeader => results.mismatched_files.push(result.path),
CheckStatus::BinaryFile => results.binary_files.push(result.path),
}
}
results
}
}
/// Recursively check for `header` in every file in `root` that matches `path_predicate`.
///
/// Returns a [`FileResults`] object containing the paths without headers detected.
pub fn check_headers_recursively(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
header: Header<impl HeaderChecker + 'static>,
num_threads: usize,
) -> Result<FileResults, CheckHeadersRecursivelyError> {
let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
let (result_tx, result_rx) = crossbeam::channel::unbounded();
// spawn a few threads to handle files in parallel
let handles = (0..num_threads)
.map(|_| {
let path_rx = path_rx.clone();
let result_tx = result_tx.clone();
let header = header.clone();
thread::spawn(move || {
for p in path_rx {
match fs::File::open(&p).and_then(|mut f| header.header_present(&mut f)) {
Ok(header_present) => {
if header_present {
// no op
} else {
let res =
FileResult { path: p, status: CheckStatus::MisMatchedHeader };
result_tx.send(Ok(res)).unwrap();
}
}
Err(e) if e.kind() == io::ErrorKind::InvalidData => {
// Binary file - add to ignore in license.rs
let res = FileResult { path: p, status: CheckStatus::BinaryFile };
result_tx.send(Ok(res)).unwrap();
}
Err(e) => result_tx
.send(Err(CheckHeadersRecursivelyError::IoError(p, e)))
.unwrap(),
}
}
// no more files
})
})
.collect::<Vec<thread::JoinHandle<()>>>();
// make sure result channel closes when threads complete
drop(result_tx);
find_files(root, path_predicate, path_tx)?;
let res: FileResults = result_rx.into_iter().collect::<Result<_, _>>()?;
for h in handles {
h.join().unwrap();
}
Ok(res)
}
/// Errors that can occur when checking for headers recursively
#[derive(Debug, thiserror::Error)]
pub enum CheckHeadersRecursivelyError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Walkdir error: {0}")]
WalkdirError(#[from] walkdir::Error),
}
/// Add the provided `header` to any file in `root` that matches `path_predicate` and that doesn't
/// already have a header as determined by `checker`.
/// Returns a list of paths that had headers added.
pub fn add_headers_recursively(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
header: Header<impl HeaderChecker>,
) -> Result<Vec<path::PathBuf>, AddHeadersRecursivelyError> {
// likely no need for threading since adding headers is only done occasionally
let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
find_files(root, path_predicate, path_tx)?;
path_rx
.into_iter()
// keep the errors, or the ones with added headers
.filter_map(|p| {
match header.add_header_if_missing(&p).map_err(|e| match e {
AddHeaderError::IoError(p, e) => AddHeadersRecursivelyError::IoError(p, e),
AddHeaderError::UnknownExtension(e) => {
AddHeadersRecursivelyError::UnknownExtension(e)
}
}) {
Ok(added) => {
if added {
Some(Ok(p))
} else {
None
}
}
Err(e) => Some(Err(e)),
}
})
.collect::<Result<Vec<_>, _>>()
}
/// Errors that can occur when adding a header recursively
#[derive(Debug, thiserror::Error)]
pub enum AddHeadersRecursivelyError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Walkdir error: {0}")]
WalkdirError(#[from] walkdir::Error),
#[error("Unknown file extension: {0:?}")]
UnknownExtension(path::PathBuf),
}
/// Find all files starting from `root` that do not match the globs in `ignore`, publishing the
/// resulting paths into `dest`.
fn find_files(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
dest: crossbeam::channel::Sender<path::PathBuf>,
) -> Result<(), walkdir::Error> {
for r in walkdir::WalkDir::new(root).into_iter() {
let entry = r?;
if entry.path().is_dir() || !path_predicate(entry.path()) {
continue;
}
dest.send(entry.into_path()).unwrap()
}
Ok(())
}
/// Prepare a header for inclusion in a particular file syntax by wrapping it with
/// comment characters as per the provided `delim`.
fn wrap_header(orig_header: &str, delim: HeaderDelimiters) -> String {
let mut out = String::new();
if !delim.first_line.is_empty() {
out.push_str(delim.first_line);
out.push('\n');
}
// assumes header uses \n
for line in orig_header.split('\n') {
out.push_str(delim.content_line_prefix);
out.push_str(line);
// Remove any trailing whitespaces (excluding newlines) from `content_line_prefix + line`.
// For example, if `content_line_prefix` is `// ` and `line` is empty, the resulting string
// should be truncated to `//`.
out.truncate(out.trim_end_matches([' ', '\t']).len());
out.push('\n');
}
if !delim.last_line.is_empty() {
out.push_str(delim.last_line);
out.push('\n');
}
out
}
/// Returns the header prefix line, content line prefix, and suffix line for the extension of the
/// provided path, or `None` if the extension is not recognized.
fn header_delimiters(p: &path::Path) -> Option<HeaderDelimiters> {
match p
.extension()
// if the extension isn't UTF-8, oh well
.and_then(|os_str| os_str.to_str())
.unwrap_or("")
{
"c" | "h" | "gv" | "java" | "scala" | "kt" | "kts" => Some(("/*", " * ", " */")),
"js" | "mjs" | "cjs" | "jsx" | "tsx" | "css" | "scss" | "sass" | "ts" => {
Some(("/**", " * ", " */"))
}
"cc" | "cpp" | "cs" | "go" | "hcl" | "hh" | "hpp" | "m" | "mm" | "proto" | "rs"
| "swift" | "dart" | "groovy" | "v" | "sv" => Some(("", "// ", "")),
"py" | "sh" | "yaml" | "yml" | "dockerfile" | "rb" | "gemfile" | "tcl" | "tf" | "bzl"
| "pl" | "pp" | "build" => Some(("", "# ", "")),
"el" | "lisp" => Some(("", ";; ", "")),
"erl" => Some(("", "% ", "")),
"hs" | "lua" | "sql" | "sdl" => Some(("", "-- ", "")),
"html" | "xml" | "vue" | "wxi" | "wxl" | "wxs" => Some(("<!--", " ", "-->")),
"php" => Some(("", "// ", "")),
"ml" | "mli" | "mll" | "mly" => Some(("(**", " ", "*)")),
// also handle whole filenames if extensions didn't match
_ => match p.file_name().and_then(|os_str| os_str.to_str()).unwrap_or("") {
"Dockerfile" => Some(("", "# ", "")),
_ => None,
},
}
.map(|(first_line, content_line_prefix, last_line)| HeaderDelimiters {
first_line,
content_line_prefix,
last_line,
})
}
/// Delimiters to use around and inside a header for a particular file syntax.
#[derive(Clone, Copy)]
struct HeaderDelimiters {
/// Line to prepend before the header
first_line: &'static str,
/// Prefix before each line of the header itself
content_line_prefix: &'static str,
/// Line to append after the header
last_line: &'static str,
}
const MAGIC_FIRST_LINES: [&str; 8] = [
"#!", // shell script
"<?xml", // XML declaratioon
"<!doctype", // HTML doctype
"# encoding:", // Ruby encoding
"# frozen_string_literal:", // Ruby interpreter instruction
"<?php", // PHP opening tag
"# escape", // Dockerfile directive https://docs.docker.com/engine/reference/builder/#parser-directives
"# syntax", // Dockerfile directive https://docs.docker.com/engine/reference/builder/#parser-directives
];
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_line_checker_finds_header_when_present() {
let input = r#"foo
some license
bar"#;
assert!(test_header().checker.check(&mut input.as_bytes()).unwrap());
}
#[test]
fn single_line_checker_doesnt_find_header_when_missing() {
let input = r#"foo
wrong license
bar"#;
assert!(!test_header().checker.check(&mut input.as_bytes()).unwrap());
}
#[test]
fn single_line_checker_throws_error_when_missing_and_file_is_non_utf8() {
let input = b"foo
\x00\xff
bar";
assert_eq!(
io::ErrorKind::InvalidData,
test_header().checker.check(&mut input.as_slice()).unwrap_err().kind()
);
}
#[test]
fn single_line_checker_doesnt_panic_when_file_is_non_utf8() {
let inputs: [&'static [u8]; 3] = [
b"foo
\x00\xff
bar",
b"foo
some license
\x00\xff
bar",
b"foo
\x00\xff
some license
bar",
];
for mut input in inputs {
// Output is not defined for non-utf-8 files, but we should handle them with grace
let _ = test_header().checker.check(&mut input);
}
}
#[test]
fn adds_header_with_empty_delimiters() {
let file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap();
fs::write(file.path(), r#"not a license"#).unwrap();
test_header().add_header_if_missing(file.path()).unwrap();
assert_eq!(
"// some license etc etc etc
not a license",
fs::read_to_string(file.path()).unwrap()
);
}
#[test]
fn adds_header_with_nonempty_delimiters() {
let file = tempfile::Builder::new().suffix(".c").tempfile().unwrap();
fs::write(file.path(), r#"not a license"#).unwrap();
test_header().add_header_if_missing(file.path()).unwrap();
assert_eq!(
"/*
* some license etc etc etc
*/
not a license",
fs::read_to_string(file.path()).unwrap()
);
}
#[test]
fn adds_header_trim_trailing_whitespace() {
let file = tempfile::Builder::new().suffix(".c").tempfile().unwrap();
fs::write(file.path(), r#"not a license"#).unwrap();
test_header_with_blank_lines_and_trailing_whitespace()
.add_header_if_missing(file.path())
.unwrap();
assert_eq!(
"/*
* some license
* line with trailing whitespace.
*
* etc
*/
not a license",
fs::read_to_string(file.path()).unwrap()
);
}
#[test]
fn doesnt_add_header_when_already_present() {
let file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap();
let initial_content = r#"
// some license etc etc etc already present
not a license"#;
fs::write(file.path(), initial_content).unwrap();
test_header().add_header_if_missing(file.path()).unwrap();
assert_eq!(initial_content, fs::read_to_string(file.path()).unwrap());
}
#[test]
fn adds_header_after_magic_first_line() {
let file = tempfile::Builder::new().suffix(".xml").tempfile().unwrap();
fs::write(
file.path(),
r#"<?xml version="1.0" encoding="UTF-8"?>
<root />
"#,
)
.unwrap();
test_header().add_header_if_missing(file.path()).unwrap();
assert_eq!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<!--
some license etc etc etc
-->
<root />
"#,
fs::read_to_string(file.path()).unwrap()
);
}
fn test_header() -> Header<SingleLineChecker> {
Header::new(
SingleLineChecker::new("some license".to_string(), 100),
r#"some license etc etc etc"#.to_string(),
)
}
fn test_header_with_blank_lines_and_trailing_whitespace() -> Header<SingleLineChecker> {
Header::new(
SingleLineChecker::new("some license".to_string(), 100),
"some license\nline with trailing whitespace. \n\netc".to_string(),
)
}
}