diff --git a/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala b/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala index 6bd16fb..90aa9f1 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala @@ -63,7 +63,7 @@ class ConnectionHandler(settings: Settings, database: DatabaseInterface, settings.put("message_id", messageId + 1) val msg = new Message(header, body) - val encrypted = crypto.encrypt(crypto.sign(msg)) + val encrypted = crypto.encryptAndSign(msg) router.onReceive(encrypted) onNewMessage(msg) } @@ -77,12 +77,12 @@ class ConnectionHandler(settings: Settings, database: DatabaseInterface, */ def onMessageReceived(msg: Message): Unit = { if (msg.header.target == crypto.localAddress) { - val decrypted = crypto.decrypt(msg) - if (!crypto.verify(decrypted)) { - Log.i(Tag, "Ignoring message with invalid signature from " + msg.header.origin) - return + crypto.verifyAndDecrypt(msg) match { + case Some(msg) => onNewMessage(msg) + case None => + Log.i(Tag, "Ignoring message with invalid signature from " + msg.header.origin) + return } - onNewMessage(decrypted) } else { router.onReceive(msg) } @@ -131,7 +131,7 @@ class ConnectionHandler(settings: Settings, database: DatabaseInterface, return false } - if (crypto.havePublicKey(sender) && !crypto.verify(msg, crypto.getPublicKey(sender))) { + if (crypto.havePublicKey(sender) && !crypto.verify(msg, Option(crypto.getPublicKey(sender)))) { Log.i(Tag, "Ignoring ConnectionInfo message with invalid signature") return false } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala b/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala index c5d6365..5e72d31 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala @@ -135,15 +135,13 @@ class Crypto(settings: Settings, keyFolder: File) { val key = loadKey(PrivateKeyAlias, classOf[PrivateKey]) sig.initSign(key) sig.update(msg.body.write) - new Message(msg.header, new CryptoData(Option(sig.sign), None), msg.body) + new Message(msg.header, new CryptoData(Option(sig.sign), msg.crypto.key), msg.body) } - def verify(msg: Message, key: PublicKey = null): Boolean = { - val publicKey = - if (key != null) key - else loadKey(msg.header.origin.toString, classOf[PublicKey]) + def verify(msg: Message, key: Option[PublicKey] = None): Boolean = { val sig = Signature.getInstance(SigningAlgorithm) - sig.initVerify(publicKey) + lazy val defaultKey = loadKey(msg.header.origin.toString, classOf[PublicKey]) + sig.initVerify(key.getOrElse(defaultKey)) sig.update(msg.body.write) sig.verify(msg.crypto.signature.get) } @@ -223,9 +221,18 @@ class Crypto(settings: Settings, keyFolder: File) { } } - def encrypt(msg: Message, key: PublicKey = null): Message = { - assert(msg.crypto.signature.isDefined, "Message must be signed before encryption") + def encryptAndSign(msg: Message, key: Option[PublicKey] = None): Message = { + sign(encrypt(msg, key)) + } + def verifyAndDecrypt(msg: Message, key: Option[PublicKey] = None): Option[Message] = { + if (verify(msg, key)) + Option(decrypt(msg)) + else + None + } + + private def encrypt(msg: Message, key: Option[PublicKey] = None): Message = { // Symmetric encryption of data val secretKey = makeSecretKey() val symmetricCipher = Cipher.getInstance(SymmetricKeyAlgorithm) @@ -233,17 +240,15 @@ class Crypto(settings: Settings, keyFolder: File) { val encrypted = new EncryptedBody(copyThroughCipher(symmetricCipher, msg.body.write)) // Asymmetric encryption of secret key - val publicKey = - if (key != null) key - else loadKey(msg.header.target.toString, classOf[PublicKey]) val asymmetricCipher = Cipher.getInstance(PublicKeyAlgorithm) - asymmetricCipher.init(Cipher.WRAP_MODE, publicKey) + lazy val defaultKey = loadKey(msg.header.target.toString, classOf[PublicKey]) + asymmetricCipher.init(Cipher.WRAP_MODE, key.getOrElse(defaultKey)) new Message(msg.header, - new CryptoData(msg.crypto.signature, Option(asymmetricCipher.wrap(secretKey))), encrypted) + new CryptoData(None, Option(asymmetricCipher.wrap(secretKey))), encrypted) } - def decrypt(msg: Message): Message = { + private def decrypt(msg: Message): Message = { // Asymmetric decryption of secret key val asymmetricCipher = Cipher.getInstance(PublicKeyAlgorithm) asymmetricCipher.init(Cipher.UNWRAP_MODE, loadKey(PrivateKeyAlias, classOf[PrivateKey])) diff --git a/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala index 42b00b2..7bad8be 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala @@ -32,7 +32,7 @@ class CryptoTest extends TestCase { def testSignVerify(): Unit = { MessageTest.messages.foreach { m => val signed = crypto.sign(m) - assertTrue(crypto.verify(signed, crypto.getLocalPublicKey)) + assertTrue(crypto.verify(signed, Option(crypto.getLocalPublicKey))) assertEquals(m.header, signed.header) assertEquals(m.body, signed.body) } @@ -40,9 +40,9 @@ class CryptoTest extends TestCase { def testEncryptDecrypt(): Unit = { MessageTest.messages.foreach{ m => - val encrypted = crypto.encrypt(crypto.sign(m), crypto.getLocalPublicKey) - val decrypted = crypto.decrypt(encrypted) - assertEquals(m.body, decrypted.body) + val encrypted = crypto.encryptAndSign(m, Option(crypto.getLocalPublicKey)) + val decrypted = crypto.verifyAndDecrypt(encrypted, Option(crypto.getLocalPublicKey)) + assertEquals(m.body, decrypted.get.body) assertEquals(m.header, encrypted.header) } } diff --git a/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala index cb63f82..881d42f 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala @@ -48,21 +48,19 @@ class MessageTest extends TestCase { val read = Message.read(new ByteArrayInputStream(bytes)) assertEquals(signed, read) - assertTrue(crypto.verify(read, crypto.getLocalPublicKey)) + assertTrue(crypto.verify(read, Option(crypto.getLocalPublicKey))) } def testSerializeEncrypted(): Unit = { MessageTest.messages.foreach{ m => - val signed = crypto.sign(m) - val encrypted = crypto.encrypt(signed, crypto.getLocalPublicKey) + val encrypted = crypto.encryptAndSign(m, Option(crypto.getLocalPublicKey)) val bytes = encrypted.write val read = Message.read(new ByteArrayInputStream(bytes)) assertEquals(encrypted.crypto, read.crypto) - val decrypted = crypto.decrypt(read) - assertEquals(m.header, decrypted.header) - assertEquals(m.body, decrypted.body) - assertTrue(crypto.verify(decrypted, crypto.getLocalPublicKey)) + val decrypted = crypto.verifyAndDecrypt(read, Option(crypto.getLocalPublicKey)) + assertEquals(m.header, decrypted.get.header) + assertEquals(m.body, decrypted.get.body) } }