| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| //! Provides a sample ukey2 shell app which can be run from the command line |
| |
| #![allow(clippy::expect_used)] |
| //TODO: remove this and fix instances of unwrap |
| #![allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)] |
| |
| use std::io::{Read, Write}; |
| use std::process::exit; |
| |
| use clap::Parser; |
| |
| use crypto_provider_rustcrypto::RustCrypto; |
| use ukey2_connections::{ |
| D2DConnectionContextV1, D2DHandshakeContext, HandshakeImplementation, |
| InitiatorD2DHandshakeContext, NextProtocol, ServerD2DHandshakeContext, |
| }; |
| |
| const MODE_INITIATOR: &str = "initiator"; |
| const MODE_RESPONDER: &str = "responder"; |
| |
| #[derive(Parser, Debug)] |
| struct Ukey2Cli { |
| /// initiator or responder mode |
| #[arg(short, long)] |
| mode: String, |
| /// length of auth string/next proto secret |
| #[arg(short, long, default_value_t = 32)] |
| verification_string_length: i32, |
| } |
| |
| /// Framing functions |
| /* |
| // Writes |message| to stdout in the frame format. |
| void WriteFrame(const string& message) { |
| // Write length of |message| in little-endian. |
| const uint32_t length = message.length(); |
| fputc((length >> (3 * 8)) & 0xFF, stdout); |
| fputc((length >> (2 * 8)) & 0xFF, stdout); |
| fputc((length >> (1 * 8)) & 0xFF, stdout); |
| fputc((length >> (0 * 8)) & 0xFF, stdout); |
| |
| // Write message to stdout. |
| CHECK_EQ(message.length(), |
| fwrite(message.c_str(), 1, message.length(), stdout)); |
| CHECK_EQ(0, fflush(stdout)); |
| } |
| */ |
| fn write_frame(message: Vec<u8>) { |
| let length: u32 = message.len() as u32; |
| let length_bytes = length.to_be_bytes(); |
| std::io::stdout().write_all(&length_bytes).unwrap(); |
| std::io::stdout().write_all(message.as_slice()).expect("failed to write message"); |
| let _ = std::io::stdout().flush(); |
| } |
| |
| /* |
| // Returns a message read from stdin after parsing it from the frame format. |
| string ReadFrame() { |
| // Read length of the frame from the stream. |
| uint8_t length_data[sizeof(uint32_t)]; |
| CHECK_EQ(sizeof(uint32_t), fread(&length_data, 1, sizeof(uint32_t), stdin)); |
| |
| uint32_t length = 0; |
| length |= static_cast<uint32_t>(length_data[0]) << (3 * 8); |
| length |= static_cast<uint32_t>(length_data[1]) << (2 * 8); |
| length |= static_cast<uint32_t>(length_data[2]) << (1 * 8); |
| length |= static_cast<uint32_t>(length_data[3]) << (0 * 8); |
| |
| // Read |length| bytes from the stream. |
| absl::FixedArray<char> buffer(length); |
| CHECK_EQ(length, fread(buffer.data(), 1, length, stdin)); |
| |
| return string(buffer.data(), length); |
| } |
| |
| */ |
| const LENGTH: usize = std::mem::size_of::<u32>(); |
| |
| fn read_frame() -> Vec<u8> { |
| let mut length_buf = [0u8; LENGTH]; |
| assert_eq!(LENGTH, std::io::stdin().read(&mut length_buf).unwrap()); |
| let length_usize = u32::from_be_bytes(length_buf); |
| let mut buffer = vec![0u8; length_usize as usize]; |
| std::io::stdin().read_exact(buffer.as_mut_slice()).expect("failed to read frame"); |
| buffer |
| } |
| |
| struct Ukey2Shell { |
| verification_string_length: usize, |
| } |
| |
| impl Ukey2Shell { |
| fn new(verification_string_length: i32) -> Self { |
| Self { verification_string_length: verification_string_length as usize } |
| } |
| |
| fn run_secure_connection_loop(connection_ctx: &mut D2DConnectionContextV1) -> bool { |
| loop { |
| let input = read_frame(); |
| let idx = input.iter().enumerate().find(|(_index, &byte)| byte == 0x20).unwrap().0; |
| let (cmd, payload) = (&input[0..idx], &input[idx + 1..]); |
| if cmd == b"encrypt" { |
| let result = |
| connection_ctx.encode_message_to_peer::<RustCrypto, &[u8]>(payload, None); |
| write_frame(result); |
| } else if cmd == b"decrypt" { |
| let result = |
| connection_ctx.decode_message_from_peer::<RustCrypto, &[u8]>(payload, None); |
| if result.is_err() { |
| println!("failed to decode payload"); |
| return false; |
| } |
| write_frame(result.unwrap()); |
| } else if cmd == b"session_unique" { |
| let result = connection_ctx.get_session_unique::<RustCrypto>(); |
| write_frame(result); |
| } else { |
| println!("unknown command"); |
| return false; |
| } |
| } |
| } |
| |
| fn run_as_initiator(&self) -> bool { |
| let mut initiator_ctx = InitiatorD2DHandshakeContext::<RustCrypto, _>::new( |
| HandshakeImplementation::PublicKeyInProtobuf, |
| vec![NextProtocol::Aes256CbcHmacSha256, NextProtocol::Aes256GcmSiv], |
| ); |
| write_frame(initiator_ctx.get_next_handshake_message().unwrap()); |
| let server_init_msg = read_frame(); |
| initiator_ctx |
| .handle_handshake_message(server_init_msg.as_slice()) |
| .expect("Failed to handle message"); |
| write_frame(initiator_ctx.get_next_handshake_message().unwrap_or_default()); |
| // confirm auth str |
| let auth_str = initiator_ctx |
| .to_completed_handshake() |
| .ok() |
| .and_then(|h| h.auth_string::<RustCrypto>().derive_vec(self.verification_string_length)) |
| .unwrap_or_else(|| vec![0; self.verification_string_length]); |
| write_frame(auth_str); |
| let ack = read_frame(); |
| if ack != "ok".to_string().into_bytes() { |
| println!("handshake failed"); |
| return false; |
| } |
| // upgrade to connection context |
| let mut initiator_conn_ctx = initiator_ctx.to_connection_context().unwrap(); |
| Self::run_secure_connection_loop(&mut initiator_conn_ctx) |
| } |
| |
| fn run_as_responder(&self) -> bool { |
| let mut server_ctx = ServerD2DHandshakeContext::<RustCrypto, _>::new( |
| HandshakeImplementation::PublicKeyInProtobuf, |
| &[NextProtocol::Aes256GcmSiv, NextProtocol::Aes256CbcHmacSha256], |
| ); |
| let initiator_init_msg = read_frame(); |
| server_ctx.handle_handshake_message(initiator_init_msg.as_slice()).unwrap(); |
| let server_next_msg = server_ctx.get_next_handshake_message().unwrap(); |
| write_frame(server_next_msg); |
| let initiator_finish_msg = read_frame(); |
| server_ctx |
| .handle_handshake_message(initiator_finish_msg.as_slice()) |
| .expect("Failed to handle message"); |
| // confirm auth str |
| let auth_str = server_ctx |
| .to_completed_handshake() |
| .ok() |
| .and_then(|h| h.auth_string::<RustCrypto>().derive_vec(self.verification_string_length)) |
| .unwrap_or_else(|| vec![0; self.verification_string_length]); |
| write_frame(auth_str); |
| let ack = read_frame(); |
| if ack != "ok".to_string().into_bytes() { |
| println!("handshake failed"); |
| return false; |
| } |
| // upgrade to connection context |
| let mut server_conn_ctx = server_ctx.to_connection_context().unwrap(); |
| Self::run_secure_connection_loop(&mut server_conn_ctx) |
| } |
| } |
| |
| fn main() { |
| let args = Ukey2Cli::parse(); |
| let shell = Ukey2Shell::new(args.verification_string_length); |
| if args.mode == MODE_INITIATOR { |
| let _ = shell.run_as_initiator(); |
| } else if args.mode == MODE_RESPONDER { |
| let _ = shell.run_as_responder(); |
| } else { |
| exit(1); |
| } |
| exit(0) |
| } |