diff --git a/src/lib/crypto/secretbox.mjs b/src/lib/crypto/secretbox.mjs index 67248c3..cf7c961 100644 --- a/src/lib/crypto/secretbox.mjs +++ b/src/lib/crypto/secretbox.mjs @@ -3,6 +3,10 @@ import tweetnacl from 'tweetnacl'; const { randomBytes, secretbox } = tweetnacl; +function is_uint8array(buffer) { + return buffer instanceof Uint8Array; +} + /** * Creates a new key ready for encryption. * @return {string} A new base64-encoded key. @@ -25,11 +29,14 @@ function encrypt(key, data) { /** * Encrypts the given data with the given key. - * @param {Buffer} key The key to use to encrypt the data. - * @param {Buffer} data The data to encrypt. + * @param {Buffer|Uint8Array} key The key to use to encrypt the data. + * @param {Buffer|Uint8Array} data The data to encrypt. * @return {Buffer} The encrypted data. */ function encrypt_bytes(key_bytes, data_bytes) { + if(!is_uint8array(key_bytes)) throw new Error(`Error: Expected key_bytes to be of type Uint8Array, but got ${typeof key_bytes}`); + if(!is_uint8array(data_bytes)) throw new Error(`Error: Expected data_bytes to be of type Uint8Array, but got ${typeof data_bytes}`); + const nonce = randomBytes(secretbox.nonceLength); const cipher_bytes = secretbox(data_bytes, nonce, key_bytes); @@ -59,11 +66,13 @@ function decrypt(key, cipher_text) { /** * Decrypts the given data with the given key. - * @param {Buffer} key The key to use to decrypt the data. - * @param {Buffer} cipher_text The ciphertext to decrypt. + * @param {Buffer|Uint8Array} key The key to use to decrypt the data. + * @param {Buffer|Uint8Array} cipher_text The ciphertext to decrypt. * @return {Buffer} The decoded data. */ function decrypt_bytes(key_bytes, cipher_text_bytes) { + if(!is_uint8array(key_bytes)) throw new Error(`Error: Expected key_bytes to be of type Uint8Array, but got ${typeof key_bytes}`); + if(!is_uint8array(cipher_text_bytes)) throw new Error(`Error: Expected cipher_text_bytes to be of type Uint8Array, but got ${typeof cipher_text_bytes}`); const nonce = cipher_text_bytes.slice(0, secretbox.nonceLength); const cipher_bytes = cipher_text_bytes.slice(secretbox.nonceLength); diff --git a/src/lib/transport/Connection.mjs b/src/lib/transport/Connection.mjs index 2224337..8c269ef 100644 --- a/src/lib/transport/Connection.mjs +++ b/src/lib/transport/Connection.mjs @@ -4,7 +4,8 @@ import crypto from 'crypto'; import net from 'net'; import { EventEmitter, once } from 'events'; -import l from 'log'; +import log from 'log'; +const l = log.get("connection"); import settings from '../../settings.mjs'; import rekey from './rekey.mjs'; @@ -14,11 +15,16 @@ import { encrypt_bytes, decrypt_bytes } from '../crypto/secretbox.mjs'; /** * Represents a connection to a single endpoint. + * @param {string} secret_join The shared join secret, encoded as base64 + * @param {net.Socket?} socket Optional. A pre-existing socket to take over and manage. */ class Connection extends EventEmitter { constructor(secret_join, socket) { super(); + if(typeof secret_join !== "string") + throw new Error(`Error: Expected secret_join to be of type string, but received variable of type ${typeof secret_join}`); + this.socket = socket; this.rekey_last = null; @@ -37,24 +43,27 @@ class Connection extends EventEmitter { */ async connect(address, port) { this.address = address; this.port = port; - this.socket = new new.Socket(); + this.socket = new net.Socket(); this.socket.connect({ address, port }); + this.socket.once("end", () => { + l.notice(`${this.address}:${this.port} disconnected`); + }); await once(this.socket, "connect"); await this.init(); } async init() { + this.address = this.socket.remoteAddress; + this.port = this.socket.remotePort; this.socket.setKeepAlive(true); - await this.rekey(); - this.framer = new FramedTransport(this.socket); - this.framer.on("frame", this.handle_frame); + this.framer.on("frame", this.handle_frame.bind(this)); - this.read_task = read_loop(); + await this.rekey(); } async rekey() { @@ -66,7 +75,7 @@ class Connection extends EventEmitter { this.emit("rekey"); } catch(error) { - l.warn(`Error when rekeying connection ${this.address}:${this.port}: ${settings.cli.verbose ? error : error.message}, killing connection`); + l.warn(`Error when rekeying connection ${this.address}:${this.port}, killing connection`, settings.cli.verbose ? error : error.message); await this.destroy(); } finally { @@ -75,20 +84,28 @@ class Connection extends EventEmitter { } async destroy() { - await this.framer.destroy(); + l.info(`Killing connection to ${this.address}:${this.port}`, new Error().stack); + if(this.framer instanceof FramedTransport) + await this.framer.destroy(); + else { + await this.socket.end(); + await this.socket.destroy(); + } this.emit("destroy"); } async handle_frame(bytes) { try { - let decrypted = decrypt_bytes(this.session_key, bytes); - if(decrypted === null) return; - await handle_message(decrypted.toString("utf-8")); + l.info(`FRAME length`, bytes.length); + let decrypted = decrypt_bytes(this.session_key, new Uint8Array(bytes)); + if(decrypted === null) { + l.warn(`Decryption of message failed`); + return; + } + await this.handle_message(decrypted.toString("utf-8")); } catch(error) { - l.warn(`Warning: Killing connection to ${this.address}:${this.port} after error: ${settings.cli.verbose ? error : error.message}`); - } - finally { + l.warn(`Warning: Killing connection to ${this.address}:${this.port} after error:`, settings.cli.verbose ? error : error.message); this.destroy(); } } @@ -96,9 +113,11 @@ class Connection extends EventEmitter { async handle_message(msg_text) { const msg = JSON.parse(msg_text); - if(msg.event == "rekey") { + l.log(`MESSAGE:${msg.event} content`, msg.message); + + if(msg.event == "rekey" && !this.rekey_in_progress) { // Set and forget here - if(!this.rekey_in_progress) this.rekey(); + this.rekey(); } this.emit("message", msg.event, msg.message); this.emit(`message-${msg.event}`, msg.message); @@ -114,23 +133,27 @@ class Connection extends EventEmitter { // TODO: Consider anonymous TLS, with jpake for mututal authentication // TODO: Consider https://devdocs.io/node/crypto#crypto.createCipheriv() - which lets us use any openssl ciphers we like - e.g. ChaCha20-Poly1305 let payload = JSON.stringify({ event, message }); - payload = encrypt_bytes(this.session_key, payload); - + payload = encrypt_bytes( + this.session_key, + Buffer.from(payload, "utf-8") + ); await this.framer.write(payload); } } Connection.Wrap = async function(secret_join, socket) { - const socket = new Connection(secret_join, socket); - await socket.init(); + const socket_wrap = new Connection(secret_join, socket); + await socket_wrap.init(); - return socket; + return socket_wrap; } Connection.Create = async function(secret_join, address, port) { const socket = new Connection(secret_join); - socket.connect(address, port); + await socket.connect(address, port); + + return socket; } export default Connection; diff --git a/src/lib/transport/FramedTransport.mjs b/src/lib/transport/FramedTransport.mjs index ceb7dcf..2b1560e 100644 --- a/src/lib/transport/FramedTransport.mjs +++ b/src/lib/transport/FramedTransport.mjs @@ -1,5 +1,10 @@ "use strict"; +import os from 'os'; + +import log from 'log'; +const l = log.get("framedtransport"); + import { EventEmitter, once } from 'events'; import { write_safe, end_safe } from '../io/StreamHelpers.mjs'; @@ -10,10 +15,12 @@ import { write_safe, end_safe } from '../io/StreamHelpers.mjs'; * * */ -class FramedTransport { +class FramedTransport extends EventEmitter { constructor(socket) { + super(); + this.socket = socket; - this.socket.on("data", this.handle_chunk); + this.socket.on("data", this.handle_chunk.bind(this)); this.buffer = null; /** The length of a uint in bytes @type {number} */ @@ -29,18 +36,29 @@ class FramedTransport { * @return {void} */ handle_chunk(chunk) { - if(this.buffer instanceof Buffer) - this.buffer = Buffer.concat(this.buffer, chunk); + l.debug(`CHUNK length`, chunk.length, `buffer length`, this.buffer === null ? 0 : this.buffer.length); + if(this.buffer instanceof Buffer) { + this.buffer = Buffer.concat([ this.buffer, chunk ]); + + } + else + this.buffer = chunk; + + l.debug(`CHUNK buffer`, this.buffer); let next_frame_length = this.buffer2uint(this.buffer, 0); + l.debug(`CHUNK total length`, this.buffer.length, `next_frame_length`, next_frame_length); + + // We have enough data! Emit a frame and then start again. if(this.buffer.length - this.uint_length >= next_frame_length) { - this.emit("frame", Buffer.slice(this.uint_length, next_frame_length)); + this.emit("frame", this.buffer.slice(this.uint_length, next_frame_length)); if(this.buffer.length - (this.uint_length + next_frame_length) > 0) - this.buffer = Buffer.slice(this,uint_length + next_frame_length); + this.buffer = this.buffer.slice(this.uint_length + next_frame_length); else this.buffer = null; + l.info(`FRAME length`, next_frame_length, `length_remaining`, this.buffer === null ? 0 : this.buffer.length); } } @@ -52,7 +70,13 @@ class FramedTransport { * @return {number} The parsed number from the buffer. */ buffer2uint(buffer, pos32) { - return new DataView(buffer).getUint32(pos32, false); + l.debug(`BUFFER2UINT buffer`, buffer, `pos32`, pos32); + // DataView doesn't work as expected here, because Node.js optimises small buffer to be views of 1 larger buffer. + let u8 = new Uint8Array(buffer).slice(pos32, this.uint_length); + // Convert from network byte order if necessary + if(os.endianness() == "LE") u8.reverse(); + l.debug(`BUFFER2UINT u8`, u8, `u32`, new Uint32Array(u8.buffer)); + return new Uint32Array(u8.buffer)[0]; } /** @@ -61,8 +85,13 @@ class FramedTransport { * @return {Buffer} A new buffer representing the given number. */ uint2buffer(uint) { + // DataView doesn't work as expected here, because Node.js optimises small buffer to be views of 1 larger buffer. let array = new ArrayBuffer(this.uint_length); - new DataView(array).setUint32(0, uint, false); + const u8 = new Uint8Array(array); + const u32 = new Uint32Array(array); + u32[0] = uint; + // Host to network byte order, if appropriate + if(os.endianness() == "LE") u8.reverse(); return Buffer.from(array); } @@ -75,6 +104,7 @@ class FramedTransport { if(this.writing) await once(this, "write-end"); this.emit("write-start"); this.writing = true; + l.info(`SEND length`, frame.length); await write_safe(this.socket, this.uint2buffer(frame.length)); await write_safe(this.socket, frame); this.writing = false; diff --git a/src/lib/transport/rekey.mjs b/src/lib/transport/rekey.mjs index 83aa17e..cf9a9be 100644 --- a/src/lib/transport/rekey.mjs +++ b/src/lib/transport/rekey.mjs @@ -3,14 +3,14 @@ import { once } from 'events'; import l from 'log'; -import { JPake } from 'jpake'; +import jpake from 'jpake'; export default async function rekey(connection, secret_join) { // 0: Setup jpake - let jpake = new JPake(secret_join); + let jpake_inst = new jpake.JPake(secret_join); // 1: Round 1 - connection.send("rekey", { round: 1, content: jpake.GetRound1Message() }); + connection.send("rekey", { round: 1, content: jpake_inst.GetRound1Message() }); // 2: Round 2 @@ -21,9 +21,12 @@ export default async function rekey(connection, secret_join) { || typeof their_round1.content !== "string") throw new Error(`Error: Received invalid round 1 from peer`); - const our_round2 = jpake.GetRound2Message(their_round1.content); + l.debug(`REKEY GOT ROUND 1`); + + const our_round2 = jpake_inst.GetRound2Message(their_round1.content); if(typeof our_round2 !== "string") throw new Error(`Error: Failed to compute rekey round 2`); + connection.send("rekey", { round: 2, content: our_round2 }); // 3: Compute new shared key @@ -33,11 +36,15 @@ export default async function rekey(connection, secret_join) { || their_round2.round !== 1 || typeof their_round2.content !== "string") throw new Error(`Error: Received invalid round 2 from peer`); + + l.debug(`REKEY GOT ROUND 2`); - const new_shared_key = jpake.ComputeSharedKey(their_round2.content); + const new_shared_key = jpake_inst.ComputeSharedKey(their_round2.content); if(typeof new_shared_key !== "string") throw new Error(`Error: Failed to compute shared key`); + l.debug(`REKEY COMPLETE`); + return Buffer.from(new_shared_key, "hex"); // let data_bytes = response[0].toString("base64"); diff --git a/src/subcommands/test-client/test-client.mjs b/src/subcommands/test-client/test-client.mjs index 2d14a8c..437d288 100644 --- a/src/subcommands/test-client/test-client.mjs +++ b/src/subcommands/test-client/test-client.mjs @@ -3,16 +3,14 @@ import net from 'net'; import l from 'log'; -import make_cert from 'make-cert'; import settings from '../../settings.mjs'; import sleep from '../../lib/async/sleep.mjs'; -import starttls from '../../lib/transport/starttls.mjs'; +import Connection from '../../lib/transport/Connection.mjs'; import { encrypt, decrypt } from '../../lib/crypto/secretbox.mjs'; export default async function() { const test_key = "H7xKSxvJFoZoNjCKAfxn4E3qUzY3Y/4bjY+qIzxg+78="; - const our_cert = make_cert("test_server.localhost"); const test_data = "hello, world"; l.notice(`TEST_DATA`, test_data); @@ -21,25 +19,16 @@ export default async function() { const decrypted = decrypt(test_key, encrypted); l.notice(`DECRYPTED`, decrypted); - const socket = net.createConnection({ - port: settings.cli.port - }, () => { - l.notice("connection established"); - }); + const socket = await Connection.Create(test_key, "::1", settings.cli.port); - await starttls(socket, test_key, our_cert, false); - - socket.on("data", (chunk) => { - l.notice(`<<< ${chunk.toString("utf-8")}`); - }); - socket.on("end", () => { - l.notice("disconnected"); + socket.on("message", (event, msg) => { + l.notice(`<<< ${event}: ${JSON.stringify(msg)}`); }); for(let i = 0; i < 100; i++) { await sleep(1000); l.notice(`>>> hello world ${i}`); - socket.write(`hello world ${i}\n`); + socket.send(`test-client`, `hello world ${i}\n`); } } diff --git a/src/subcommands/test-server/test-server.mjs b/src/subcommands/test-server/test-server.mjs index 5376dd4..8fc34ab 100644 --- a/src/subcommands/test-server/test-server.mjs +++ b/src/subcommands/test-server/test-server.mjs @@ -2,7 +2,6 @@ import net from 'net'; -import make_cert from 'make-cert'; import l from 'log'; import settings from '../../settings.mjs'; @@ -10,13 +9,17 @@ import Connection from '../../lib/transport/Connection.mjs'; export default async function() { const test_key = "H7xKSxvJFoZoNjCKAfxn4E3qUzY3Y/4bjY+qIzxg+78="; - const our_cert = make_cert("test_server.localhost"); const server = net.createServer(async (client) => { l.notice("client connected"); - await starttls(client, test_key, our_cert, true); - client.write("hello\n"); - client.pipe(client); + const connection = await Connection.Wrap(test_key, client); + connection.write("hello\n"); + connection.on("message", (event, msg) => { + l.notice(`<<< ${event}: ${JSON.stringify(msg)}`); + }); + setInterval(async () => { + connection.send("test-server", (new Date()).toString()); + }, 1000); }); server.on("error", (error) => { throw error; });